Samples per feature dependence

How do power, association strength as well as errors in weights, scores and loadings depend on sample size?

Setup

[1]:
import numpy as np
import pandas as pd
import xarray as xr
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, zscore
from scipy.spatial.distance import pdist, cdist, squareform

from sklearn import clone
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_validate, ShuffleSplit

from gemmr.data import load_outcomes, print_ds_stats
from gemmr.metrics import *
from gemmr.sample_size.interpolation import *
from gemmr.plot import mean_metric_curve
from gemmr.util import subset_ds

import matplotlib
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').param.set_param(dpi=120)

from my_config import *

import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
[3]:
rs = [.1, .3, .5]
[4]:
r_clrs = hv.Palette(cmap_r, samples=len(rs)).values
[5]:
ds_cca = load_outcomes('sweep_cca_cca_random_sum+-2+-2').sel(mode=0)
ds_pls = load_outcomes('sweep_pls_pls_random_sum+-2+-2').sel(mode=0)
Loading data from subfolder 'gemmr_latest'
Loading data from subfolder 'gemmr_latest'
[6]:
ds_cca = ds_cca.sel(px=ds_cca.px<128)
ds_pls = ds_pls.sel(px=ds_pls.px<128)
[7]:
ds_cca = subset_ds(ds_cca, n_keep=25)
ds_pls = subset_ds(ds_pls, n_keep=25)

What’s in the outcome data files?

[8]:
print_ds_stats(ds_cca)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated
[9]:
print_ds_stats(ds_pls)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0,  0,  0, 10]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated

Figure

[10]:
def mean_metric_curve(metric, rs=(.1, .5), n_per_ftr_typical=5,
                      ylabel=None, qs=(.025, .975), plotted_data=None, panel_label=None):
    """Plots mean curves for given ``rs`` as a function of ``n_per_ftr``.

    Parameters
    ----------
    metric : xr.DataArray
        must have dimensions ``r`` and ``n_per_ftr``, all other dimensions
        are averaged over
    rs : tuple-like
        separate curves are plotted for each entry of ``rs``
    n_per_ftr_typical : int or None
        if not ``None``, a vertical dashed line is plotted at this value
    ylabel : str or None
        y-label

    Returns
    -------
    panel : hv.Overlay
    """

    other_dims = [d for d in metric.dims if d not in ['r', 'n_per_ftr']]
    if len(other_dims) > 0:

        ensemble_info_str = 'INFO: Ensemble dims: '
        for dim in other_dims:
            ensemble_info_str += f'n_{dim}={len(metric[dim])}; '
        print(ensemble_info_str)

        metric = metric.stack(DUMMYDIM=other_dims)



    if ylabel is not None:
        metric = metric.rename(ylabel)

    panel = hv.Overlay()
    _plotted_data = {}
    for r in rs:
        _metric = metric.sel(r=r).dropna('n_per_ftr', how='any')
        panel *= hv.Curve(_metric.mean('DUMMYDIM'))
        panel *= hv.Area(
            (_metric.n_per_ftr, _metric.quantile(qs[0], 'DUMMYDIM'), _metric.quantile(qs[1], 'DUMMYDIM')),
            vdims=['y', 'y2']
        )

        if plotted_data is not None:
            __metric = _metric.drop('r')
            df = pd.concat(
                [
                    __metric.mean('DUMMYDIM').rename('mean').to_series(),
                    __metric.quantile(qs[0], 'DUMMYDIM').drop('quantile').rename(f'{100*qs[0]:.1f}%').to_series(),
                    __metric.quantile(qs[1], 'DUMMYDIM').drop('quantile').rename(f'{100*qs[1]:.1f}%').to_series(),
                ],
                axis=1
            )
            _plotted_data[r] = df

    if plotted_data is not None:
        df = pd.concat(_plotted_data, axis=1)
        df.columns.names = ['r_true', 'curve']
        plotted_data[panel_label]  = df

    if n_per_ftr_typical is not None:
        panel *= hv.VLine(n_per_ftr_typical)

    return panel
[11]:
def format_assocStrength_axis(plot, element):
    return
    yax = plot.handles['axis'].yaxis
    yax.set_minor_formatter(matplotlib.ticker.NullFormatter())
    yax.set_minor_locator(matplotlib.ticker.LogLocator(subs=(2,3,4,5,6,7,8,9)))
    yax.set_major_formatter(matplotlib.ticker.LogFormatterSciNotation())
    yax.set_major_locator(matplotlib.ticker.LogLocator(subs=(1,)))


px_plot = [8]  # slice(None)
plotted_data = {}

fig = (
    # --- power ---
    (
        mean_metric_curve(ds_cca.sel(px=px_plot).power, ylabel='Power', n_per_ftr_typical=None, plotted_data=plotted_data, panel_label='CCA_power')
        * hv.Text(300, 0.4, r'$r_\mathrm{true}=$', fontsize=7, halign='right', valign='top')
        * hv.Text(300, 0.3, r'$0.5$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[2])
        # * hv.Text(300, 0.2, r'$0.3$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[1])
        * hv.Text(300, 0.2, r'$0.1$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[0])
        # * hv.Text(n_per_ftr_typical, 1.1, '  typical', fontsize=7, halign='center', valign='bottom')

    ).opts(xlabel='', ylim=(None, 1.01), hooks=[suptitle_cca])
    + (
        mean_metric_curve(ds_pls.sel(px=px_plot).power, ylabel='Power', n_per_ftr_typical=None, plotted_data=plotted_data, panel_label='PLS_power')
        # * hv.Text(n_per_ftr_typical, 1.1, '  typical', fontsize=7, halign='center', valign='bottom')
    ).opts(xlabel='', ylabel='', ylim=(None, 1.01), hooks=[suptitle_pls])
    # --- association strength ---
    + (
        mean_metric_curve(ds_cca.sel(px=px_plot).between_assocs, ylabel='Association strength', n_per_ftr_typical=None, plotted_data=plotted_data, panel_label='CCA_assocStrength')
        * hv.HLine(.1)
        # * hv.HLine(.3)
        * hv.HLine(.5)
        * (
            hv.Text(10, .095, 'true', fontsize=7, halign='left', valign='top')
            # * hv.Text(300, .29, 'true', fontsize=7, halign='right', valign='top')
            * hv.Text(300, .48, 'true', fontsize=7, halign='right', valign='top')
        ).opts(opts.Text(color=hv.Palette(cmap_r)))
        * hv.Text(3, .8, 'corr. ', fontsize=8, halign='center', valign='bottom').opts(color='black')
    ).opts(xlabel='', yticks=5, ylim=(.0, .75), hooks=[format_assocStrength_axis])
    + (
        mean_metric_curve(ds_pls.sel(px=px_plot).between_assocs, ylabel='Between-set correlation', n_per_ftr_typical=None, plotted_data=plotted_data, panel_label='PLS_assocStrength')
        * hv.HLine(float(ds_pls.sel(px=px_plot, r=0.1).between_assocs_true.mean('px').mean('Sigma_id')))
        * hv.HLine(float(ds_pls.sel(px=px_plot, r=0.5).between_assocs_true.mean('px').mean('Sigma_id')))
        * (
            hv.Text(10, .035, 'true', fontsize=7, halign='left', valign='top')
            # * hv.Text(10, .108, 'true', fontsize=7, halign='left', valign='top')
            * hv.Text(290, .17, 'true', fontsize=7, halign='right', valign='top')
        ).opts(opts.Text(color=hv.Palette(cmap_r)))
        * hv.Text(3, .47, 'cov. ', fontsize=8, halign='center', valign='bottom').opts(color='black')
    ).opts(xlabel='', ylabel='', yticks=5, ylim=(0, .45), hooks=[format_assocStrength_axis])
    # --- other error metrics ---
    + mean_metric_curve(mk_weightError(ds_cca.sel(px=px_plot)), ylabel='Weight error', n_per_ftr_typical=None, plotted_data=plotted_data, panel_label='CCA_weightError').opts(xlabel='', ylim=(None, 1))
    + mean_metric_curve(mk_weightError(ds_pls.sel(px=px_plot)), ylabel='Weight error', n_per_ftr_typical=None, plotted_data=plotted_data, panel_label='PLS_weightError').opts(xlabel='', ylabel='', ylim=(None, 1))
    + mean_metric_curve(mk_scoreError(ds_cca.sel(px=px_plot)), ylabel='Score error', n_per_ftr_typical=None, plotted_data=plotted_data, panel_label='CCA_scoreError').opts(xlabel='', ylim=(None, 1))
    + mean_metric_curve(mk_scoreError(ds_pls.sel(px=px_plot)), ylabel='Score error', n_per_ftr_typical=None, plotted_data=plotted_data, panel_label='PLS_scoreError').opts(xlabel='', ylabel='', ylim=(None, 1))
    + mean_metric_curve(mk_loadingError(ds_cca.sel(px=px_plot)), ylabel='Loading error', n_per_ftr_typical=None, plotted_data=plotted_data, panel_label='CCA_loadingError').opts(ylim=(None, 1))
    + mean_metric_curve(mk_loadingError(ds_pls.sel(px=px_plot)), ylabel='Loading error', n_per_ftr_typical=None, plotted_data=plotted_data, panel_label='PLS_loadingError').opts(ylabel='', ylim=(None, 1))
).redim(
    n_per_ftr='Samples per feature'
).cols(
    2
).opts(*fig_opts).opts(
    opts.Curve(color=hv.Palette(cmap_r)),
    opts.Area(linewidth=0, alpha=.3, color=hv.Palette(cmap_r)),
    opts.HLine(color=hv.Palette(cmap_r), linewidth=1, linestyle='--'),
    opts.VLine(color='black', linestyle='--', linewidth=1),
    opts.Overlay(logx=False, logy=False, xlim=(3, 300), sublabel_position=(-.4, .95)),
    opts.Layout(hspace=.35),
)

hv.save(fig, 'fig/fig2_samples_per_feature_dependence.pdf')
save_source_data(plotted_data, 'fig2')

fig
INFO: Ensemble dims: n_px=1; n_Sigma_id=25;
INFO: Ensemble dims: n_px=1; n_Sigma_id=25;
INFO: Ensemble dims: n_px=1; n_Sigma_id=25; n_rep=100;
INFO: Ensemble dims: n_px=1; n_Sigma_id=25; n_rep=100;
INFO: Ensemble dims: n_px=1; n_Sigma_id=25; n_rep=100;
INFO: Ensemble dims: n_px=1; n_Sigma_id=25; n_rep=100;
INFO: Ensemble dims: n_px=1; n_Sigma_id=25; n_rep=100;
INFO: Ensemble dims: n_px=1; n_Sigma_id=25; n_rep=100;
INFO: Ensemble dims: n_px=1; n_Sigma_id=25; n_rep=100;
INFO: Ensemble dims: n_px=1; n_Sigma_id=25; n_rep=100;
[11]:

Association strength errors

[12]:
px = [8]

fig = (
    # --- Association strength
    (
        mean_metric_curve(mk_betweenAssocRelError(ds_cca.sel(px=px)), ylabel='Relative association\nstrength error', n_per_ftr_typical=None)
        * hv.HLine(.1)
        * hv.Text(300, 8, r'$r_\mathrm{true}=$', fontsize=7, halign='right', valign='top')
        * hv.Text(300, 6, r'$0.5$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[2])
        # * hv.Text(300, 0.2*(.7)**2, r'$0.3$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[1])
        * hv.Text(300, 4, r'$0.1$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[0])
    ).opts(xlabel='Samples / feature', yticks=5, ylim=(.09, None), title='CCA\nbetween-set correlation')
    + (
        mean_metric_curve(mk_betweenCorrRelError(ds_pls.sel(px=px)), ylabel='Assoc strength2', n_per_ftr_typical=None)
        * hv.HLine(0.1)
    ).opts(xlabel='Samples / feature', ylabel='', yticks=5, ylim=(5e-2, None), title='PLS\nbetween-set correlation')
    + (
        mean_metric_curve(mk_betweenAssocRelError(ds_pls.sel(px=px)), ylabel='Assoc strength3', n_per_ftr_typical=None)
        * hv.HLine(0.1)
    ).opts(xlabel='Samples / feature', ylabel='', yticks=5, ylim=(5e-2, None), title='PLS\nbetween-set covariance')
).redim(
    n_per_ftr='Samples per feature'
).cols(
    3
).opts(*fig_opts).opts(
    opts.Curve(color=hv.Palette(cmap_r)),
    opts.Area(color=hv.Palette(cmap_r), linewidth=0.5, alpha=.5),
    opts.HLine(color='black', linewidth=1, linestyle='--'),
    opts.VLine(color='black', linestyle='--', linewidth=1),
    opts.Overlay(logx=False, logy=False, xlim=(3, 300), ylim=(-1, 12), sublabel_position=(-.4, .95)),
    opts.Layout(hspace=.35, fig_inches=(3*1.7, None)),
)

hv.save(fig, 'fig/figS_assoc_strength_error_converges_to_0.pdf')

fig
INFO: Ensemble dims: n_px=1; n_Sigma_id=25; n_rep=100;
INFO: Ensemble dims: n_px=1; n_Sigma_id=25; n_rep=100;
INFO: Ensemble dims: n_px=1; n_Sigma_id=25; n_rep=100;
[12]: