Sparse PLS

[2]:
import warnings
from tqdm import TqdmExperimentalWarning
warnings.filterwarnings('ignore', category=TqdmExperimentalWarning)

import numpy as np
import xarray as xr
import pandas as pd

from scipy.stats import zscore, pearsonr

from sklearn.decomposition import PCA
from sklearn.utils import check_random_state
from sklearn.model_selection import KFold

from gemmr.generative_model import GEMMR, JointCovarianceModelPLS
from gemmr.generative_model.base import PLSgm, CCAgm
from gemmr.estimators import SVDCCA, SVDPLS
from gemmr.estimators.r_estimators import SparsePLS
from gemmr.sample_analysis.analyzers import analyze_model
from gemmr.sample_analysis import addon, postproc
from gemmr.util import _calc_true_loadings

import holoviews as hv
hv.extension('matplotlib')
from holoviews import opts
hv.renderer('matplotlib').param.set_param(dpi=120)

from my_config import *

from tqdm.notebook import tqdm, trange
[2]:
clr_spls = 'mediumvioletred'
clrs = dict(
    SVDPLS=clr_pls,
    SVDCCA=clr_cca,
    SparsePLS=clr_spls
)
[3]:
# Witten et al. (2009): Fig. 4 and Appendix 3
n = 50
random_seed = 0

rng = np.random.default_rng(random_seed)

ex = rng.normal(scale=.3, size=(n, 100))
ey = rng.normal(scale=.3, size=(n, 100))
W = np.linalg.qr(rng.normal(size=(n, n)))[0]
w1, w2 = W[:, [0]], W[:, [1]]
u1 = np.r_[[1]*20, [-1]*20, [0]*60].reshape(-1, 1)
u2 = np.r_[[-1]*10, [1]*10, [-1]*10, [1]*10, [0]*60].reshape(-1, 1)
v1 = np.r_[[0]*60, [-1]*20, [1]*20].reshape(-1, 1)
v2 = np.r_[[0]*60, [1]*10, [-1]*10, [1]*10, [-1]*10].reshape(-1, 1)

X = w1 @ u1.T + w2 @ u2.T + ex
Y = w1 @ v1.T + w2 @ u2.T + ey

Sxy = X.T @ Y / n

hv.Curve(np.linalg.svd(Sxy)[0][:, 0]).opts(linestyle='', marker='.')
[3]:
[4]:
penalties = np.arange(.1, 1.01, .1)
scca = SparsePLS(penaltyxs=[.2, .3, .4], penaltyys=[.2, .3, .4], penalty_pairing='product', niter=100, cv=5).fit(X, Y)
print(scca.corrs_)
hv.Curve(scca.x_rotations_[:, 0]).opts(linestyle='', marker='.')
[0.91153997]
[4]:
[5]:
rng = np.random.default_rng(seed=0)

Sxx = 4*u1@u1.T/n + u2@u2.T/n + (.9)*np.eye(len(u1))
Syy = 4*v1@v1.T/n + v2@v2.T/n + (.9)*np.eye(len(v1))
Sxy = 4*u1@v1.T/n + u2@v2.T/n +  u1 @ rng.normal(scale=.003, size=(1, len(u1))) + rng.normal(scale=.003, size=(len(v1), 1)) @ v1.T

S =  np.vstack([
    np.hstack([Sxx, Sxy]),
    np.hstack([Sxy.T, Syy])
])

jcov = JointCovarianceModelPLS(S, len(u1), m=1)
jcov.true_corrs_
[5]:
array([0.78143672])
[6]:
px = 64
weights = dict(
    uniform=np.ones(px),
    random=np.random.default_rng(0).uniform(-1, 1, size=px),
    decay=(np.arange(1, px+1) ** -1.),
    rise=(np.arange(1, px+1) ** -1.)[::-1],
    step_down=np.r_[np.ones(px//8), np.zeros(7*px//8)],
    step_up=np.r_[np.zeros(7*px//8), np.ones(1*px//8)],
    spikes=np.random.default_rng(0).binomial(n=1, p=1./8, size=px), #* (np.arange(1, px+1) ** -1.),
    decaying_spikes=np.random.default_rng(0).binomial(n=1, p=1./8, size=px) * (np.arange(1, px+1) ** -1.),
)

for k in weights:
    weights[k] = weights[k] / np.linalg.norm(weights[k], keepdims=True)

fig = hv.Layout()
for k in weights:
    fig += hv.Curve(weights[k]).relabel(k)

fig
[6]:
[7]:
import rpy2.robjects
rpy2.robjects.r("Sys.setenv('R_MAX_VSIZE'=128000000000)")
[7]:
BoolVector with 1 elements.
1
[8]:
# ## Takes a while to run

# ax, ay = -1., -1.

# penalties = [.1, .3, .5]

# ress = []
# for w_lbl in weights:  # ['decay', 'step_down']:
#     w = weights[w_lbl]
#     plsgm = PLSgm(w.reshape(-1, 1), w.reshape(-1, 1), ax=ax, ay=ay, r_between=0.3)
#     res = analyze_model(
#         plsgm,
#         [
#             SVDPLS(),
#             SparsePLS(penaltyxs=penalties, penaltyys=penalties, penalty_pairing='zip')
#         ],
#         n_per_ftrs=(1, 16, 256, ),
#         check_convergence=False,
#         n_rep=25,
#         n_test=1000,
#         addons=[
#             addon.test_scores, addon.weights_true_cossim, addon.loadings_true_pearson, addon.test_scores_true_pearson, addon.test_scores_true_spearman, addon.sparseCCA_penalties, #addon.cv
#             ],
#         postprocessors=[postproc.weights_pairwise_cossim_stats, postproc.weights_pairwise_jaccard_stats],
#         #true_loadings=_calc_true_loadings(plsgm.Sigma_, plsgm.px, plsgm.x_rotations_[:, :1], plsgm.y_rotations_[:, :1]),
#         cvs=[('kfold5', KFold(5))],
#         scorers=addon.mk_scorers_for_cv(),
#         mk_test_statistics=addon.mk_test_statistics_scores,
#         random_state=0
#     )
#     res['weight'] = w_lbl
#     ress.append(res)
# res_pls = xr.concat(ress, 'weight')

# res_pls.to_netcdf('pls_vs_spls.nc')
[9]:
res_pls = xr.open_dataset('pls_vs_spls.nc')
[10]:
def avg_weights(xw):
    xw = xw# .mean('rep')
    assert xw.ndim == 2
    xw = xw / np.linalg.norm(xw.mean('rep'))
    return xw


def plot_comparison(xw_pls, xw_spls, xw_true, show_xlabel=True, qs=(.025, .975)):
    if show_xlabel:
        xlabel = None
    else:
        xlabel = ''

    xw_pls_mean = xw_pls.mean('rep')
    xw_spls_mean = xw_spls.mean('rep')

    if xw_pls_mean.values @ xw_true < 0:
        xw_pls *= -1
    if xw_spls_mean.values @ xw_true < 0:
        xw_spls *= -1

    return (
        hv.Area((xw_pls.x_feature, xw_pls.quantile(qs[0], 'rep'), xw_pls.quantile(qs[1], 'rep')), vdims=['y', 'y2']).opts(color=clr_pls)
        * hv.Area((xw_spls.x_feature, xw_spls.quantile(qs[0], 'rep'), xw_spls.quantile(qs[1], 'rep')), vdims=['y', 'y2']).opts(color=clr_spls)
        * hv.Curve(xw_pls_mean, label='PLS').opts(color=clr_pls, linewidth=2.5)
        * hv.Curve(xw_spls_mean, label='SPLS').opts(color=clr_spls, linestyle='--', linewidth=2.5)
        * hv.Curve(xw_true, label='Ground truth').opts(color='black', linewidth=2, linestyle=':')
    ).redim(
        x='PC',
        y='Weight'
    ).opts(
        opts.Area(linewidth=1, alpha=.2),
        opts.Overlay(padding=.02, xlabel=xlabel, ylim=(-1, 1), sublabel_position=(-.45, .95))
    )


def plot_stability(res, show_xlabel=True):
    if show_xlabel:
        xlabel = None
    else:
        xlabel = ''
    panel = hv.Overlay()
    ls = dict(SVDPLS='-', SparsePLS='--')
    for estr in res.estr.values:
        panel *= (
            hv.Area(
                (res.x_weights_pairwise_cossim_stats.n_per_ftr,
                 res.x_weights_pairwise_cossim_stats.sel(mode=0, estr=estr, stat='q2.5%'),
                 res.x_weights_pairwise_cossim_stats.sel(mode=0, estr=estr, stat='q97.5%')
                 ),
                vdims=['y', 'y2']
            ).opts(color=clrs[estr])
            * hv.Curve(res.x_weights_pairwise_cossim_stats.sel(mode=0, stat='mean', estr=estr)).opts(color=clrs[estr], linestyle=ls[estr])
        )
    return panel.redim(
        n_per_ftr='Samples / feature',
        x_weights_pairwise_cossim_stats='Weight stability',
    ).opts(
        opts.Area(linewidth=0, alpha=.3),
        opts.Overlay(xlim=(.1, None), ylim=(0, 1), logx=True, xlabel=xlabel, sublabel_position=(-.75, .95))
    )


def plot_weight_type(res, show_xlabel=True, show_titel=True, show_legend=True, n_per_ftrs=None):

    if n_per_ftrs is None:
        n_per_ftrs = res.dropna('n_per_ftr', how='any', subset=['x_weights_true_cossim']).n_per_ftr.values
        n_per_ftrs = [n_per_ftrs[0], n_per_ftrs[1], n_per_ftrs[-1]]

    fig = hv.Layout()
    for n_per_ftr in n_per_ftrs:
        if show_titel:
            if int(n_per_ftr) == n_per_ftr:
                n_per_ftr = int(n_per_ftr)
            title = f'{n_per_ftr} samples / feature'
        else:
            title = f'{n_per_ftr} samples / feature'

        if n_per_ftr == res.n_per_ftr.values[0]:
            ylabel = None
        else:
            ylabel = ''
            show_legend = False

        panel = plot_comparison(
            res.sel(mode=0).x_weights.sel(estr='SVDPLS', n_per_ftr=n_per_ftr),
            res.sel(mode=0).x_weights.sel(estr='SparsePLS', n_per_ftr=n_per_ftr),
            res.sel(mode=0).x_weights_true.values,
            show_xlabel=show_xlabel
        ).opts(title=title)

        fig += panel.opts(ylabel=ylabel, show_legend=show_legend)

    fig += plot_stability(res, show_xlabel=show_xlabel)
    fig.cols(
        4
    ).opts(*fig_opts).opts(
        opts.Overlay(hooks=[legend_frame_off]),
        opts.Layout(fig_inches=(7, 10))
    )
    return fig

def plot_all(res, n_per_ftrs=None):
    fig = hv.Layout()
    for w in res.weight.values:
        if w == res.weight.values[-1]:
            show_xlabel = True
        else:
            show_xlabel = False

        if w == res.weight.values[0]:
            show_titel = True
            show_legend = True
        else:
            show_titel = False
            show_legend = False

        fig += plot_weight_type(res.sel(weight=w), show_xlabel=show_xlabel, show_titel=show_titel, show_legend=show_legend, n_per_ftrs=n_per_ftrs)

    return fig.opts(*fig_opts).opts(
        opts.Layout(fig_inches=(7, None))
    )
[11]:
fig = plot_all(res_pls, n_per_ftrs=None)
fig
[11]:
[12]:
hv.save(fig, 'fig/figS_spls.pdf')