Sparse PLS
[2]:
import warnings
from tqdm import TqdmExperimentalWarning
warnings.filterwarnings('ignore', category=TqdmExperimentalWarning)
import numpy as np
import xarray as xr
import pandas as pd
from scipy.stats import zscore, pearsonr
from sklearn.decomposition import PCA
from sklearn.utils import check_random_state
from sklearn.model_selection import KFold
from gemmr.generative_model import GEMMR, JointCovarianceModelPLS
from gemmr.generative_model.base import PLSgm, CCAgm
from gemmr.estimators import SVDCCA, SVDPLS
from gemmr.estimators.r_estimators import SparsePLS
from gemmr.sample_analysis.analyzers import analyze_model
from gemmr.sample_analysis import addon, postproc
from gemmr.util import _calc_true_loadings
import holoviews as hv
hv.extension('matplotlib')
from holoviews import opts
hv.renderer('matplotlib').param.set_param(dpi=120)
from my_config import *
from tqdm.notebook import tqdm, trange
[2]:
clr_spls = 'mediumvioletred'
clrs = dict(
SVDPLS=clr_pls,
SVDCCA=clr_cca,
SparsePLS=clr_spls
)
[3]:
# Witten et al. (2009): Fig. 4 and Appendix 3
n = 50
random_seed = 0
rng = np.random.default_rng(random_seed)
ex = rng.normal(scale=.3, size=(n, 100))
ey = rng.normal(scale=.3, size=(n, 100))
W = np.linalg.qr(rng.normal(size=(n, n)))[0]
w1, w2 = W[:, [0]], W[:, [1]]
u1 = np.r_[[1]*20, [-1]*20, [0]*60].reshape(-1, 1)
u2 = np.r_[[-1]*10, [1]*10, [-1]*10, [1]*10, [0]*60].reshape(-1, 1)
v1 = np.r_[[0]*60, [-1]*20, [1]*20].reshape(-1, 1)
v2 = np.r_[[0]*60, [1]*10, [-1]*10, [1]*10, [-1]*10].reshape(-1, 1)
X = w1 @ u1.T + w2 @ u2.T + ex
Y = w1 @ v1.T + w2 @ u2.T + ey
Sxy = X.T @ Y / n
hv.Curve(np.linalg.svd(Sxy)[0][:, 0]).opts(linestyle='', marker='.')
[3]:
[4]:
penalties = np.arange(.1, 1.01, .1)
scca = SparsePLS(penaltyxs=[.2, .3, .4], penaltyys=[.2, .3, .4], penalty_pairing='product', niter=100, cv=5).fit(X, Y)
print(scca.corrs_)
hv.Curve(scca.x_rotations_[:, 0]).opts(linestyle='', marker='.')
[0.91153997]
[4]:
[5]:
rng = np.random.default_rng(seed=0)
Sxx = 4*u1@u1.T/n + u2@u2.T/n + (.9)*np.eye(len(u1))
Syy = 4*v1@v1.T/n + v2@v2.T/n + (.9)*np.eye(len(v1))
Sxy = 4*u1@v1.T/n + u2@v2.T/n + u1 @ rng.normal(scale=.003, size=(1, len(u1))) + rng.normal(scale=.003, size=(len(v1), 1)) @ v1.T
S = np.vstack([
np.hstack([Sxx, Sxy]),
np.hstack([Sxy.T, Syy])
])
jcov = JointCovarianceModelPLS(S, len(u1), m=1)
jcov.true_corrs_
[5]:
array([0.78143672])
[6]:
px = 64
weights = dict(
uniform=np.ones(px),
random=np.random.default_rng(0).uniform(-1, 1, size=px),
decay=(np.arange(1, px+1) ** -1.),
rise=(np.arange(1, px+1) ** -1.)[::-1],
step_down=np.r_[np.ones(px//8), np.zeros(7*px//8)],
step_up=np.r_[np.zeros(7*px//8), np.ones(1*px//8)],
spikes=np.random.default_rng(0).binomial(n=1, p=1./8, size=px), #* (np.arange(1, px+1) ** -1.),
decaying_spikes=np.random.default_rng(0).binomial(n=1, p=1./8, size=px) * (np.arange(1, px+1) ** -1.),
)
for k in weights:
weights[k] = weights[k] / np.linalg.norm(weights[k], keepdims=True)
fig = hv.Layout()
for k in weights:
fig += hv.Curve(weights[k]).relabel(k)
fig
[6]:
[7]:
import rpy2.robjects
rpy2.robjects.r("Sys.setenv('R_MAX_VSIZE'=128000000000)")
[7]:
BoolVector with 1 elements.
1 |
[8]:
# ## Takes a while to run
# ax, ay = -1., -1.
# penalties = [.1, .3, .5]
# ress = []
# for w_lbl in weights: # ['decay', 'step_down']:
# w = weights[w_lbl]
# plsgm = PLSgm(w.reshape(-1, 1), w.reshape(-1, 1), ax=ax, ay=ay, r_between=0.3)
# res = analyze_model(
# plsgm,
# [
# SVDPLS(),
# SparsePLS(penaltyxs=penalties, penaltyys=penalties, penalty_pairing='zip')
# ],
# n_per_ftrs=(1, 16, 256, ),
# check_convergence=False,
# n_rep=25,
# n_test=1000,
# addons=[
# addon.test_scores, addon.weights_true_cossim, addon.loadings_true_pearson, addon.test_scores_true_pearson, addon.test_scores_true_spearman, addon.sparseCCA_penalties, #addon.cv
# ],
# postprocessors=[postproc.weights_pairwise_cossim_stats, postproc.weights_pairwise_jaccard_stats],
# #true_loadings=_calc_true_loadings(plsgm.Sigma_, plsgm.px, plsgm.x_rotations_[:, :1], plsgm.y_rotations_[:, :1]),
# cvs=[('kfold5', KFold(5))],
# scorers=addon.mk_scorers_for_cv(),
# mk_test_statistics=addon.mk_test_statistics_scores,
# random_state=0
# )
# res['weight'] = w_lbl
# ress.append(res)
# res_pls = xr.concat(ress, 'weight')
# res_pls.to_netcdf('pls_vs_spls.nc')
[9]:
res_pls = xr.open_dataset('pls_vs_spls.nc')
[10]:
def avg_weights(xw):
xw = xw# .mean('rep')
assert xw.ndim == 2
xw = xw / np.linalg.norm(xw.mean('rep'))
return xw
def plot_comparison(xw_pls, xw_spls, xw_true, show_xlabel=True, qs=(.025, .975)):
if show_xlabel:
xlabel = None
else:
xlabel = ''
xw_pls_mean = xw_pls.mean('rep')
xw_spls_mean = xw_spls.mean('rep')
if xw_pls_mean.values @ xw_true < 0:
xw_pls *= -1
if xw_spls_mean.values @ xw_true < 0:
xw_spls *= -1
return (
hv.Area((xw_pls.x_feature, xw_pls.quantile(qs[0], 'rep'), xw_pls.quantile(qs[1], 'rep')), vdims=['y', 'y2']).opts(color=clr_pls)
* hv.Area((xw_spls.x_feature, xw_spls.quantile(qs[0], 'rep'), xw_spls.quantile(qs[1], 'rep')), vdims=['y', 'y2']).opts(color=clr_spls)
* hv.Curve(xw_pls_mean, label='PLS').opts(color=clr_pls, linewidth=2.5)
* hv.Curve(xw_spls_mean, label='SPLS').opts(color=clr_spls, linestyle='--', linewidth=2.5)
* hv.Curve(xw_true, label='Ground truth').opts(color='black', linewidth=2, linestyle=':')
).redim(
x='PC',
y='Weight'
).opts(
opts.Area(linewidth=1, alpha=.2),
opts.Overlay(padding=.02, xlabel=xlabel, ylim=(-1, 1), sublabel_position=(-.45, .95))
)
def plot_stability(res, show_xlabel=True):
if show_xlabel:
xlabel = None
else:
xlabel = ''
panel = hv.Overlay()
ls = dict(SVDPLS='-', SparsePLS='--')
for estr in res.estr.values:
panel *= (
hv.Area(
(res.x_weights_pairwise_cossim_stats.n_per_ftr,
res.x_weights_pairwise_cossim_stats.sel(mode=0, estr=estr, stat='q2.5%'),
res.x_weights_pairwise_cossim_stats.sel(mode=0, estr=estr, stat='q97.5%')
),
vdims=['y', 'y2']
).opts(color=clrs[estr])
* hv.Curve(res.x_weights_pairwise_cossim_stats.sel(mode=0, stat='mean', estr=estr)).opts(color=clrs[estr], linestyle=ls[estr])
)
return panel.redim(
n_per_ftr='Samples / feature',
x_weights_pairwise_cossim_stats='Weight stability',
).opts(
opts.Area(linewidth=0, alpha=.3),
opts.Overlay(xlim=(.1, None), ylim=(0, 1), logx=True, xlabel=xlabel, sublabel_position=(-.75, .95))
)
def plot_weight_type(res, show_xlabel=True, show_titel=True, show_legend=True, n_per_ftrs=None):
if n_per_ftrs is None:
n_per_ftrs = res.dropna('n_per_ftr', how='any', subset=['x_weights_true_cossim']).n_per_ftr.values
n_per_ftrs = [n_per_ftrs[0], n_per_ftrs[1], n_per_ftrs[-1]]
fig = hv.Layout()
for n_per_ftr in n_per_ftrs:
if show_titel:
if int(n_per_ftr) == n_per_ftr:
n_per_ftr = int(n_per_ftr)
title = f'{n_per_ftr} samples / feature'
else:
title = f'{n_per_ftr} samples / feature'
if n_per_ftr == res.n_per_ftr.values[0]:
ylabel = None
else:
ylabel = ''
show_legend = False
panel = plot_comparison(
res.sel(mode=0).x_weights.sel(estr='SVDPLS', n_per_ftr=n_per_ftr),
res.sel(mode=0).x_weights.sel(estr='SparsePLS', n_per_ftr=n_per_ftr),
res.sel(mode=0).x_weights_true.values,
show_xlabel=show_xlabel
).opts(title=title)
fig += panel.opts(ylabel=ylabel, show_legend=show_legend)
fig += plot_stability(res, show_xlabel=show_xlabel)
fig.cols(
4
).opts(*fig_opts).opts(
opts.Overlay(hooks=[legend_frame_off]),
opts.Layout(fig_inches=(7, 10))
)
return fig
def plot_all(res, n_per_ftrs=None):
fig = hv.Layout()
for w in res.weight.values:
if w == res.weight.values[-1]:
show_xlabel = True
else:
show_xlabel = False
if w == res.weight.values[0]:
show_titel = True
show_legend = True
else:
show_titel = False
show_legend = False
fig += plot_weight_type(res.sel(weight=w), show_xlabel=show_xlabel, show_titel=show_titel, show_legend=show_legend, n_per_ftrs=n_per_ftrs)
return fig.opts(*fig_opts).opts(
opts.Layout(fig_inches=(7, None))
)
[11]:
fig = plot_all(res_pls, n_per_ftrs=None)
fig
[11]:
[12]:
hv.save(fig, 'fig/figS_spls.pdf')