Required sample sizes¶
How many samples are required to obatin stable CCA and PLS estimates?
Setup¶
[1]:
import numpy as np
import pandas as pd
import xarray as xr
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, zscore
from scipy.spatial.distance import pdist, cdist, squareform
from sklearn import clone
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_validate, ShuffleSplit
from gemmr.data import load_outcomes, print_ds_stats
from gemmr.metrics import *
from gemmr.sample_size.interpolation import *
from gemmr.plot import heatmap_n_req
from gemmr.util import nPerFtr2n
import matplotlib
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').set_param(dpi=120)
from my_config import *
import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
warnings.filterwarnings(
'ignore', 'aspect is not supported for Axes with xscale=log, yscale=linear', category=UserWarning
) # holoviews emits this for log-linear plots
Load data¶
Load pre-calculated outcomes of analyzed synthetic data
[2]:
ds_cca = load_outcomes('cca')
ds_pls = load_outcomes('pls', tag='axPlusay-2')
What’s in the outcome data files?
[3]:
print_ds_stats(ds_cca)
n_modes 1
n_rep 100
n_per_ftr [ 3 4 8 16 32 64 128 256 512 1024]
r [0.1 0.3 0.5 0.7 0.9]
px [ 2 4 8 16 32 64 128]
ax+ay range (0.00, 0.00)
py == px
<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[ 0, 25, 25, 25, 25],
[ 0, 25, 25, 25, 25]])
Coordinates:
* r (r) float64 0.1 0.3 0.5 0.7 0.9
* px (px) int64 2 4 8 16 32 64 128
power calculated
[4]:
print_ds_stats(ds_pls)
n_modes 1
n_rep 100
n_per_ftr [ 3 4 8 16 32 64 128 256 512 1024 2048 4096 8192]
r [0.1 0.3 0.5 0.7 0.9]
px [ 2 4 8 16 32 64 128]
ax+ay range (-2.00, -2.00)
py == px
<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[ 0, 25, 25, 25, 25],
[ 0, 25, 25, 25, 25],
[ 0, 0, 0, 25, 25],
[ 0, 0, 0, 25, 25]])
Coordinates:
* r (r) float64 0.1 0.3 0.5 0.7 0.9
* px (px) int64 2 4 8 16 32 64 128
power calculated
Required sample sizes¶
Let’s calculate the required sample size. This is implemented in ccapwr.required_sample_size
. The function calc_n_required_all_metrics
produces sample size estimates based on 5 metrics (power, relative error of between-set association strength, weight error, loading error, score error), as well as a combination of these 5 metrics (i.e. the maximum across all 5 metrics)
[5]:
cca_n_req_per_ftr = calc_n_required_all_metrics(ds_cca.sel(mode=0), search_dim='n_per_ftr')
pls_n_req_per_ftr = calc_n_required_all_metrics(ds_pls.sel(mode=0), search_dim='n_per_ftr')
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
Illustration of the required sample size based on the combined criterion
[6]:
fig_required_sample_size_heatmap = (
heatmap_n_req(nPerFtr2n(cca_n_req_per_ftr['combined'], ds_cca.py)).opts(hooks=[suptitle_cca])
+ heatmap_n_req(nPerFtr2n(pls_n_req_per_ftr['combined'], ds_pls.py)).opts(colorbar=True, hooks=[suptitle_pls])
).cols(
2
).opts(*fig_opts).opts(
opts.QuadMesh(logx=True, logz=True, #xlim=(2, 256+64), ylim=(0, 1),
clim=(1e1, 1e5), cmap='inferno',
sublabel_position=(-.45, .95),
fontsize=dict(labels=8, ticks=7, legend=7, title=8),
),
#opts.Overlay(),
opts.Layout(fig_inches=(3.42, None))
)
fig_required_sample_size_heatmap
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
mask |= resdat <= 0
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
mask |= resdat <= 0
[6]:
It’s somewhat hard to read off the numbers, therefore we replot this for individual values of \(r_\mathrm{true}\)
[7]:
def plot_n_req(n_req, qs=(.025, .975), color=None, label='mean'):
#n_req = n_req.assign_coords(px=2*n_req.px)#.rename(px='ptot')
n_req_q = n_req.quantile(qs, 'Sigma_id')
return (
hv.Curve(
(n_req.ptot, n_req.mean('Sigma_id')),
label=label
)
* hv.Area(
(n_req_q.ptot, n_req_q.sel(quantile=qs[0]), n_req_q.sel(quantile=qs[-1])),
vdims=['y', 'y2']
)
).redim(
x='Number of features',
y='Required sample size'
).opts(
opts.Curve(color=color),
opts.Area(alpha=.25, color=color),
opts.Overlay(logx=True, logy=True)
)
[8]:
panels_req_sample_size = {}
for r in [.1, .3, .5, .7, .9]:
panels_req_sample_size[r] = (
plot_n_req(nPerFtr2n(cca_n_req_per_ftr['combined'], ds_cca.py).sel(r=r), color=clr_cca, label='CCA')
* plot_n_req(nPerFtr2n(pls_n_req_per_ftr['combined'], ds_pls.py).sel(r=r), color=clr_pls, label='PLS')
).opts(
opts.Overlay(logx=True, logy=True, show_legend=True, hooks=[legend_frame_off])
)
hv.NdLayout(
panels_req_sample_size,
kdims=hv.Dimension('r_true', label=r'$r_\mathrm{true}$')
).cols(
5
)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/numpy/lib/nanfunctions.py:1391: RuntimeWarning: All-NaN slice encountered
result = np.apply_along_axis(_nanquantile_1d, axis, a, q,
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/numpy/lib/nanfunctions.py:1391: RuntimeWarning: All-NaN slice encountered
result = np.apply_along_axis(_nanquantile_1d, axis, a, q,
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/numpy/lib/nanfunctions.py:1391: RuntimeWarning: All-NaN slice encountered
result = np.apply_along_axis(_nanquantile_1d, axis, a, q,
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/numpy/lib/nanfunctions.py:1391: RuntimeWarning: All-NaN slice encountered
result = np.apply_along_axis(_nanquantile_1d, axis, a, q,
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
[8]:
Required sample sizes depending on used metric¶
[9]:
r = 0.3
panels = dict(cca=hv.Overlay(), pls=hv.Overlay())
metric_labels = dict(
power='power',
betweenAssoc='assoc. strength error',
weightError='weight error',
scoreError='score error',
loadingError='loading error',
combined='combined'
)
metric_clrs = hv.Palette('Dark2', samples=len(metric_labels)).values
for model, ds, n_req_per_ftr in [
('cca', ds_cca, cca_n_req_per_ftr),
('pls', ds_pls, pls_n_req_per_ftr),
]:
for metric_i, metric in enumerate(n_req_per_ftr):
if metric == 'combined':
linestyle = '--'
else:
linestyle = '-'
n_req = nPerFtr2n(n_req_per_ftr[metric], ds.py).sel(r=r)
panels[model] *= hv.Curve(n_req.mean('Sigma_id').dropna('ptot'), label=metric).opts(linestyle=linestyle, color=metric_clrs[metric_i])
if model == 'pls':
panels[model] *= hv.Text(256, 3 * (2.25)**metric_i, metric_labels[metric], fontsize=7, halign='right', valign='bottom').opts(color=metric_clrs[metric_i])
fig_req_sample_size_by_metric = (
panels['cca']
+ panels['pls'].opts(ylabel='')
).redim(
ptot='Number of features',
n_required='Req. sample size to obtain\n' + r'$\geq$ 90% power & $\leq$ 10%' + '\nerror in other metrics'
).opts(*fig_opts).opts(
opts.Curve(linewidth=1),
opts.Overlay(logx=True, logy=True, show_legend=False)
)
fig_req_sample_size_by_metric
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
[9]:
Required sample size per variable¶
Another way to condense the information shown in the heatmaps above, is to investigate the number of required samples per feature.
[10]:
def plot_n_per_ftr(n_req_per_ftr, color=None, label=''):
px_exponent = 1 # error_bars['max'][2, 1]
n_per_ftr = n_req_per_ftr.mean('Sigma_id').rename('y') # rename for redim below
n_per_ftr_stacked = n_per_ftr.stack(it=('r', 'px'))
panel = hv.Curve(n_per_ftr.mean('px')).opts(
color='black', linestyle='--', linewidth=1
)
for r in n_per_ftr.r.values:
panel *= hv.Curve(
([.85*r, 1.15*r], [float(n_per_ftr.sel(r=r).mean('px'))]*2),
kdims='r',
label=label,
).opts(
color=color,
linewidth=1
)
panel *= hv.Scatter(
(n_per_ftr_stacked.r.values * np.random.uniform(.95, 1.05, size=n_per_ftr_stacked.r.shape), n_per_ftr_stacked),
kdims='r',
#label=label,
).opts(marker='.', color=color)
panel = panel.redim(
r='$r_\mathrm{true}$',
y='Req. samples per feature'
).opts(
opts.Overlay(logx=True, logy=True)
)
return panel
[11]:
panels_req_n_per_variable = (
plot_n_per_ftr(pls_n_req_per_ftr['combined'], color=clr_pls, label='PLS').opts(yaxis='bare')#.relabel('PLS')
* plot_n_per_ftr(cca_n_req_per_ftr['combined'], color=clr_cca, label='CCA').opts(show_legend=False)#.relabel('CCA')
#* hv.Text(1, 2000, 'PLS', halign='right', valign='top', fontsize=8).opts(color=clr_pls)
#* hv.Text(1, 1000, 'CCA', halign='right', valign='top', fontsize=8).opts(color=clr_cca)
).opts(*fig_opts).opts(
opts.Scatter(s=15),
opts.Overlay(logx=True, logy=True, padding=.02, show_legend=False, fig_inches=(1.7, None)),
)
panels_req_n_per_variable
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
[11]:
Figure¶
[12]:
fig = (
(
panels_req_sample_size[0.3]
* hv.Text(256, 100, '$r_\mathrm{true}=0.3$', halign='right', valign='bottom', fontsize=8)
* hv.Text(5, 5e4, 'PLS', halign='left', valign='top', fontsize=8).opts(color=clr_pls)
* hv.Text(5, 2.5e4, 'CCA', halign='left', valign='top', fontsize=8).opts(color=clr_cca)
).opts(logx=True, logy=True, show_legend=False, ylabel='Req. sample size to obtain\n' + r'$\geq$ 90% power & $\leq$ 10%' + '\nerror in other metrics')
+ panels_req_n_per_variable.opts(show_legend=False)
#+ panel_lm_coeffs_pls.opts(xlim=(-.5, None)).relabel('linear model for $\log\ n_\mathrm{required}$') +
#+ panel_lm_coeffs_cca.opts(xlim=(-.5, None)).relabel('linear model for $\log\ n_\mathrm{required}$').opts(opts.Overlay(yaxis='bare'))
).cols(
2
).opts(*fig_opts).opts(
opts.Layout(sublabel_position=(-.3, .95), hspace=.6)
)
hv.save(fig, 'fig/fig8_required_sample_size.pdf')
fig
[12]:
[13]:
figS = (
fig_required_sample_size_heatmap
+ fig_req_sample_size_by_metric
).cols(
2
).opts(*fig_opts).opts(
opts.Layout(vspace=.5)
)
hv.save(figS, 'fig/figS_required_sample_size.pdf')
figS
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
mask |= resdat <= 0
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
mask |= resdat <= 0
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
mask |= resdat <= 0
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
mask |= resdat <= 0
[13]: