Samples per feature dependence

How do power, association strength as well as errors in weights, scores and loadings depend on sample size?


import numpy as np
import xarray as xr
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, zscore
from scipy.spatial.distance import pdist, cdist, squareform

from sklearn import clone
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_validate, ShuffleSplit

from import load_outcomes, print_ds_stats
from gemmr.metrics import *
from gemmr.sample_size.interpolation import *
from gemmr.plot import mean_metric_curve

import matplotlib
import holoviews as hv
from holoviews import opts

from my_config import *

import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
rs = [.1, .3, .5]
r_clrs = hv.Palette(cmap_r, samples=len(rs)).values
ds_cca = load_outcomes('cca').sel(mode=0)
ds_pls = load_outcomes('pls', tag='axPlusay-2').sel(mode=0)

What’s in the outcome data files?

n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024]
r                [0.1 0.3 0.5 0.7 0.9]
px               [  2   4   8  16  32  64 128]
ax+ay range     (0.00, 0.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25]])
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32 64 128

power           calculated
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7 0.9]
px               [  2   4   8  16  32  64 128]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0,  0,  0, 25, 25],
       [ 0,  0,  0, 25, 25]])
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32 64 128

power           calculated


def format_assocStrength_axis(plot, element):
    yax = plot.handles['axis'].yaxis

fig = (
    # --- power ---
        mean_metric_curve(ds_cca.power, ylabel='Power')
        * hv.Text(300, 0.2, r'$r_\mathrm{true}=$', fontsize=7, halign='right', valign='top')
        * hv.Text(300, 0.2*(.7)**1, r'$0.5$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[2])
        * hv.Text(300, 0.2*(.7)**2, r'$0.3$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[1])
        * hv.Text(300, 0.2*(.7)**3, r'$0.1$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[0])
        * hv.Text(n_per_ftr_typical, 1.1, '  typical', fontsize=7, halign='center', valign='bottom')

    ).opts(xlabel='', ylim=(None, 1.01), hooks=[suptitle_cca])
    + (
        mean_metric_curve(ds_pls.power, ylabel='Power')
        * hv.Text(n_per_ftr_typical, 1.1, '  typical', fontsize=7, halign='center', valign='bottom')
    ).opts(xlabel='', ylabel='', ylim=(None, 1.01), hooks=[suptitle_pls])
    # --- association strength ---
    + (
        mean_metric_curve(ds_cca.between_assocs, ylabel='Association strength')
        * hv.HLine(.1)
        * hv.HLine(.3)
        * hv.HLine(.5)
        * (
            hv.Text(10, .105, 'true', fontsize=7, halign='left', valign='bottom')
            * hv.Text(300, .29, 'true', fontsize=7, halign='right', valign='top')
            * hv.Text(300, .48, 'true', fontsize=7, halign='right', valign='top')
        * hv.Text(3, 1.1, 'corr. ', fontsize=8, halign='center', valign='bottom').opts(color='black')
    ).opts(xlabel='', yticks=5, ylim=(.09, 1.075), hooks=[format_assocStrength_axis])
    + (
        mean_metric_curve(ds_pls.between_assocs, ylabel='Assoc strength2')
        * hv.HLine(float(ds_pls.between_assocs_true.sel(r=0.1).mean('px').mean('Sigma_id')))
        * hv.HLine(float(ds_pls.between_assocs_true.sel(r=0.3).mean('px').mean('Sigma_id')))
        * hv.HLine(float(ds_pls.between_assocs_true.sel(r=0.5).mean('px').mean('Sigma_id')))
        * (
            hv.Text(10, .055, 'true', fontsize=7, halign='left', valign='bottom')
            * hv.Text(10, .108, 'true', fontsize=7, halign='left', valign='top')
            * hv.Text(300, .19, 'true', fontsize=7, halign='right', valign='top')
        * hv.Text(3, 1.1, 'cov. ', fontsize=8, halign='center', valign='bottom').opts(color='black')
    ).opts(xlabel='', ylabel='', yticks=5, ylim=(5e-2, 1.075), hooks=[format_assocStrength_axis])
    # --- other error metrics ---
    + mean_metric_curve(mk_weightError(ds_cca), ylabel='Weight error').opts(xlabel='', ylim=(None, 1))
    + mean_metric_curve(mk_weightError(ds_pls), ylabel='Weight error').opts(xlabel='', ylabel='', ylim=(None, 1))
    + mean_metric_curve(mk_scoreError(ds_cca), ylabel='Score error').opts(xlabel='', ylim=(None, 1))
    + mean_metric_curve(mk_scoreError(ds_pls), ylabel='Score error').opts(xlabel='', ylabel='', ylim=(None, 1))
    + mean_metric_curve(mk_loadingError(ds_cca), ylabel='Loading error').opts(ylim=(None, 1))
    + mean_metric_curve(mk_loadingError(ds_pls), ylabel='Loading error').opts(ylabel='', ylim=(None, 1))
    n_per_ftr='Samples per feature'
    opts.HLine(color=hv.Palette(cmap_r), linewidth=1, linestyle='--'),
    opts.VLine(color='black', linestyle='--', linewidth=1),
    opts.Overlay(logx=True, logy=True, xlim=(3, 300), sublabel_position=(-.4, .95)),
), 'fig/fig2_samples_per_feature_dependence.pdf')

/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/ RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)

Alternative view

fig = (
        mean_metric_curve(ds_cca.power, ylabel='Power', n_per_ftr_typical=None)
        * hv.HLine(.9)
        * hv.Text(300, 0.2, r'$r_\mathrm{true}=$', fontsize=7, halign='right', valign='top')
        * hv.Text(300, 0.2*(.7)**1, r'$0.5$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[2])
        * hv.Text(300, 0.2*(.7)**2, r'$0.3$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[1])
        * hv.Text(300, 0.2*(.7)**3, r'$0.1$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[0])

    + (
        mean_metric_curve(ds_pls.power, ylabel='Power', n_per_ftr_typical=None).relabel('PLS').opts(xlabel='', ylabel='')
        * hv.HLine(.9)
    + (
        mean_metric_curve(mk_betweenAssocRelError(ds_cca), ylabel='Association strength')
        * hv.HLine(.1)
    ).opts(xlabel='', yticks=5, ylim=(.09, None))
    + (
        mean_metric_curve(mk_betweenAssocRelError(ds_pls), ylabel='Assoc strength2', n_per_ftr_typical=None)
        * hv.HLine(0.1)
    ).opts(xlabel='', ylabel='', yticks=5, ylim=(5e-2, None))
    + (
        mean_metric_curve(mk_weightError(ds_cca), ylabel='Weight error', n_per_ftr_typical=None).opts(xlabel='', ylim=(None, 1))
        * hv.HLine(.1)
    + (
        mean_metric_curve(mk_weightError(ds_pls), ylabel='Weight error', n_per_ftr_typical=None).opts(xlabel='', ylabel='', ylim=(None, 1))
        * hv.HLine(.1)
    + (
        mean_metric_curve(mk_scoreError(ds_cca), ylabel='Score error', n_per_ftr_typical=None).opts(xlabel='', ylim=(None, 1))
        * hv.HLine(.1)
    + (
        mean_metric_curve(mk_scoreError(ds_pls), ylabel='Score error', n_per_ftr_typical=None).opts(xlabel='', ylabel='', ylim=(None, 1))
        * hv.HLine(.1)
    + (
        mean_metric_curve(mk_loadingError(ds_cca), ylabel='Loading error', n_per_ftr_typical=None).opts(ylim=(None, 1))
        * hv.HLine(.1)
    + (
        mean_metric_curve(mk_loadingError(ds_pls), ylabel='Loading error', n_per_ftr_typical=None).opts(ylabel='', ylim=(None, 1))
        * hv.HLine(.1)
    n_per_ftr='Samples per feature'
    opts.HLine(color='black', linewidth=1, linestyle='--'),
    opts.VLine(color='black', linestyle='--', linewidth=1),
    opts.Overlay(logx=True, logy=True, xlim=(3, 300), sublabel_position=(-.4, .95)),
), 'fig/fig2alternative_samples_per_feature_dependence.png')

/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/ RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)