Required sample sizes

How many samples are required to obatin stable CCA and PLS estimates?

Setup

[1]:
import numpy as np
import pandas as pd
import xarray as xr
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, zscore
from scipy.spatial.distance import pdist, cdist, squareform

from sklearn import clone
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_validate, ShuffleSplit

from gemmr.data import load_outcomes, print_ds_stats
from gemmr.metrics import *
from gemmr.sample_size.interpolation import *
from gemmr.plot import heatmap_n_req
from gemmr.util import nPerFtr2n

import matplotlib
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').set_param(dpi=120)

from my_config import *

import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
warnings.filterwarnings(
    'ignore', 'aspect is not supported for Axes with xscale=log, yscale=linear', category=UserWarning
)  # holoviews emits this for log-linear plots

Load data

Load pre-calculated outcomes of analyzed synthetic data

[2]:
ds_cca = load_outcomes('cca')
ds_pls = load_outcomes('pls', tag='axPlusay-2')

What’s in the outcome data files?

[3]:
print_ds_stats(ds_cca)
n_modes          1
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024]
r                [0.1 0.3 0.5 0.7 0.9]
px               [  2   4   8  16  32  64 128]
ax+ay range     (0.00, 0.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32 64 128

power           calculated
[4]:
print_ds_stats(ds_pls)
n_modes          1
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7 0.9]
px               [  2   4   8  16  32  64 128]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0,  0,  0, 25, 25],
       [ 0,  0,  0, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32 64 128

power           calculated

Required sample sizes

Let’s calculate the required sample size. This is implemented in ccapwr.required_sample_size. The function calc_n_required_all_metrics produces sample size estimates based on 5 metrics (power, relative error of between-set association strength, weight error, loading error, score error), as well as a combination of these 5 metrics (i.e. the maximum across all 5 metrics)

[5]:
cca_n_req_per_ftr = calc_n_required_all_metrics(ds_cca.sel(mode=0), search_dim='n_per_ftr')
pls_n_req_per_ftr = calc_n_required_all_metrics(ds_pls.sel(mode=0), search_dim='n_per_ftr')
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)

Illustration of the required sample size based on the combined criterion

[6]:
fig_required_sample_size_heatmap = (
    heatmap_n_req(nPerFtr2n(cca_n_req_per_ftr['combined'], ds_cca.py)).opts(hooks=[suptitle_cca])
    + heatmap_n_req(nPerFtr2n(pls_n_req_per_ftr['combined'], ds_pls.py)).opts(colorbar=True, hooks=[suptitle_pls])
).cols(
    2
).opts(*fig_opts).opts(
    opts.QuadMesh(logx=True, logz=True, #xlim=(2, 256+64), ylim=(0, 1),
                  clim=(1e1, 1e5), cmap='inferno',
                  sublabel_position=(-.45, .95),
                  fontsize=dict(labels=8, ticks=7, legend=7, title=8),
                 ),
    #opts.Overlay(),
    opts.Layout(fig_inches=(3.42, None))
)

fig_required_sample_size_heatmap
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
  mask |= resdat <= 0
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
  mask |= resdat <= 0
[6]:

It’s somewhat hard to read off the numbers, therefore we replot this for individual values of \(r_\mathrm{true}\)

[7]:
def plot_n_req(n_req, qs=(.025, .975), color=None, label='mean'):

    #n_req = n_req.assign_coords(px=2*n_req.px)#.rename(px='ptot')
    n_req_q = n_req.quantile(qs, 'Sigma_id')

    return (
        hv.Curve(
           (n_req.ptot, n_req.mean('Sigma_id')),
            label=label
        )
        * hv.Area(
            (n_req_q.ptot, n_req_q.sel(quantile=qs[0]), n_req_q.sel(quantile=qs[-1])),
            vdims=['y', 'y2']
        )
    ).redim(
        x='Number of features',
        y='Required sample size'
    ).opts(
        opts.Curve(color=color),
        opts.Area(alpha=.25, color=color),
        opts.Overlay(logx=True, logy=True)
    )
[8]:
panels_req_sample_size = {}
for r in [.1, .3, .5, .7, .9]:
    panels_req_sample_size[r] = (
        plot_n_req(nPerFtr2n(cca_n_req_per_ftr['combined'], ds_cca.py).sel(r=r), color=clr_cca, label='CCA')
        * plot_n_req(nPerFtr2n(pls_n_req_per_ftr['combined'], ds_pls.py).sel(r=r), color=clr_pls, label='PLS')
    ).opts(
        opts.Overlay(logx=True, logy=True, show_legend=True, hooks=[legend_frame_off])
    )


hv.NdLayout(
    panels_req_sample_size,
    kdims=hv.Dimension('r_true', label=r'$r_\mathrm{true}$')
).cols(
    5
)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/numpy/lib/nanfunctions.py:1391: RuntimeWarning: All-NaN slice encountered
  result = np.apply_along_axis(_nanquantile_1d, axis, a, q,
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/numpy/lib/nanfunctions.py:1391: RuntimeWarning: All-NaN slice encountered
  result = np.apply_along_axis(_nanquantile_1d, axis, a, q,
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/numpy/lib/nanfunctions.py:1391: RuntimeWarning: All-NaN slice encountered
  result = np.apply_along_axis(_nanquantile_1d, axis, a, q,
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/numpy/lib/nanfunctions.py:1391: RuntimeWarning: All-NaN slice encountered
  result = np.apply_along_axis(_nanquantile_1d, axis, a, q,
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
[8]:

Required sample sizes depending on used metric

[9]:
r = 0.3

panels = dict(cca=hv.Overlay(), pls=hv.Overlay())

metric_labels = dict(
    power='power',
    betweenAssoc='assoc. strength error',
    weightError='weight error',
    scoreError='score error',
    loadingError='loading error',
    combined='combined'
)
metric_clrs = hv.Palette('Dark2', samples=len(metric_labels)).values

for model, ds, n_req_per_ftr in [
    ('cca', ds_cca, cca_n_req_per_ftr),
    ('pls', ds_pls, pls_n_req_per_ftr),
]:
    for metric_i, metric in enumerate(n_req_per_ftr):
        if metric == 'combined':
            linestyle = '--'
        else:
            linestyle = '-'
        n_req = nPerFtr2n(n_req_per_ftr[metric], ds.py).sel(r=r)
        panels[model] *= hv.Curve(n_req.mean('Sigma_id').dropna('ptot'), label=metric).opts(linestyle=linestyle, color=metric_clrs[metric_i])

        if model == 'pls':
            panels[model] *= hv.Text(256, 3 * (2.25)**metric_i, metric_labels[metric], fontsize=7, halign='right', valign='bottom').opts(color=metric_clrs[metric_i])

fig_req_sample_size_by_metric = (
    panels['cca']
    + panels['pls'].opts(ylabel='')
).redim(
    ptot='Number of features',
    n_required='Req. sample size to obtain\n' + r'$\geq$ 90% power & $\leq$ 10%' + '\nerror in other metrics'
).opts(*fig_opts).opts(
    opts.Curve(linewidth=1),
    opts.Overlay(logx=True, logy=True, show_legend=False)
)

fig_req_sample_size_by_metric
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
[9]:

Required sample size per variable

Another way to condense the information shown in the heatmaps above, is to investigate the number of required samples per feature.

[10]:
def plot_n_per_ftr(n_req_per_ftr, color=None, label=''):

    px_exponent = 1 # error_bars['max'][2, 1]
    n_per_ftr = n_req_per_ftr.mean('Sigma_id').rename('y')  # rename for redim below
    n_per_ftr_stacked = n_per_ftr.stack(it=('r', 'px'))

    panel = hv.Curve(n_per_ftr.mean('px')).opts(
        color='black', linestyle='--', linewidth=1
    )

    for r in n_per_ftr.r.values:
        panel *= hv.Curve(
            ([.85*r, 1.15*r], [float(n_per_ftr.sel(r=r).mean('px'))]*2),
            kdims='r',
            label=label,
        ).opts(
            color=color,
            linewidth=1
        )

    panel *= hv.Scatter(
        (n_per_ftr_stacked.r.values * np.random.uniform(.95, 1.05, size=n_per_ftr_stacked.r.shape), n_per_ftr_stacked),
        kdims='r',
        #label=label,
    ).opts(marker='.', color=color)

    panel = panel.redim(
        r='$r_\mathrm{true}$',
        y='Req. samples per feature'
    ).opts(
        opts.Overlay(logx=True, logy=True)
    )
    return panel


[11]:
panels_req_n_per_variable = (
    plot_n_per_ftr(pls_n_req_per_ftr['combined'], color=clr_pls, label='PLS').opts(yaxis='bare')#.relabel('PLS')
    * plot_n_per_ftr(cca_n_req_per_ftr['combined'], color=clr_cca, label='CCA').opts(show_legend=False)#.relabel('CCA')
    #* hv.Text(1, 2000, 'PLS', halign='right', valign='top', fontsize=8).opts(color=clr_pls)
    #* hv.Text(1, 1000, 'CCA', halign='right', valign='top', fontsize=8).opts(color=clr_cca)
).opts(*fig_opts).opts(
    opts.Scatter(s=15),
    opts.Overlay(logx=True, logy=True, padding=.02, show_legend=False, fig_inches=(1.7, None)),
)

panels_req_n_per_variable
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
[11]:

Figure

[12]:
fig = (
    (
        panels_req_sample_size[0.3]
        * hv.Text(256, 100, '$r_\mathrm{true}=0.3$', halign='right', valign='bottom', fontsize=8)
        * hv.Text(5, 5e4, 'PLS', halign='left', valign='top', fontsize=8).opts(color=clr_pls)
        * hv.Text(5, 2.5e4, 'CCA', halign='left', valign='top', fontsize=8).opts(color=clr_cca)
    ).opts(logx=True, logy=True, show_legend=False, ylabel='Req. sample size to obtain\n' + r'$\geq$ 90% power & $\leq$ 10%' + '\nerror in other metrics')
    + panels_req_n_per_variable.opts(show_legend=False)
    #+ panel_lm_coeffs_pls.opts(xlim=(-.5, None)).relabel('linear model for $\log\ n_\mathrm{required}$') +
    #+ panel_lm_coeffs_cca.opts(xlim=(-.5, None)).relabel('linear model for $\log\ n_\mathrm{required}$').opts(opts.Overlay(yaxis='bare'))
).cols(
    2
).opts(*fig_opts).opts(
    opts.Layout(sublabel_position=(-.3, .95), hspace=.6)
)

hv.save(fig, 'fig/fig8_required_sample_size.pdf')

fig
[12]:
[13]:
figS = (
    fig_required_sample_size_heatmap
    + fig_req_sample_size_by_metric
).cols(
    2
).opts(*fig_opts).opts(
    opts.Layout(vspace=.5)
)


hv.save(figS, 'fig/figS_required_sample_size.pdf')

figS
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
  mask |= resdat <= 0
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
  mask |= resdat <= 0
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
  mask |= resdat <= 0
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/matplotlib/colors.py:1110: RuntimeWarning: invalid value encountered in less_equal
  mask |= resdat <= 0
[13]: