Samples per feature dependence

How do power, association strength as well as errors in weights, scores and loadings depend on sample size?

Setup

[1]:
import numpy as np
import xarray as xr
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, zscore
from scipy.spatial.distance import pdist, cdist, squareform

from sklearn import clone
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_validate, ShuffleSplit

from gemmr.data import load_outcomes, print_ds_stats
from gemmr.metrics import *
from gemmr.sample_size.interpolation import *
from gemmr.plot import mean_metric_curve

import matplotlib
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').set_param(dpi=120)

from my_config import *

import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
[2]:
rs = [.1, .3, .5]
[3]:
r_clrs = hv.Palette(cmap_r, samples=len(rs)).values
[4]:
ds_cca = load_outcomes('cca').sel(mode=0)
ds_pls = load_outcomes('pls', tag='axPlusay-2').sel(mode=0)

What’s in the outcome data files?

[5]:
print_ds_stats(ds_cca)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024]
r                [0.1 0.3 0.5 0.7 0.9]
px               [  2   4   8  16  32  64 128]
ax+ay range     (0.00, 0.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32 64 128

power           calculated
[6]:
print_ds_stats(ds_pls)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7 0.9]
px               [  2   4   8  16  32  64 128]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0,  0,  0, 25, 25],
       [ 0,  0,  0, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32 64 128

power           calculated

Figure

[7]:
def format_assocStrength_axis(plot, element):
    yax = plot.handles['axis'].yaxis
    yax.set_minor_formatter(matplotlib.ticker.NullFormatter())
    yax.set_minor_locator(matplotlib.ticker.LogLocator(subs=(2,3,4,5,6,7,8,9)))
    yax.set_major_formatter(matplotlib.ticker.LogFormatterSciNotation())
    yax.set_major_locator(matplotlib.ticker.LogLocator(subs=(1,)))


fig = (
    # --- power ---
    (
        mean_metric_curve(ds_cca.power, ylabel='Power')
        * hv.Text(300, 0.2, r'$r_\mathrm{true}=$', fontsize=7, halign='right', valign='top')
        * hv.Text(300, 0.2*(.7)**1, r'$0.5$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[2])
        * hv.Text(300, 0.2*(.7)**2, r'$0.3$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[1])
        * hv.Text(300, 0.2*(.7)**3, r'$0.1$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[0])
        * hv.Text(n_per_ftr_typical, 1.1, '  typical', fontsize=7, halign='center', valign='bottom')

    ).opts(xlabel='', ylim=(None, 1.01), hooks=[suptitle_cca])
    + (
        mean_metric_curve(ds_pls.power, ylabel='Power')
        * hv.Text(n_per_ftr_typical, 1.1, '  typical', fontsize=7, halign='center', valign='bottom')
    ).opts(xlabel='', ylabel='', ylim=(None, 1.01), hooks=[suptitle_pls])
    # --- association strength ---
    + (
        mean_metric_curve(ds_cca.between_assocs, ylabel='Association strength')
        * hv.HLine(.1)
        * hv.HLine(.3)
        * hv.HLine(.5)
        * (
            hv.Text(10, .105, 'true', fontsize=7, halign='left', valign='bottom')
            * hv.Text(300, .29, 'true', fontsize=7, halign='right', valign='top')
            * hv.Text(300, .48, 'true', fontsize=7, halign='right', valign='top')
        ).opts(opts.Text(color=hv.Palette(cmap_r)))
        * hv.Text(3, 1.1, 'corr. ', fontsize=8, halign='center', valign='bottom').opts(color='black')
    ).opts(xlabel='', yticks=5, ylim=(.09, 1.075), hooks=[format_assocStrength_axis])
    + (
        mean_metric_curve(ds_pls.between_assocs, ylabel='Assoc strength2')
        * hv.HLine(float(ds_pls.between_assocs_true.sel(r=0.1).mean('px').mean('Sigma_id')))
        * hv.HLine(float(ds_pls.between_assocs_true.sel(r=0.3).mean('px').mean('Sigma_id')))
        * hv.HLine(float(ds_pls.between_assocs_true.sel(r=0.5).mean('px').mean('Sigma_id')))
        * (
            hv.Text(10, .055, 'true', fontsize=7, halign='left', valign='bottom')
            * hv.Text(10, .108, 'true', fontsize=7, halign='left', valign='top')
            * hv.Text(300, .19, 'true', fontsize=7, halign='right', valign='top')
        ).opts(opts.Text(color=hv.Palette(cmap_r)))
        * hv.Text(3, 1.1, 'cov. ', fontsize=8, halign='center', valign='bottom').opts(color='black')
    ).opts(xlabel='', ylabel='', yticks=5, ylim=(5e-2, 1.075), hooks=[format_assocStrength_axis])
    # --- other error metrics ---
    + mean_metric_curve(mk_weightError(ds_cca), ylabel='Weight error').opts(xlabel='', ylim=(None, 1))
    + mean_metric_curve(mk_weightError(ds_pls), ylabel='Weight error').opts(xlabel='', ylabel='', ylim=(None, 1))
    + mean_metric_curve(mk_scoreError(ds_cca), ylabel='Score error').opts(xlabel='', ylim=(None, 1))
    + mean_metric_curve(mk_scoreError(ds_pls), ylabel='Score error').opts(xlabel='', ylabel='', ylim=(None, 1))
    + mean_metric_curve(mk_loadingError(ds_cca), ylabel='Loading error').opts(ylim=(None, 1))
    + mean_metric_curve(mk_loadingError(ds_pls), ylabel='Loading error').opts(ylabel='', ylim=(None, 1))
).redim(
    n_per_ftr='Samples per feature'
).cols(
    2
).opts(*fig_opts).opts(
    opts.Curve(color=hv.Palette(cmap_r)),
    opts.HLine(color=hv.Palette(cmap_r), linewidth=1, linestyle='--'),
    opts.VLine(color='black', linestyle='--', linewidth=1),
    opts.Overlay(logx=True, logy=True, xlim=(3, 300), sublabel_position=(-.4, .95)),
    opts.Layout(hspace=.35),
)

hv.save(fig, 'fig/fig2_samples_per_feature_dependence.pdf')

fig
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
[7]:

Alternative view

[8]:
fig = (
    (
        mean_metric_curve(ds_cca.power, ylabel='Power', n_per_ftr_typical=None)
        * hv.HLine(.9)
        * hv.Text(300, 0.2, r'$r_\mathrm{true}=$', fontsize=7, halign='right', valign='top')
        * hv.Text(300, 0.2*(.7)**1, r'$0.5$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[2])
        * hv.Text(300, 0.2*(.7)**2, r'$0.3$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[1])
        * hv.Text(300, 0.2*(.7)**3, r'$0.1$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[0])

    ).relabel('CCA').opts(xlabel='')
    + (
        mean_metric_curve(ds_pls.power, ylabel='Power', n_per_ftr_typical=None).relabel('PLS').opts(xlabel='', ylabel='')
        * hv.HLine(.9)
    )
    + (
        mean_metric_curve(mk_betweenAssocRelError(ds_cca), ylabel='Association strength')
        * hv.HLine(.1)
    ).opts(xlabel='', yticks=5, ylim=(.09, None))
    + (
        mean_metric_curve(mk_betweenAssocRelError(ds_pls), ylabel='Assoc strength2', n_per_ftr_typical=None)
        * hv.HLine(0.1)
    ).opts(xlabel='', ylabel='', yticks=5, ylim=(5e-2, None))
    + (
        mean_metric_curve(mk_weightError(ds_cca), ylabel='Weight error', n_per_ftr_typical=None).opts(xlabel='', ylim=(None, 1))
        * hv.HLine(.1)
    )
    + (
        mean_metric_curve(mk_weightError(ds_pls), ylabel='Weight error', n_per_ftr_typical=None).opts(xlabel='', ylabel='', ylim=(None, 1))
        * hv.HLine(.1)
    )
    + (
        mean_metric_curve(mk_scoreError(ds_cca), ylabel='Score error', n_per_ftr_typical=None).opts(xlabel='', ylim=(None, 1))
        * hv.HLine(.1)
    )
    + (
        mean_metric_curve(mk_scoreError(ds_pls), ylabel='Score error', n_per_ftr_typical=None).opts(xlabel='', ylabel='', ylim=(None, 1))
        * hv.HLine(.1)
    )
    + (
        mean_metric_curve(mk_loadingError(ds_cca), ylabel='Loading error', n_per_ftr_typical=None).opts(ylim=(None, 1))
        * hv.HLine(.1)
    )
    + (
        mean_metric_curve(mk_loadingError(ds_pls), ylabel='Loading error', n_per_ftr_typical=None).opts(ylabel='', ylim=(None, 1))
        * hv.HLine(.1)
    )
).redim(
    n_per_ftr='Samples per feature'
).cols(
    2
).opts(*fig_opts).opts(
    opts.Curve(color=hv.Palette(cmap_r)),
    opts.HLine(color='black', linewidth=1, linestyle='--'),
    opts.VLine(color='black', linestyle='--', linewidth=1),
    opts.Overlay(logx=True, logy=True, xlim=(3, 300), sublabel_position=(-.4, .95)),
    opts.Layout(hspace=.35),
)

hv.save(fig, 'fig/fig2alternative_samples_per_feature_dependence.png')

fig
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
[8]: