Weight errors

How much do estimated CCA and PLS weights vary across repetitions, depending on sample size?

Setup

[34]:
import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
warnings.filterwarnings(
    'ignore', 'aspect is not supported for Axes with xscale=log, yscale=linear', category=UserWarning
)  # holoviews emits this for log-linear plots
from tqdm import TqdmExperimentalWarning
warnings.filterwarnings('ignore', category=TqdmExperimentalWarning)

import os

import numpy as np
import pandas as pd
import xarray as xr
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, zscore
from scipy.spatial.distance import pdist, cdist, squareform

from sklearn import clone
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_validate, ShuffleSplit
from sklearn.utils import check_random_state

from gemmr.data import load_outcomes, print_ds_stats
from gemmr.generative_model import GEMMR, PLSgm, JointCovarianceModelCCA, JointCovarianceModelPLS
from gemmr.generative_model.base import WeightNotFoundError
from gemmr.sample_size import *
from gemmr.estimators import *
from gemmr.sample_analysis import *
from gemmr.plot import heatmap_n_req
from gemmr.util import nPerFtr2n, subset_ds
from gemmr.plot import polar_hist

import matplotlib
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import matplotlib.patheffects as path_effects
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').param.set_param(dpi=120)

from my_config import *
[35]:
r_clrs = hv.Palette(cmap_r, samples=3).values
n_per_ftr_clrs = hv.Palette(cmap_n_per_ftr, samples=4).values
clr_random = 'darkslateblue'
clr_perm = 'gray'

Analyze data with 100 features per set and a correlation 0.3

[ ]:
r_true = 0.3
px = 100

n_mul_cca = 10
res_cca_100 = analyze_model_parameters(
    'cca',
    n_perm=0,
    n_rep=100,
    n_Sigmas=1,
    n_test=0,#1000,
    pxs=[px],
    py='px',
    n_per_ftrs=[n_per_ftr_typical, n_mul_cca*n_per_ftr_typical],
    rs=(r_true,),
    powerlaw_decay=(-1, -1),
    random_state=0,
    qx=.9,
    qy=.9,
)


n_mul_pls = 10
res_pls_100 = analyze_model_parameters(
    'pls',
    n_perm=0,
    n_rep=100,
    n_Sigmas=1,
    n_test=0,#1000,
    pxs=[px],
    py='px',
    n_per_ftrs=[n_per_ftr_typical, n_mul_pls*n_per_ftr_typical],
    rs=(r_true,),
    powerlaw_decay=(-1, -1),
    random_state=0,
    qx=.9,
    qy=.9,
)

Illustration of weight uncertainty

[37]:
panels_weights_100 = dict()

qs = (.025, .975)

for model, res, n_mul, n_mul_clr in [
    ('cca', res_cca_100, n_mul_cca, n_per_ftr_clrs[1]),
    ('pls', res_pls_100, n_mul_pls, n_per_ftr_clrs[2])
]:

    _res = res.sel(mode=0, Sigma_id=0, px=px, r=0.3)
    _w_true = _res.x_weights_true
    w_true = _w_true.sortby(_w_true, ascending=False)
    w_q = _res.x_weights.quantile(qs, 'rep').sortby(_w_true, ascending=False)

    panels_weights_100[model] = (
        hv.Area(
            (np.arange(len(w_true)),
             w_q.sel(n_per_ftr=n_per_ftr_typical, quantile=qs[0]).values,
             w_q.sel(n_per_ftr=n_per_ftr_typical, quantile=qs[-1]).values),
            vdims=['y', 'y2'],
            label='$n_\mathrm{typical}$'
        ).opts(color=n_per_ftr_clrs[1])
        * hv.Area(
            (np.arange(len(w_true)),
             w_q.sel(n_per_ftr=n_mul*n_per_ftr_typical, quantile=qs[0]).values,
             w_q.sel(n_per_ftr=n_mul*n_per_ftr_typical, quantile=qs[-1]).values),
            vdims=['y', 'y2'],
            label='$%i\cdot n_\mathrm{typical}$' % n_mul
        ).opts(color=n_per_ftr_clrs[2])
        * hv.Curve(w_true.values, label='true').opts(color=n_per_ftr_clrs[-1])
        * hv.HLine(0).opts(color='white', linestyle='--', linewidth=1)
    ).redim(
        x='Feature id',
        y='Weight'
    ).opts(
        opts.Area(alpha=.8, linewidth=0, show_legend=False),
        opts.Curve(show_legend=False),
        opts.Overlay(xlim=(0, 100), ylim=(-np.abs(w_q).max(), np.abs(w_q).max()))
    )

(
    panels_weights_100['cca']
    + panels_weights_100['pls']
).opts(*fig_opts)
[37]:

Sample size dependence

[ ]:
px = 100
r_true = 0.3
n_per_ftrs = [4, 64, 4096]

res_cca_100_ = analyze_model_parameters(
    'cca',
    n_perm=0,
    n_rep=10,
    n_Sigmas=1,
    n_test=0,
    pxs=[px],
    py='px',
    n_per_ftrs=n_per_ftrs,
    rs=(r_true,),
    powerlaw_decay=(-1, -1),
    random_state=0,
)

res_pls_100_ = analyze_model_parameters(
    'pls',
    n_perm=0,
    n_rep=10,
    n_Sigmas=1,
    n_test=0,
    pxs=[px],
    py='px',
    n_per_ftrs=n_per_ftrs,
    rs=(r_true,),
    powerlaw_decay=(-1, -1),
    random_state=0,
)
[39]:
# average across reps (in contrast, below: rep=0)
res_cca_100__ = res_cca_100_.sel(px=100, r=r_true, Sigma_id=0, mode=0)
res_pls_100__ = res_pls_100_.sel(px=100, r=r_true, Sigma_id=0, mode=0)

delta_xw_cca = res_cca_100__.x_weights.mean('rep') - res_cca_100__.x_weights_true
delta_xw_pls = res_pls_100__.x_weights.mean('rep') - res_pls_100__.x_weights_true

(
    (
        hv.Curve(delta_xw_cca.sel(n_per_ftr=n_per_ftrs[0]), label=f'{n_per_ftrs[0]}')
        * hv.Curve(delta_xw_cca.sel(n_per_ftr=n_per_ftrs[1]), label=f'{n_per_ftrs[1]}')
        * hv.Curve(delta_xw_cca.sel(n_per_ftr=n_per_ftrs[2]), label=f'{n_per_ftrs[2]}')
        * hv.Text(72, -.33, 'smpls / ftr', fontsize=7)
    ).opts(hooks=[legend_frame_off, suptitle_cca], legend_position='bottom_right')
    + (
        hv.Curve(delta_xw_pls.sel(n_per_ftr=n_per_ftrs[0]), label=f'{n_per_ftrs[0]}')
        * hv.Curve(delta_xw_pls.sel(n_per_ftr=n_per_ftrs[1]), label=f'{n_per_ftrs[1]}')
        * hv.Curve(delta_xw_pls.sel(n_per_ftr=n_per_ftrs[2]), label=f'{n_per_ftrs[2]}')
    ).opts(hooks=[legend_frame_off, suptitle_pls], show_legend=False, legend_position='top_right', ylabel='')
).redim(
    x_feature='Principal component',
    y='$X$ Weight - true $X$ weight'
).opts(*fig_opts).opts(
    opts.Overlay(padding=.1, legend_position='bottom_left', ylim=(-.41, None), legend_cols=2),
    opts.Curve(linewidth=1)
)
[39]:
[40]:
# rep = 0

res_cca_100__ = res_cca_100_.sel(px=100, r=r_true, Sigma_id=0, mode=0)
res_pls_100__ = res_pls_100_.sel(px=100, r=r_true, Sigma_id=0, mode=0)

delta_xw_cca = res_cca_100__.x_weights.sel(rep=0) - res_cca_100__.x_weights_true
delta_xw_pls = res_pls_100__.x_weights.sel(rep=0) - res_pls_100__.x_weights_true

fig_delta_weight_convergence = (
    (
        hv.Curve(delta_xw_cca.sel(n_per_ftr=n_per_ftrs[0]), label=f'{n_per_ftrs[0]}')
        * hv.Curve(delta_xw_cca.sel(n_per_ftr=n_per_ftrs[1]), label=f'{n_per_ftrs[1]}')
        * hv.Curve(delta_xw_cca.sel(n_per_ftr=n_per_ftrs[2]), label=f'{n_per_ftrs[2]}')
        * hv.Text(50, -.53, 'smpls / ftr', fontsize=7, halign='left')
    ).opts(hooks=[legend_frame_off, suptitle_cca], legend_position='bottom_right')
    + (
        hv.Curve(delta_xw_pls.sel(n_per_ftr=n_per_ftrs[0]), label=f'{n_per_ftrs[0]}')
        * hv.Curve(delta_xw_pls.sel(n_per_ftr=n_per_ftrs[1]), label=f'{n_per_ftrs[1]}')
        * hv.Curve(delta_xw_pls.sel(n_per_ftr=n_per_ftrs[2]), label=f'{n_per_ftrs[2]}')
    ).opts(hooks=[legend_frame_off, suptitle_pls], show_legend=False, legend_position='top_right', ylabel='')
).redim(
    x_feature='Principal component',
    y='$X$ Weight - true $X$ weight'
).opts(*fig_opts).opts(
    opts.Overlay(padding=.1, legend_position='bottom_left', legend_cols=2),
    opts.Curve(linewidth=1)
)

fig_delta_weight_convergence
[40]:

How many samples are required to obtain at most 10% weight error?

[41]:
ds_cca = load_outcomes('sweep_cca_cca_random_sum+-2+-2').sel(mode=0)
ds_pls = load_outcomes('sweep_pls_pls_random_sum+-2+-2').sel(mode=0)
Loading data from subfolder 'gemmr_latest'
Loading data from subfolder 'gemmr_latest'
[42]:
ds_cca = ds_cca.sel(px=ds_cca.px<128)
ds_pls = ds_pls.sel(px=ds_pls.px<128)
[43]:
ds_cca = subset_ds(ds_cca, n_keep=25)
ds_pls = subset_ds(ds_pls, n_keep=25)
[44]:
print_ds_stats(ds_cca)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated
[45]:
print_ds_stats(ds_pls)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0,  0,  0, 10]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated
[46]:
cca_n_req_per_ftr = calc_n_required_all_metrics(ds_cca, search_dim='n_per_ftr')
pls_n_req_per_ftr = calc_n_required_all_metrics(ds_pls, search_dim='n_per_ftr')
[47]:
fig_required_sample_size_heatmap = (
    heatmap_n_req(nPerFtr2n(cca_n_req_per_ftr['weightError'], ds_cca.py)).relabel('Weight error $\leq$ 10%')
    + heatmap_n_req(nPerFtr2n(pls_n_req_per_ftr['weightError'], ds_pls.py)).relabel('Weight error $\leq$ 10%').opts(colorbar=True, yaxis='bare')
).cols(
    2
).opts(*fig_opts).opts(
    opts.QuadMesh(logx=True, logz=True, ylim=(0, .8),
                  clim=(1e1, 1e5), cmap='inferno',
                  sublabel_position=(-.45, .95),
                  fontsize=dict(labels=8, ticks=7, legend=7, title=8),
                 ),
    opts.Layout(fig_inches=(3.42, 1.))
)

fig_required_sample_size_heatmap
[47]:

How similar are estimated CCA and PLS weights across repetitions?

[48]:
res_cossim_weights = dict(
    cca=load_outcomes('weightStability_cca_cca_random_sum+-2+-2').sel(mode=0),
    pls=load_outcomes('weightStability_pls_pls_random_sum+-2+-2').sel(mode=0),
)
Loading data from subfolder 'gemmr_latest'
Loading data from subfolder 'gemmr_latest'

What’s in the outcome data files?

[49]:
res_cossim_weights['cca'] = subset_ds(res_cossim_weights['cca'], n_keep=100)
res_cossim_weights['pls'] = subset_ds(res_cossim_weights['pls'], n_keep=100)
[50]:
print_ds_stats(res_cossim_weights['cca'])
n_rep            100
n_per_ftr        [  3   8  16  32  64 128 256 512]
r                [0.3]
px               [100]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 1, r: 1)>
array([[100]])
Coordinates:
  * r        (r) float64 0.3
  * px       (px) int64 100

power           not calculated
[51]:
print_ds_stats(res_cossim_weights['pls'])
n_rep            100
n_per_ftr        [  3   8  16  32  64 128 256 512]
r                [0.3]
px               [100]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 1, r: 1)>
array([[100]])
Coordinates:
  * r        (r) float64 0.3
  * px       (px) int64 100

power           not calculated
[52]:
def plot_cossim_weights_stats(res, outcome='x_weights_pairwise_cossim_stats', color=None):

    stats = res[outcome]
    ptot = res.px + res.py

    panel = hv.Overlay()
    for i in stats.Sigma_id.values:
        panel *= hv.Curve(
            (stats.n_per_ftr, stats.sel(stat='mean', Sigma_id=i))
        ).opts(linewidth=1, linestyle='--', alpha=.1)

    panel *= (
        hv.Curve(
            (stats.n_per_ftr, stats.sel(stat='mean').mean('Sigma_id'))
        )
    )

    return panel.redim(
        x='Samples per feature',
        y='Weight stability'
    ).opts(
        opts.Area(alpha=.33, color=color),
        opts.Curve(color=color),
        opts.Overlay(ylim=(0, 1), logx=True)
    )
[53]:
(
    plot_cossim_weights_stats(res_cossim_weights['cca'].sel(px=100, r=0.3), color=clr_cca)
    + (
        plot_cossim_weights_stats(res_cossim_weights['pls'].sel(px=100, r=0.3), color=clr_pls)
    )
).opts(*fig_opts).opts(
    opts.Overlay(logx=False)
)
[53]:

How do weight vectors look like in the principal component coordinate system?

[54]:
# --- Takes a while to run ---

# from gemmr.sample_analysis.addon import weights_pc_cossim


# pxs = [2, 4, 8, 16, 32, 64]
# r_trues = [.1, .3, .5]
# n_per_ftrs = [3, 5, 8, 16, 32, 64, 128, 200, ]
# n_rep = 100
# n_perm = 0
# n_Sigmas = 10

# res_cca_100_ = analyze_model_parameters(
#     'cca',
#     n_perm=n_perm,
#     n_rep=n_rep,
#     n_Sigmas=n_Sigmas,
#     n_test=0,#1000,
#     pxs=pxs,
#     py='px',
#     n_per_ftrs=n_per_ftrs,
#     rs=r_trues,
#     powerlaw_decay=(-1, -1),
#     random_state=0,
#     addons=[weights_pc_cossim],
# )
# res_cca_100_.to_netcdf('pcBias_cca_cca_random_sum+-2+-2.nc')

# res_pls_100_ = analyze_model_parameters(
#     'pls',
#     n_perm=n_perm,
#     n_rep=n_rep,
#     n_Sigmas=n_Sigmas,
#     n_test=0,#1000,
#     pxs=pxs,
#     py='px',
#     n_per_ftrs=n_per_ftrs,
#     rs=r_trues,
#     powerlaw_decay=(-1, -1),
#     random_state=0,
#     addons=[weights_pc_cossim],
# )
# res_pls_100_.to_netcdf('pcBias_pls_pls_random_sum+-2+-2.nc')
[55]:
res = dict(
    cca=xr.load_dataset('./pcBias_cca_cca_random_sum+-2+-2.nc').sel(mode=0),
    pls=xr.load_dataset('./pcBias_pls_pls_random_sum+-2+-2.nc').sel(mode=0),
)

What’s in the outcome data files?

[56]:
print_ds_stats(res['cca'])
n_rep            100
n_per_ftr        [  3   5   8  16  32  64 128 200]
r                [0.1 0.3 0.5]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 3)>
array([[10, 10, 10],
       [10, 10, 10],
       [10, 10, 10],
       [10, 10, 10],
       [10, 10, 10],
       [10, 10, 10]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5
  * px       (px) int64 2 4 8 16 32 64

power           not calculated
[57]:
print_ds_stats(res['pls'])
n_rep            100
n_per_ftr        [  3   5   8  16  32  64 128 200]
r                [0.1 0.3 0.5]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 3)>
array([[10, 10, 10],
       [10, 10, 10],
       [10, 10, 10],
       [10, 10, 10],
       [10, 10, 10],
       [10, 10, 10]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5
  * px       (px) int64 2 4 8 16 32 64

power           not calculated
[58]:
def plot_weight_pc_cossim(ds, n_pcs=None, color=None, qs=(.025, .975)):

    if n_pcs is None:
        n_pcs = ds.px.values.max()

    sim = np.abs(ds.x_weights_pc_cossim.sel(px=n_pcs))

    other_dims = [d for d in sim.dims if d != 'x_pc']
    sim = sim.stack(it=other_dims)

    random_data = np.abs(scipy.stats.beta.rvs((n_pcs-1)/2, (n_pcs-1)/2, size=(10000)) * 2 - 1)

    panel = (
        hv.HLine(random_data.mean())
        * hv.Curve(
            (np.arange(1, len(sim)+1), sim.mean('it')),
            label=r'$\langle$synthetic' + '\n' + r'datasets$\rangle$'
        )
        * hv.Area(
            (np.arange(1, len(sim)+1), sim.quantile(qs[0], 'it'), sim.quantile(qs[1], 'it')),
            vdims=['y', 'y2']
        )
    ).opts(
        opts.Curve(color=color, xlim=(0.01, n_pcs+1)),
        opts.Area(linewidth=0, alpha=.5, color=color),
        opts.HLine(color=clr_random, linestyle='-', linewidth=1),
    ).redim(
        x='Principal component',
        y='Weight PC similarity'
    )
    return panel
[59]:
def hook_weight_pc_cossim_legend(plot, element):

    ax = plot.handles['axis']

    patch_rnd = mpatches.Patch(color=clr_random, label='random')
    patch_2 = mpatches.Patch(color=n_per_ftr_clrs[1], label='{} samples/ftr'.format(n_per_ftr_typical))
    patch_50 = mpatches.Patch(color=n_per_ftr_clrs[-1], label='200 samples/ftr')

    ax.legend(
        handles=[
            patch_rnd, patch_2, patch_50,
        ],
        frameon=False, fontsize=7,
        handlelength=1, ncol=1, columnspacing=.5,
        loc='upper right', bbox_to_anchor=(1.125, 1.0))


panels_cca_weight_pc_cossim = (
    plot_weight_pc_cossim(res['cca'].sel(r=0.3, n_per_ftr=200), color=n_per_ftr_clrs[-1])
    * plot_weight_pc_cossim(res['cca'].sel(r=0.3, n_per_ftr=n_per_ftr_typical), color=n_per_ftr_clrs[1])
).relabel(
    r'$r_\mathrm{true}=$0.3, 64 ftrs/set'
).opts(
    logx=False, logy=False, show_legend=True,
    hooks=[hook_weight_pc_cossim_legend,
    ]
)


panels_pls_weight_pc_cossim = (
    plot_weight_pc_cossim(res['pls'].sel(r=0.3, n_per_ftr=200), color=n_per_ftr_clrs[-1])
    * plot_weight_pc_cossim(res['pls'].sel(r=0.3, n_per_ftr=n_per_ftr_typical), color=n_per_ftr_clrs[1])
).relabel(
    r'$r_\mathrm{true}=$0.3, 64 ftrs/set'
).opts(
    show_legend=False, logx=False, logy=False, ylabel='',
)


(
    panels_cca_weight_pc_cossim
    + panels_pls_weight_pc_cossim
).opts(*fig_opts)
[59]:
[60]:
hv.save(fig_required_sample_size_heatmap, 'fig/figS_weight_error_required_sample_size_heatmap.pdf')

fig_required_sample_size_heatmap
[60]:

How strong is the PC bias depending on sample size and true between-set correlation?

We define “PC bias” as the mean cosine-similarity between estimated weights and the 1st principal component axis. The mean is taken across repetitions, dimensionalities and different instantiations of the joint covariance matrix.

[61]:
panels_nPerFtrs_cossim = {model: hv.Overlay() for model in res}
qs = (.025, .975)
rs = (.1, .5)

plotted_data = {}

for model in res:

    mean_beta_rvs = xr.DataArray(pd.Series(
        {px: np.abs(scipy.stats.beta.rvs((px-1)/2, (px-1)/2, size=(10000)) * 2 - 1).mean()
         for px in res[model].px.values},
    )).rename(dim_0='px')

    stack_dims = ['rep', 'Sigma_id', 'px']
    ensemble_info = f'INFO: {model.upper()} ensemble dims: '
    for d in stack_dims:
        ensemble_info += f'n_{d}={len(res[model][d])}; '
    print(ensemble_info)

    delta_pc1bias = np.maximum(
        (np.abs(res[model].x_weights_pc_cossim.sel(x_pc=0)) - mean_beta_rvs),
        (np.abs(res[model].y_weights_pc_cossim.sel(y_pc=0)) - mean_beta_rvs)
    ).stack(stacked=stack_dims)

    _plotted_data = {}
    for r in rs:
        panels_nPerFtrs_cossim[model] *= (
            hv.Curve(delta_pc1bias.sel(r=r).mean('stacked'), label=r'$r_\mathrm{true}=%.1f$' % r)
            * hv.Area(
                (delta_pc1bias.n_per_ftr, delta_pc1bias.sel(r=r).quantile(qs[0], 'stacked'), delta_pc1bias.sel(r=r).quantile(qs[1], 'stacked')),
                vdims=['y', 'y2']
            )
        )
        _plotted_data[r] = pd.concat(
            [
                delta_pc1bias.sel(r=r).mean('stacked').rename('mean').to_series(),
                delta_pc1bias.sel(r=r).quantile(qs[0], 'stacked').rename(f'{100*qs[0]:.1f}%').to_series(),
                delta_pc1bias.sel(r=r).quantile(qs[1], 'stacked').rename(f'{100*qs[1]:.1f}%').to_series(),
            ],
            axis=1
        )

    _plotted_data = pd.concat(_plotted_data, axis=1)
    _plotted_data.columns.names = ['r_true', 'curve']
    plotted_data[model.upper() + '_avgWeightPC1Similarity'] = _plotted_data

fig_nPerFtrs_cossim = (
    (
        panels_nPerFtrs_cossim['cca']
        * hv.Text(40, .65, '0.1', fontsize=7, halign='center')
        # * hv.Text(40, .25, '0.3', fontsize=7, halign='center').opts(color=r_clrs[1])
        * hv.Text(40, .5, '0.5', fontsize=7, halign='center')
    ).opts(
        opts.Overlay(show_legend=False),
        opts.Text(color=hv.Palette(cmap_r))
    ) * (
        hv.Text(40, .8, r'$r_\mathrm{true}=$', fontsize=7, halign='center').opts(color='black')
    )
    + panels_nPerFtrs_cossim['pls'].opts(show_legend=False, ylabel='')
).redim(
    n_per_ftr='Samples per feature',
    y=hv.Dimension('cossim_weights_1stPCaxis', label='Avg. weight PC1 similarity\n(data — random)')
).opts(*fig_opts).opts(
    opts.Curve(color=hv.Palette(cmap_r)),
    opts.Area(linewidth=0, alpha=.25, color=hv.Palette(cmap_r)),
    opts.Overlay(logx=False, show_legend=False)
)

fig_nPerFtrs_cossim
INFO: CCA ensemble dims: n_rep=100; n_Sigma_id=10; n_px=6;
INFO: PLS ensemble dims: n_rep=100; n_Sigma_id=10; n_px=6;
[61]:
[62]:
save_source_data(plotted_data, 'fig4')

Assemble figure

[63]:
# holoviews hooks

def set_axis_position(plot, element):
    ax = plot.handles['axis']
    bbox = ax.get_position()

    if bbox.x0 < .1:
        x0 = .05
    elif bbox.x0 < .35:
        x0 = .26
    elif bbox.x0 < .6:
        x0 = .55
    else:
        x0 = .77

    if bbox.y0 < .1:
        y0 = .05
    elif bbox.y0 < .5:
        y0 = .05 + .3
    else:
        y0 = .05 + .6

    ax.set_position((x0, y0, .14, .2))


def legend_samples_per_feature(plot, element):
    ax = plot.handles['axis']
    fontdict=dict(size=8)
    text = ax.text(1.05, .775, '5', ha='left', color=n_per_ftr_clrs[1], transform=ax.transAxes, fontdict=fontdict)
    ax.text(1.25, .775, '50', ha='right', color=n_per_ftr_clrs[2], transform=ax.transAxes, fontdict=fontdict)
    ax.text(1.15, .9, 'smpls/ftr', ha='center', transform=ax.transAxes, fontdict=fontdict)


def axpos_middlerow(plot, element):
    return
    ax = plot.handles['axis']
    bbox = ax.get_position()
    ax.set_position((bbox.x0, bbox.y0+.025, bbox.width, bbox.height))


def axpos_bottomrow(plot, element):
    return
    ax = plot.handles['axis']
    bbox = ax.get_position()
    ax.set_position((bbox.x0+.025, .05, .85*bbox.width, .9*bbox.height))


def synth_perm_legend(plot, element):

    ax = plot.handles['axis']

    line_synt = Line2D([0], [0], color='black', linewidth=1, linestyle='-', label=r'synth')
    line_null = Line2D([0], [0], color='black', linewidth=1, linestyle='--', label=r'perm')

    ax.legend(handles=[line_synt, line_null], frameon=False, fontsize=7,
              handlelength=1, ncol=1,
              loc='upper right', bbox_to_anchor=(1., 1.03))
[64]:
fig = (
    fig_delta_weight_convergence

    + plot_cossim_weights_stats(res_cossim_weights['cca'].sel(px=100, r=0.3), outcome='x_weights_pairwise_cossim_stats', color=clr_cca).opts(
        logx=False, logy=False
    )
    + (
        plot_cossim_weights_stats(res_cossim_weights['pls'].sel(px=100, r=0.3), color=clr_pls)
    ).opts(ylim=(0, 1), yaxis='bare', logx=False, logy=False)

    + fig_nPerFtrs_cossim.Overlay.I.opts(opts.Overlay(hooks=[axpos_bottomrow, legend_frame_off]))
    + fig_nPerFtrs_cossim.Overlay.II.opts(opts.Overlay(hooks=[axpos_bottomrow]))
).cols(
    2
).opts(*fig_opts).opts(
    opts.QuadMesh(logx=True, logz=True, ylim=(0, .8),
                  clim=(1e1, 1e5), cmap='inferno',
                  fontsize=dict(labels=8, ticks=7, legend=7, title=8),
                  aspect='auto'
                 ),
    opts.Overlay(aspect='auto'),
    opts.Layout(hspace=.25, vspace=.7, fig_inches=(3.42, 4.25), sublabel_position=(-.3, .95))
)

hv.save(fig, 'fig/fig3_weights.pdf')

fig
[64]: