Samples per feature dependence¶
How do power, association strength as well as errors in weights, scores and loadings depend on sample size?
Setup¶
[1]:
import numpy as np
import xarray as xr
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, zscore
from scipy.spatial.distance import pdist, cdist, squareform
from sklearn import clone
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_validate, ShuffleSplit
from gemmr.data import load_outcomes, print_ds_stats
from gemmr.metrics import *
from gemmr.sample_size.interpolation import *
from gemmr.plot import mean_metric_curve
import matplotlib
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').set_param(dpi=120)
from my_config import *
import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
[2]:
rs = [.1, .3, .5]
[3]:
r_clrs = hv.Palette(cmap_r, samples=len(rs)).values
[4]:
ds_cca = load_outcomes('cca').sel(mode=0)
ds_pls = load_outcomes('pls', tag='axPlusay-2').sel(mode=0)
What’s in the outcome data files?
[5]:
print_ds_stats(ds_cca)
n_rep 100
n_per_ftr [ 3 4 8 16 32 64 128 256 512 1024]
r [0.1 0.3 0.5 0.7 0.9]
px [ 2 4 8 16 32 64 128]
ax+ay range (0.00, 0.00)
py == px
<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[ 0, 25, 25, 25, 25],
[ 0, 25, 25, 25, 25]])
Coordinates:
* r (r) float64 0.1 0.3 0.5 0.7 0.9
* px (px) int64 2 4 8 16 32 64 128
power calculated
[6]:
print_ds_stats(ds_pls)
n_rep 100
n_per_ftr [ 3 4 8 16 32 64 128 256 512 1024 2048 4096 8192]
r [0.1 0.3 0.5 0.7 0.9]
px [ 2 4 8 16 32 64 128]
ax+ay range (-2.00, -2.00)
py == px
<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[ 0, 25, 25, 25, 25],
[ 0, 25, 25, 25, 25],
[ 0, 0, 0, 25, 25],
[ 0, 0, 0, 25, 25]])
Coordinates:
* r (r) float64 0.1 0.3 0.5 0.7 0.9
* px (px) int64 2 4 8 16 32 64 128
power calculated
Figure¶
[7]:
def format_assocStrength_axis(plot, element):
yax = plot.handles['axis'].yaxis
yax.set_minor_formatter(matplotlib.ticker.NullFormatter())
yax.set_minor_locator(matplotlib.ticker.LogLocator(subs=(2,3,4,5,6,7,8,9)))
yax.set_major_formatter(matplotlib.ticker.LogFormatterSciNotation())
yax.set_major_locator(matplotlib.ticker.LogLocator(subs=(1,)))
fig = (
# --- power ---
(
mean_metric_curve(ds_cca.power, ylabel='Power')
* hv.Text(300, 0.2, r'$r_\mathrm{true}=$', fontsize=7, halign='right', valign='top')
* hv.Text(300, 0.2*(.7)**1, r'$0.5$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[2])
* hv.Text(300, 0.2*(.7)**2, r'$0.3$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[1])
* hv.Text(300, 0.2*(.7)**3, r'$0.1$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[0])
* hv.Text(n_per_ftr_typical, 1.1, ' typical', fontsize=7, halign='center', valign='bottom')
).opts(xlabel='', ylim=(None, 1.01), hooks=[suptitle_cca])
+ (
mean_metric_curve(ds_pls.power, ylabel='Power')
* hv.Text(n_per_ftr_typical, 1.1, ' typical', fontsize=7, halign='center', valign='bottom')
).opts(xlabel='', ylabel='', ylim=(None, 1.01), hooks=[suptitle_pls])
# --- association strength ---
+ (
mean_metric_curve(ds_cca.between_assocs, ylabel='Association strength')
* hv.HLine(.1)
* hv.HLine(.3)
* hv.HLine(.5)
* (
hv.Text(10, .105, 'true', fontsize=7, halign='left', valign='bottom')
* hv.Text(300, .29, 'true', fontsize=7, halign='right', valign='top')
* hv.Text(300, .48, 'true', fontsize=7, halign='right', valign='top')
).opts(opts.Text(color=hv.Palette(cmap_r)))
* hv.Text(3, 1.1, 'corr. ', fontsize=8, halign='center', valign='bottom').opts(color='black')
).opts(xlabel='', yticks=5, ylim=(.09, 1.075), hooks=[format_assocStrength_axis])
+ (
mean_metric_curve(ds_pls.between_assocs, ylabel='Assoc strength2')
* hv.HLine(float(ds_pls.between_assocs_true.sel(r=0.1).mean('px').mean('Sigma_id')))
* hv.HLine(float(ds_pls.between_assocs_true.sel(r=0.3).mean('px').mean('Sigma_id')))
* hv.HLine(float(ds_pls.between_assocs_true.sel(r=0.5).mean('px').mean('Sigma_id')))
* (
hv.Text(10, .055, 'true', fontsize=7, halign='left', valign='bottom')
* hv.Text(10, .108, 'true', fontsize=7, halign='left', valign='top')
* hv.Text(300, .19, 'true', fontsize=7, halign='right', valign='top')
).opts(opts.Text(color=hv.Palette(cmap_r)))
* hv.Text(3, 1.1, 'cov. ', fontsize=8, halign='center', valign='bottom').opts(color='black')
).opts(xlabel='', ylabel='', yticks=5, ylim=(5e-2, 1.075), hooks=[format_assocStrength_axis])
# --- other error metrics ---
+ mean_metric_curve(mk_weightError(ds_cca), ylabel='Weight error').opts(xlabel='', ylim=(None, 1))
+ mean_metric_curve(mk_weightError(ds_pls), ylabel='Weight error').opts(xlabel='', ylabel='', ylim=(None, 1))
+ mean_metric_curve(mk_scoreError(ds_cca), ylabel='Score error').opts(xlabel='', ylim=(None, 1))
+ mean_metric_curve(mk_scoreError(ds_pls), ylabel='Score error').opts(xlabel='', ylabel='', ylim=(None, 1))
+ mean_metric_curve(mk_loadingError(ds_cca), ylabel='Loading error').opts(ylim=(None, 1))
+ mean_metric_curve(mk_loadingError(ds_pls), ylabel='Loading error').opts(ylabel='', ylim=(None, 1))
).redim(
n_per_ftr='Samples per feature'
).cols(
2
).opts(*fig_opts).opts(
opts.Curve(color=hv.Palette(cmap_r)),
opts.HLine(color=hv.Palette(cmap_r), linewidth=1, linestyle='--'),
opts.VLine(color='black', linestyle='--', linewidth=1),
opts.Overlay(logx=True, logy=True, xlim=(3, 300), sublabel_position=(-.4, .95)),
opts.Layout(hspace=.35),
)
hv.save(fig, 'fig/fig2_samples_per_feature_dependence.pdf')
fig
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
[7]:
Alternative view
[8]:
fig = (
(
mean_metric_curve(ds_cca.power, ylabel='Power', n_per_ftr_typical=None)
* hv.HLine(.9)
* hv.Text(300, 0.2, r'$r_\mathrm{true}=$', fontsize=7, halign='right', valign='top')
* hv.Text(300, 0.2*(.7)**1, r'$0.5$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[2])
* hv.Text(300, 0.2*(.7)**2, r'$0.3$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[1])
* hv.Text(300, 0.2*(.7)**3, r'$0.1$', fontsize=7, halign='right', valign='top').opts(color=r_clrs[0])
).relabel('CCA').opts(xlabel='')
+ (
mean_metric_curve(ds_pls.power, ylabel='Power', n_per_ftr_typical=None).relabel('PLS').opts(xlabel='', ylabel='')
* hv.HLine(.9)
)
+ (
mean_metric_curve(mk_betweenAssocRelError(ds_cca), ylabel='Association strength')
* hv.HLine(.1)
).opts(xlabel='', yticks=5, ylim=(.09, None))
+ (
mean_metric_curve(mk_betweenAssocRelError(ds_pls), ylabel='Assoc strength2', n_per_ftr_typical=None)
* hv.HLine(0.1)
).opts(xlabel='', ylabel='', yticks=5, ylim=(5e-2, None))
+ (
mean_metric_curve(mk_weightError(ds_cca), ylabel='Weight error', n_per_ftr_typical=None).opts(xlabel='', ylim=(None, 1))
* hv.HLine(.1)
)
+ (
mean_metric_curve(mk_weightError(ds_pls), ylabel='Weight error', n_per_ftr_typical=None).opts(xlabel='', ylabel='', ylim=(None, 1))
* hv.HLine(.1)
)
+ (
mean_metric_curve(mk_scoreError(ds_cca), ylabel='Score error', n_per_ftr_typical=None).opts(xlabel='', ylim=(None, 1))
* hv.HLine(.1)
)
+ (
mean_metric_curve(mk_scoreError(ds_pls), ylabel='Score error', n_per_ftr_typical=None).opts(xlabel='', ylabel='', ylim=(None, 1))
* hv.HLine(.1)
)
+ (
mean_metric_curve(mk_loadingError(ds_cca), ylabel='Loading error', n_per_ftr_typical=None).opts(ylim=(None, 1))
* hv.HLine(.1)
)
+ (
mean_metric_curve(mk_loadingError(ds_pls), ylabel='Loading error', n_per_ftr_typical=None).opts(ylabel='', ylim=(None, 1))
* hv.HLine(.1)
)
).redim(
n_per_ftr='Samples per feature'
).cols(
2
).opts(*fig_opts).opts(
opts.Curve(color=hv.Palette(cmap_r)),
opts.HLine(color='black', linewidth=1, linestyle='--'),
opts.VLine(color='black', linestyle='--', linewidth=1),
opts.Overlay(logx=True, logy=True, xlim=(3, 300), sublabel_position=(-.4, .95)),
opts.Layout(hspace=.35),
)
hv.save(fig, 'fig/fig2alternative_samples_per_feature_dependence.png')
fig
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
[8]: