Weight errors

How much do estimated CCA and PLS weights vary across repetitions, depending on sample size?

Setup

[1]:
import numpy as np
import xarray as xr
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, zscore
from scipy.spatial.distance import pdist, cdist, squareform

from sklearn import clone
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_validate, ShuffleSplit
from sklearn.utils import check_random_state

from gemmr.data import load_outcomes, print_ds_stats
from gemmr.generative_model import setup_model
from gemmr.sample_size import *
from gemmr.estimators import *
from gemmr.sample_analysis import *
from gemmr.plot import heatmap_n_req
from gemmr.util import nPerFtr2n

import matplotlib
import matplotlib.patheffects as path_effects
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').set_param(dpi=120)

from my_config import *

import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
warnings.filterwarnings(
    'ignore', 'aspect is not supported for Axes with xscale=log, yscale=linear', category=UserWarning
)  # holoviews emits this for log-linear plots
[2]:
r_clrs = hv.Palette(cmap_r, samples=3).values
n_per_ftr_clrs = hv.Palette(cmap_n_per_ftr, samples=4).values

Analyze data with 100 features per set and a correlation 0.3

[ ]:
r_true = 0.3
px = 100

n_mul_cca = 10
res_cca_100 = analyze_model_parameters(
    'cca',
    n_perm=0,
    n_rep=100,
    n_Sigmas=1,
    n_test=0,#1000,
    pxs=[px],
    py='px',
    n_per_ftrs=[n_per_ftr_typical, n_mul_cca*n_per_ftr_typical],
    rs=(r_true,),
    axPlusay_range=(0, 0),
    random_state=0,
    qx=.9,
    qy=.9,
)


n_mul_pls = 10
res_pls_100 = analyze_model_parameters(
    'pls',
    n_perm=0,
    n_rep=100,
    n_Sigmas=1,
    n_test=0,#1000,
    pxs=[px],
    py='px',
    n_per_ftrs=[n_per_ftr_typical, n_mul_pls*n_per_ftr_typical],
    rs=(r_true,),
    axPlusay_range=(-2, -2),
    random_state=0,
    qx=.9,
    qy=.9,
)

Illustration of weight uncertainty

[4]:
panels_weights_100 = dict()

qs = (.025, .975)

for model, res, n_mul, n_mul_clr in [
    ('cca', res_cca_100, n_mul_cca, n_per_ftr_clrs[1]),
    ('pls', res_pls_100, n_mul_pls, n_per_ftr_clrs[2])
]:

    _res = res.sel(mode=0, Sigma_id=0, px=px, r=0.3)
    _w_true = _res.x_weights_true
    w_true = _w_true.sortby(_w_true, ascending=False)
    w_q = _res.x_weights.quantile(qs, 'rep').sortby(_w_true, ascending=False)

    panels_weights_100[model] = (
        hv.Area(
            (np.arange(len(w_true)),
             w_q.sel(n_per_ftr=n_per_ftr_typical, quantile=qs[0]).values,
             w_q.sel(n_per_ftr=n_per_ftr_typical, quantile=qs[-1]).values),
            vdims=['y', 'y2'],
            label='$n_\mathrm{typical}$'
        ).opts(color=n_per_ftr_clrs[0])
        * hv.Area(
            (np.arange(len(w_true)),
             w_q.sel(n_per_ftr=n_mul*n_per_ftr_typical, quantile=qs[0]).values,
             w_q.sel(n_per_ftr=n_mul*n_per_ftr_typical, quantile=qs[-1]).values),
            vdims=['y', 'y2'],
            label='$%i\cdot n_\mathrm{typical}$' % n_mul
        ).opts(color=n_per_ftr_clrs[1])
        * hv.Curve(w_true.values, label='true').opts(color=n_per_ftr_clrs[-1])
        * hv.HLine(0).opts(color='white', linestyle='--', linewidth=1)
    ).redim(
        x='Feature id',
        y='Weight'
    ).opts(
        opts.Area(alpha=.8, linewidth=0, show_legend=False),
        opts.Curve(show_legend=False),
        opts.Overlay(xlim=(0, 100), ylim=(-np.abs(w_q).max(), np.abs(w_q).max()))
    )

(
    panels_weights_100['cca']
    + panels_weights_100['pls']
).opts(*fig_opts)
[4]:

How similar are estimated CCA and PLS weights across repetitions?

[5]:
res_cossim_weights = dict(
    cca=load_outcomes('cca', tag='cossimWeightStats').sel(mode=0),
    pls=load_outcomes('pls', tag='cossimWeightStats').sel(mode=0),
)

What’s in the outcome data files?

[6]:
print_ds_stats(res_cossim_weights['cca'])
n_rep            100
n_per_ftr        [  3   8  16  32  64 128 256 512]
r                [0.3]
px               [100]
ax+ay range     (0.00, 0.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 1, r: 1)>
array([[10]])
Coordinates:
  * r        (r) float64 0.3
  * px       (px) int64 100

power           not calculated
[7]:
print_ds_stats(res_cossim_weights['pls'])
n_rep            100
n_per_ftr        [  3   8  16  32  64 128 256 512]
r                [0.3]
px               [100]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 1, r: 1)>
array([[100]])
Coordinates:
  * r        (r) float64 0.3
  * px       (px) int64 100

power           not calculated
[8]:
def plot_cossim_weights_stats(res, outcome='x_cossim_weights_stats', color=None):

    stats = res[outcome]
    ptot = res.px + res.py

    panel = hv.Overlay()
    for i in stats.Sigma_id.values:
        panel *= hv.Curve(
            (stats.n_per_ftr, stats.sel(stat='mean', Sigma_id=i))
        ).opts(linewidth=1, linestyle='--', alpha=.1)

    panel *= (
        #hv.Area(
        #    (stats.n_per_ftr * ptot, stats.sel(stat='q2.5%'), stats.sel(stat='q97.5%')),
        #    vdims=['y', 'y2']
        #)
        hv.Curve(
            (stats.n_per_ftr, stats.sel(stat='mean').mean('Sigma_id'))
        )
    )

    return panel.redim(
        x='Samples per feature',
        y='Weight stability'
    ).opts(
        opts.Area(alpha=.33, color=color),
        opts.Curve(color=color),
        opts.Overlay(ylim=(0, 1), logx=True)
    )
[9]:
(
    plot_cossim_weights_stats(res_cossim_weights['cca'].sel(px=100, r=0.3), color=clr_cca)
    + (
        plot_cossim_weights_stats(res_cossim_weights['pls'].sel(px=100, r=0.3), color=clr_pls)
        #* hv.Text(700, .6, 'mean', halign='left', valign='top', fontsize=8)
        #* hv.Text(5500, .3, '1 cov\nmatrix', halign='right', valign='bottom', fontsize=8)
    ).opts(logx=True)
).opts(*fig_opts)
[9]:

How many samples are required to obtain at most 10% weight error?

[10]:
ds_cca = load_outcomes('cca').sel(mode=0)
ds_pls = load_outcomes('pls', tag='axPlusay-2').sel(mode=0)
[11]:
print_ds_stats(ds_cca)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024]
r                [0.1 0.3 0.5 0.7 0.9]
px               [  2   4   8  16  32  64 128]
ax+ay range     (0.00, 0.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32 64 128

power           calculated
[12]:
print_ds_stats(ds_pls)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7 0.9]
px               [  2   4   8  16  32  64 128]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0,  0,  0, 25, 25],
       [ 0,  0,  0, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32 64 128

power           calculated
[13]:
cca_n_req_per_ftr = calc_n_required_all_metrics(ds_cca, search_dim='n_per_ftr')
pls_n_req_per_ftr = calc_n_required_all_metrics(ds_pls, search_dim='n_per_ftr')
[14]:
fig_required_sample_size_heatmap = (
    heatmap_n_req(nPerFtr2n(cca_n_req_per_ftr['weightError'], ds_cca.py)).relabel('Weight error $\leq$ 10%')
    + heatmap_n_req(nPerFtr2n(pls_n_req_per_ftr['weightError'], ds_pls.py)).relabel('Weight error $\leq$ 10%').opts(colorbar=True, yaxis='bare')
).cols(
    2
).opts(*fig_opts).opts(
    opts.QuadMesh(logx=True, logz=True, ylim=(0, 1), #xlim=(2, 256+64)
                  clim=(1e1, 1e5), cmap='inferno',
                  sublabel_position=(-.45, .95),
                  fontsize=dict(labels=8, ticks=7, legend=7, title=8),
                 ),
    #opts.Overlay(),
    opts.Layout(fig_inches=(3.42, None))
)

fig_required_sample_size_heatmap
[14]:

Assemble figure

[15]:
px = 100
[17]:
def set_axis_position(plot, element):
    ax = plot.handles['axis']
    bbox = ax.get_position()

    if bbox.x0 < .1:
        x0 = .05
    elif bbox.x0 < .35:
        x0 = .26
    elif bbox.x0 < .6:
        x0 = .55
    else:
        x0 = .77

    if bbox.y0 < .1:
        y0 = .05
    elif bbox.y0 < .5:
        y0 = .05 + .3
    else:
        y0 = .05 + .6

    ax.set_position((x0, y0, .14, .2))


def legend_samples_per_feature(plot, element):
    ax = plot.handles['axis']
    fontdict=dict(size=8)
    text = ax.text(1.05, .9, '5', ha='left', color=n_per_ftr_clrs[0], transform=ax.transAxes, fontdict=fontdict)
    ax.text(1.25, .9, '50', ha='right', color=n_per_ftr_clrs[1], transform=ax.transAxes, fontdict=fontdict)
    ax.text(1.15, .8, 'smpls/ftr', ha='center', transform=ax.transAxes, fontdict=fontdict)
    # make '5' more visible
    text.set_path_effects([
        path_effects.Stroke(linewidth=.1, foreground='black'),
        path_effects.Normal()
    ])

fig = (
    (
        panels_weights_100['cca']
        # * hv.Text(100, 0.25, '%i samples/ftr' % int(n_per_ftr_typical), halign='right', valign='top', fontsize=8).opts(color='black')
        # * hv.Text(35, 0.1, '%i smpls/ftr' % int(n_mul_cca*n_per_ftr_typical), halign='left', valign='top', fontsize=8).opts(color='black')
        * hv.Text(90, -0.2, 'true', halign='right', valign='top', fontsize=8).opts(color=n_per_ftr_clrs[-1])
    ).relabel('$r_\mathrm{true}=0.3$, 100 ftrs/set').opts(
        xlim=(0, 100), ylim=(-.25, .25), sublabel_position=(-.35, .95), hooks=[suptitle_cca]
    )
    + (
        panels_weights_100['pls']
        # * hv.Text(3, -.1, '%i smpls/ftr' % int(n_per_ftr_typical), halign='left', valign='top', fontsize=8).opts(color=n_per_ftr_clrs[0])
        # * hv.Text(73, -0.1, '%i' % int(n_mul_pls*n_per_ftr_typical), halign='center', valign='top', fontsize=8).opts(color=n_per_ftr_clrs[1])
        # * hv.Text(73, -0.15, 'smpls/ftr', halign='center', valign='top', fontsize=8).opts(color=n_per_ftr_clrs[1])
        * hv.Text(10, 0.25, 'true', halign='left', valign='top', fontsize=8).opts(color=n_per_ftr_clrs[-1])
    ).relabel('$r_\mathrm{true}=0.3$, 100 ftrs/set').opts(
        xlim=(0, 100), ylim=(-.25, .25), yaxis='bare', sublabel_position=(-.3, .95), hooks=[suptitle_pls, legend_samples_per_feature]
    )
    + plot_cossim_weights_stats(res_cossim_weights['cca'].sel(px=100, r=0.3), color=clr_cca).opts(sublabel_position=(-.3, 1))
    + (
        plot_cossim_weights_stats(res_cossim_weights['pls'].sel(px=100, r=0.3), color=clr_pls)
        #* hv.Text(700, .6, 'mean', halign='left', valign='top', fontsize=8)
        #* hv.Text(6000, .25, '1 cov\nmatrix  ', halign='right', valign='bottom', fontsize=8)
    ).opts(ylim=(0, 1), yaxis='bare', sublabel_position=(-.25, 1))
    + heatmap_n_req(nPerFtr2n(cca_n_req_per_ftr['weightError'], ds_cca.py)).relabel('Weight error $\leq$ 10%').opts(sublabel_position=(-.3, .95))
    + heatmap_n_req(nPerFtr2n(pls_n_req_per_ftr['weightError'], ds_pls.py)).relabel('Weight error $\leq$ 10%').opts(colorbar=True, yaxis='bare', sublabel_position=(-.25, .95))
).cols(
    2
).opts(*fig_opts).opts(*fig_opts).opts(
    opts.QuadMesh(logx=True, logz=True, xlim=(2, 256+64), ylim=(0, 1),
                  clim=(1e1, 1e5), cmap='inferno',
                  fontsize=dict(labels=8, ticks=7, legend=7, title=8),
                 ),
    #opts.Overlay(sublabel_position=(-.35, .95)),
    opts.Layout(hspace=.25, vspace=.525)
)

hv.save(fig, 'fig/fig3_weight_errors.pdf')

fig
[17]: