Loadings and Cross-loadings
We explore how many samples are necessary to obtain stable cross-loadings, and compare to loadings.
Setup
[24]:
import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
warnings.filterwarnings(
'ignore', 'aspect is not supported for Axes with xscale=log, yscale=linear', category=UserWarning
) # holoviews emits this for log-linear plots
from tqdm import TqdmExperimentalWarning
warnings.filterwarnings('ignore', category=TqdmExperimentalWarning)
import numpy as np
import xarray as xr
from scipy.stats import pearsonr, spearmanr
from scipy.spatial.distance import pdist
from sklearn.utils import check_random_state
from gemmr.data import load_outcomes, print_ds_stats
from gemmr.generative_model import *
from gemmr.sample_analysis import *
from gemmr.estimators import SVDCCA
from gemmr.sample_size import *
from gemmr.metrics import *
from gemmr.util import subset_ds
from matplotlib.patches import Rectangle
import holoviews as hv
from holoviews import opts
hv.renderer('matplotlib').param.set_param(dpi=120)
hv.extension('matplotlib')
from my_config import *
[2]:
ds_cca = load_outcomes('sweep_cca_cca_random_sum+-2+-2')
ds_pls = load_outcomes('sweep_pls_pls_random_sum+-2+-2')
Loading data from subfolder 'gemmr_latest'
Loading data from subfolder 'gemmr_latest'
[3]:
ds_cca = ds_cca.sel(px=ds_cca.px<128)
ds_pls = ds_pls.sel(px=ds_pls.px<128)
[4]:
ds_cca = subset_ds(ds_cca, n_keep=25)
ds_pls = subset_ds(ds_pls, n_keep=25)
[5]:
assert (ds_cca.py == ds_cca.px).where(np.isfinite(ds_cca.py)).all()
assert (ds_pls.py == ds_pls.px).where(np.isfinite(ds_pls.py)).all()
What’s in the outcome data?
[6]:
print_ds_stats(ds_cca)
n_modes 1
n_rep 100
n_per_ftr [ 3 4 8 16 32 64 128 256 512 1024 2048 4096 8192]
r [0.1 0.3 0.5 0.7]
px [ 2 4 8 16 32 64]
ax+ay range (-2.00, -2.00)
py == px
<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
[25, 25, 25, 25],
[25, 25, 25, 25],
[25, 25, 25, 25],
[25, 25, 25, 25],
[ 0, 25, 25, 25]])
Coordinates:
* r (r) float64 0.1 0.3 0.5 0.7
* px (px) int64 2 4 8 16 32 64
power calculated
[7]:
print_ds_stats(ds_pls)
n_modes 1
n_rep 100
n_per_ftr [ 3 4 8 16 32 64 128 256 512 1024 2048 4096 8192]
r [0.1 0.3 0.5 0.7]
px [ 2 4 8 16 32 64]
ax+ay range (-2.00, -2.00)
py == px
<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
[25, 25, 25, 25],
[25, 25, 25, 25],
[ 0, 25, 25, 25],
[ 0, 25, 25, 25],
[ 0, 0, 0, 10]])
Coordinates:
* r (r) float64 0.1 0.3 0.5 0.7
* px (px) int64 2 4 8 16 32 64
power calculated
True loadings vs true cross-loadings
How similar are loadings and cross-loadings in principle? We compare here their true values that were calculated directly from the assumed joint covariance matrices (i.e. they were not estimated from data).
[8]:
def _corr(x, y, corr_fun=pearsonr):
mask = np.isfinite(x) & np.isfinite(y)
if mask.sum() >= 2:
return corr_fun(x[mask], y[mask])[0]
else:
return np.nan
def corrs_loading_crossloadings_true(ds, corr_fun=pearsonr):
return xr.apply_ufunc(
_corr,
ds.x_loadings_true,
ds.x_crossloadings_true,
input_core_dims=[['x_feature'], ['x_feature']],
vectorize=True,
kwargs=dict(corr_fun=corr_fun)
)
[9]:
cca_pearsonr_loadings_crossloadings_true = corrs_loading_crossloadings_true(ds_cca, pearsonr).mean('Sigma_id').rename('Pearson correlation between\ntrue loadings and\ntrue cross-loadings')
cca_spearmanr_loadings_crossloadings_true = corrs_loading_crossloadings_true(ds_cca, spearmanr).mean('Sigma_id').rename('Spearman correlation between\ntrue loadings and\ntrue cross-loadings')
pls_pearsonr_loadings_crossloadings_true = corrs_loading_crossloadings_true(ds_pls, pearsonr).mean('Sigma_id').rename('Pearson correlation between\ntrue loadings and\ntrue cross-loadings')
pls_spearmanr_loadings_crossloadings_true = corrs_loading_crossloadings_true(ds_pls, spearmanr).mean('Sigma_id').rename('Spearman correlation between\ntrue loadings and\ntrue cross-loadings')
[10]:
cca_pearsonr_loadings_crossloadings_true = cca_pearsonr_loadings_crossloadings_true.assign_coords(
px=2 * cca_pearsonr_loadings_crossloadings_true.px
).rename(px='ptot')
pls_pearsonr_loadings_crossloadings_true = pls_pearsonr_loadings_crossloadings_true.assign_coords(
px=2 * pls_pearsonr_loadings_crossloadings_true.px
).rename(px='ptot')
[11]:
fig_true_cca = (
hv.QuadMesh(cca_pearsonr_loadings_crossloadings_true, kdims=['ptot', 'r']).opts(hooks=[suptitle_cca])
).redim(
ptot='Number of features',
r='$r_\mathrm{true}$'
).opts(*fig_opts).opts(
opts.QuadMesh(logx=True, cmap='inferno', clim=(0, 1), fig_inches=(1.7, None), sublabel_format='c', sublabel_position=(-.5, .85)),
)
fig_true_cca
[11]:
[12]:
handles = []
def horizontal_colorbar(plot, element):
cax = plot.handles['cax']
ax = plot.handles['axis']
bbox = ax.get_position()
l, b, w, h = bbox.x0, bbox.y0, bbox.width, bbox.height
cax.set_position([l, -.2, w, 0.05*h])
[13]:
fig_true_pls = (
hv.QuadMesh(pls_pearsonr_loadings_crossloadings_true, kdims=['ptot', 'r'])
).redim(
ptot='Number of features',
r='$r_\mathrm{true}$'
).opts(*fig_opts).opts(
opts.QuadMesh(
logx=True, cmap='inferno', clim=(0, 1), clabel="corr(loadings, cross-loadings)", fig_inches=(1.7, None), colorbar=True,
colorbar_opts={"orientation":"vertical", "label": ""}, ylabel='', hooks=[suptitle_pls],
sublabel_format='d', sublabel_position=(-.5, .85)),
)
fig_true_pls
[13]:
How many samples are required to obtain stable cross-loadings, compared to the number of required samples for loadings?
[14]:
target_error = 0.1
n_reqs_cca = calc_n_required_all_metrics(ds_cca.sel(mode=0), target_error=target_error, search_dim='n_per_ftr')
n_reqs_pls = calc_n_required_all_metrics(ds_pls.sel(mode=0), target_error=target_error, search_dim='n_per_ftr')
n_reqs_cca['crossloadingError'] = calc_n_required(
mk_crossloadingError(ds_cca).mean('rep'),
-target_error, target_error, search_dim='n_per_ftr')
n_reqs_pls['crossloadingError'] = calc_n_required(
mk_crossloadingError(ds_pls).mean('rep'),
-target_error, target_error, search_dim='n_per_ftr')
[15]:
pxs_cca = n_reqs_cca['combined'].px
pxs_pls = n_reqs_pls['combined'].px
rel_diff = lambda n_reqs, ref_metric: ((n_reqs['crossloadingError'] - n_reqs[ref_metric]) / n_reqs[ref_metric]).mean('Sigma_id').sel(px=n_reqs[ref_metric].px>2).rename('Required samples per feature\n(cross-loadings - loadings)/loadings')
[16]:
_rel_diff_cca = rel_diff(n_reqs_cca, 'loadingError')
_rel_diff_cca = _rel_diff_cca.assign_coords(
px=2 * _rel_diff_cca.px
).rename(px='ptot')
_rel_diff_pls = rel_diff(n_reqs_pls, 'loadingError')
_rel_diff_pls = _rel_diff_pls.assign_coords(
px=2 * _rel_diff_pls.px
).rename(px='ptot')
[17]:
fig_crossloadings_cca = (
hv.QuadMesh(_rel_diff_cca, kdims=['ptot', 'r'])
).redim(
ptot='Number of features',
r='$r_\mathrm{true}$'
).opts(*fig_opts).opts(
opts.QuadMesh(
logx=True, logz=False, cmap='RdBu_r', symmetric=True, clim=(-.65, .65), hooks=[suptitle_cca],
fig_inches=(1.7, None), ylabel='', colorbar=False, colorbar_opts={"orientation":"vertical"},
sublabel_format='e', sublabel_position=(-.5, .85))
)
fig_crossloadings_cca
[17]:
[18]:
fig_crossloadings_pls = (
hv.QuadMesh(_rel_diff_pls, kdims=['ptot', 'r'])
).redim(
ptot='Number of features',
r='$r_\mathrm{true}$'
).opts(*fig_opts).opts(
opts.QuadMesh(
logx=True, logz=False, cmap='RdBu_r', symmetric=True, clim=(-.65, .65), hooks=[suptitle_pls],
fig_inches=(1.7, None), ylabel='', colorbar=True, colorbar_opts={"orientation":"vertical"},
sublabel_format='f', sublabel_position=(-.5, .85))
)
fig_crossloadings_pls
[18]:
Loadings schematic
[19]:
rng = check_random_state(42)
n = 100
rXs = [.5, .9, .3, .8]
tX = rng.normal(size=n)
X = np.array([rXs]) * tX.reshape(-1, 1) + rng.normal(size=(n, 4))
rYs = [.7, .2, .8, .9]
tY = rng.normal(size=n)
Y = np.array([rYs]) * tY.reshape(-1, 1) + rng.normal(size=(n, 4))
X.shape, Y.shape
[19]:
((100, 4), (100, 4))
[20]:
# extent is used below for figure layout
calc_extent = lambda X, t: max(max(np.abs(X.min(0))), max(np.abs(X.max(0))), np.abs(t.min()), t.max())
extentX = calc_extent(X, tX)
extentY = calc_extent(Y, tY)
print(extentX, extentY)
extent = max(extentX, extentY)
4.190859707165826 4.323160858691789
[21]:
def hook_loadings(plot, element):
ax = plot.handles['axis']
fig = plot.handles['fig']
if (ax.get_position().x0 > .75):
ax.text(-extent-2.05, 0, '...', transform=ax.transData)
if (ax.get_position().x0 > .75) and (ax.get_position().y0 > .5):
# X loadings
rect = Rectangle((0.12, 0.90), .85, .05, fill=True, alpha=.1, linewidth=0, linestyle='--',
facecolor='steelblue', transform=fig.transFigure)
fig.add_artist(rect)
fig.text(.95, .96, '$X$ loadings', ha='right', va='bottom', fontsize=8, color='steelblue')
# Y loadings
rect = Rectangle((0.12, 0.41), .85, .05, fill=True, alpha=.1, linewidth=0, linestyle='--',
facecolor='rebeccapurple', transform=fig.transFigure)
fig.add_artist(rect)
fig.text(.95, .46, '$Y$ loadings', ha='right', va='bottom', fontsize=8, color='rebeccapurple')
figX, figY = hv.Layout(), hv.Layout()
for ri in range(len(rXs)):
r = rXs[ri]
if ri < len(rXs) - 1:
var_num = ri + 1
else:
var_num = r'$p_X$'
panel = (
hv.Scatter(
(X[:, ri], tX),
kdims='$X$ feature {}'.format(var_num),
vdims='$X$ scores',
).opts(color='steelblue')
* hv.Text(extent, extent+.5, '$r=%.1f$'%r, halign='right', valign='top', fontsize=8)
)
if ri > 0:
panel = panel.opts(yaxis='bare')
panel = panel.opts(sublabel_format='')
else:
panel = panel.opts(sublabel_format='a')
figX += panel
#######################################################
r = rYs[ri]
if ri < len(rYs) - 1:
var_num = ri + 1
else:
var_num = r'$p_Y$'
panel = (
hv.Scatter(
(Y[:, ri], tY),
kdims='$Y$ feature {}'.format(var_num),
vdims='$Y$ scores',
).opts(color='rebeccapurple')
* hv.Text(extent, extent+.5, '$r=%.1f$'%r, halign='right', valign='top', fontsize=8)
)
if ri == 0:
panel *= hv.Text(1.8, 2.8, 'subject', halign='left', valign='center', fontsize=8)
panel = panel.opts(sublabel_format='b')
if ri > 0:
panel = panel.opts(yaxis='bare')
panel = panel.opts(sublabel_format='')
figY += panel
#################################################
fig_schematic = (
figX
+ figY
).cols(
4
).opts(fig_opts).opts(
opts.Scatter(s=3, xticks=0, yticks=0),
opts.Overlay(xlim=(-extent, extent), ylim=(-extent, extent), hooks=[hook_loadings]),
opts.Layout(fig_inches=(7, None), hspace=.3, vspace=.25)
)
fig_schematic
[21]:
[22]:
def hook_loadings(plot, element):
ax = plot.handles['axis']
fig = plot.handles['fig']
if (ax.get_position().x0 > .75):
ax.text(-extent-2.05, 0, '...', transform=ax.transData)
if (ax.get_position().x0 > .75) and (ax.get_position().y0 > .5):
# X loadings
rect = Rectangle((0.12, 0.92), .85, .05, fill=True, alpha=.1, linewidth=0, linestyle='--',
facecolor='steelblue', transform=fig.transFigure)
fig.add_artist(rect)
fig.text(.95, .97, '$X$ loadings', ha='right', va='bottom', fontsize=8, color='steelblue')
# Y loadings
rect = Rectangle((0.12, 0.58), .85, .05, fill=True, alpha=.1, linewidth=0, linestyle='--',
facecolor='rebeccapurple', transform=fig.transFigure)
fig.add_artist(rect)
fig.text(.95, .63, '$Y$ loadings', ha='right', va='bottom', fontsize=8, color='rebeccapurple')
fig_combined = (
fig_schematic.opts(fig_opts).opts(
opts.Scatter(s=3, xticks=0, yticks=0),
opts.Overlay(xlim=(-extent, extent), ylim=(-extent, extent), hooks=[hook_loadings]),
)
+ fig_true_cca.opts(*fig_opts).opts(
opts.QuadMesh(logx=True, cmap='inferno', clim=(0, 1), sublabel_format='c', sublabel_position=(-.15, .92)),
)
+ fig_true_pls.opts(*fig_opts).opts(
opts.QuadMesh(
logx=True, cmap='inferno', clim=(0, 1), colorbar=True, colorbar_opts={"orientation":"vertical"}, # clabel="",
ylabel='', hooks=[suptitle_pls], sublabel_format='d', sublabel_position=(-.15, .92)),
)
+ fig_crossloadings_cca.opts(*fig_opts).opts(
opts.QuadMesh(
logx=True, logz=False, cmap='RdBu_r', symmetric=True, xlim=(6, None), clim=(-.65, .65),
hooks=[suptitle_cca], ylabel='', colorbar=False, colorbar_opts={"orientation":"vertical"},
sublabel_format='e', sublabel_position=(-.15, .92))
)
+ fig_crossloadings_pls.opts(*fig_opts).opts(
opts.QuadMesh(
logx=True, logz=False, cmap='RdBu_r', symmetric=True, xlim=(6, None), clim=(-.65, .65), # clabel="",
hooks=[suptitle_pls], ylabel='', colorbar=True, colorbar_opts={"orientation":"vertical"}, sublabel_format='f',
sublabel_position=(-.15, .92))
)
).cols(
4
).opts(*fig_opts).opts(
opts.Layout(fig_inches=(7, None), hspace=.9, vspace=.6)
)
hv.save(fig_combined, 'fig/figS_loadings.pdf')
fig_combined
[22]:
[23]:
hv.save(fig_schematic, 'fig/figS_loadings_AB.svg')
hv.save(fig_true_cca, 'fig/figS_loadings_C.pdf')
hv.save(fig_true_pls, 'fig/figS_loadings_D.pdf')
hv.save(fig_crossloadings_cca, 'fig/figS_loadings_E.pdf')
hv.save(fig_crossloadings_pls, 'fig/figS_loadings_F.pdf')