Required sample sizes

How many samples are required to obatin stable CCA and PLS estimates?

Setup

[1]:
import numpy as np
import pandas as pd
import xarray as xr
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, zscore
from scipy.spatial.distance import pdist, cdist, squareform

from sklearn import clone
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_validate, ShuffleSplit

from gemmr.data import load_outcomes, print_ds_stats
from gemmr.metrics import *
from gemmr.sample_size.interpolation import *
from gemmr.plot import heatmap_n_req
from gemmr.util import nPerFtr2n, subset_ds

import matplotlib
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').param.set_param(dpi=120)

from my_config import *

import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
warnings.filterwarnings(
    'ignore', 'aspect is not supported for Axes with xscale=log, yscale=linear', category=UserWarning
)  # holoviews emits this for log-linear plots
[2]:
# This will be populated below and then written to disk
plotted_data = {}
[3]:
# helper functions for figS layout, used below

def figS_axpos(plot, elements):
    # global ax
    ax = plot.handles['axis']

    pos = ax.get_position()
    if pos.x0 < 0.1:
        ax.set_position((pos.x0-.035, pos.y0, pos.width, pos.height))
    elif pos.x0 < 0.4:
        ax.set_position((pos.x0-.07, pos.y0, pos.width, pos.height))
    elif pos.x0 < .6:
        pass

Load data

Load pre-calculated outcomes of analyzed synthetic data

[4]:
ds_cca = load_outcomes('sweep_cca_cca_random_sum+-2+-2')
ds_pls = load_outcomes('sweep_pls_pls_random_sum+-2+-2')
Loading data from subfolder 'gemmr_latest'
Loading data from subfolder 'gemmr_latest'
[5]:
ds_cca = ds_cca.sel(px=ds_cca.px < 128)
ds_pls = ds_pls.sel(px=ds_pls.px < 128)
[6]:
ds_cca = subset_ds(ds_cca, n_keep=25)
ds_pls = subset_ds(ds_pls, n_keep=25)

What’s in the outcome data files?

[7]:
print_ds_stats(ds_cca)
n_modes          1
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated
[8]:
print_ds_stats(ds_pls)
n_modes          1
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0,  0,  0, 10]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated

Required sample sizes

Let’s calculate the required sample size. This is implemented in ccapwr.required_sample_size. The function calc_n_required_all_metrics produces sample size estimates based on 5 metrics (power, relative error of between-set association strength, weight error, loading error, score error), as well as a combination of these 5 metrics (i.e. the maximum across all 5 metrics)

[9]:
cca_n_req_per_ftr = calc_n_required_all_metrics(ds_cca.sel(mode=0), search_dim='n_per_ftr')
pls_n_req_per_ftr = calc_n_required_all_metrics(ds_pls.sel(mode=0), search_dim='n_per_ftr')

Illustration of the required sample size based on the combined criterion

[10]:
metric = 'combined'
fig_required_sample_size_heatmap = (
    heatmap_n_req(nPerFtr2n(cca_n_req_per_ftr[metric], ds_cca.py)).opts(hooks=[figS_axpos, suptitle_cca])
    + heatmap_n_req(nPerFtr2n(pls_n_req_per_ftr[metric], ds_pls.py)).opts(colorbar=True, hooks=[figS_axpos, suptitle_pls])
).cols(
    2
).opts(*fig_opts).opts(
    opts.QuadMesh(logx=True, logz=True, #xlim=(2, 256+64), ylim=(0, 1),
                clim=(1e1, 1e5), cmap='inferno',
                sublabel_position=(-.45, .95),
                fontsize=dict(labels=8, ticks=7, legend=7, title=8),
                ),
    opts.Layout(fig_inches=(3.42, None))
)

fig_required_sample_size_heatmap
[10]:

It’s somewhat hard to read off the numbers, therefore we replot this for individual values of \(r_\mathrm{true}\)

[11]:
def plot_n_req(n_req, qs=(.025, .975), color=None, label='mean', plotted_data=None, panel_label=None):

    n_req_q = n_req.dropna('ptot', how='all').quantile(qs, 'Sigma_id')

    if plotted_data is not None:
        _plotted_data = pd.concat(
            [
                n_req.mean('Sigma_id').to_pandas().rename('mean'),
                n_req_q.sel(quantile=qs[0]).to_pandas().rename(f'{100*qs[0]:.1f}%'),
                n_req_q.sel(quantile=qs[-1]).to_pandas().rename(f'{100*qs[-1]:.1f}%'),
            ],
            axis=1
        )
        plotted_data[panel_label] = _plotted_data

    return (
        hv.Curve(
           (n_req.ptot, n_req.mean('Sigma_id')),
            label=label
        )
        * hv.Area(
            (n_req_q.ptot, n_req_q.sel(quantile=qs[0]), n_req_q.sel(quantile=qs[-1])),
            vdims=['y', 'y2']
        )
    ).redim(
        x='Number of features',
        y='Required sample size'
    ).opts(
        opts.Curve(color=color),
        opts.Area(alpha=.25, color=color),
        opts.Overlay(logx=True, logy=True)
    )
[12]:
panels_req_sample_size = {}
for r in [.1, .3, .5, .7,]:
    panels_req_sample_size[r] = (
        plot_n_req(nPerFtr2n(cca_n_req_per_ftr['combined'], ds_cca.py).sel(r=r), color=clr_cca, label='CCA', plotted_data=plotted_data, panel_label=f'CCA_r{r:.1f}_req_n')
        * plot_n_req(nPerFtr2n(pls_n_req_per_ftr['combined'], ds_pls.py).sel(r=r), color=clr_pls, label='PLS', plotted_data=plotted_data, panel_label=f'PLS_r{r:.1f}_req_n')
    ).opts(
        opts.Overlay(logx=True, logy=True, show_legend=True, hooks=[legend_frame_off])
    )


hv.NdLayout(
    panels_req_sample_size,
    kdims=hv.Dimension('r_true', label=r'$r_\mathrm{true}$')
).cols(
    5
)
[12]:

Required sample sizes depending on used metric

[13]:
r = 0.3

panels = dict(cca=hv.Overlay(), pls=hv.Overlay())

metric_labels = dict(
    power='power',
    betweenAssoc='assoc. strength error',
    weightError='weight error',
    scoreError='score error',
    loadingError='loading error',
    combined='combined'
)
metric_clrs = hv.Palette('Dark2', samples=len(metric_labels)).values

for model, ds_, n_req_per_ftr in [
    ('cca', ds_cca, cca_n_req_per_ftr),
    ('pls', ds_pls, pls_n_req_per_ftr),
]:
    for metric_i, metric in enumerate(n_req_per_ftr):
        if metric == 'combined':
            linestyle = '--'
        else:
            linestyle = '-'
        n_req = nPerFtr2n(n_req_per_ftr[metric], ds_.py).sel(r=r)
        panels[model] *= hv.Curve(n_req.mean('Sigma_id').dropna('ptot'), label=metric).opts(linestyle=linestyle, color=metric_clrs[metric_i])

        if model == 'pls':
            panels[model] *= hv.Text(256, 3 * (2.25)**metric_i, metric_labels[metric], fontsize=7, halign='right', valign='bottom').opts(color=metric_clrs[metric_i])

fig_req_sample_size_by_metric = (
    panels['cca'].opts(hooks=[figS_axpos, suptitle_cca])
    + panels['pls'].opts(ylabel='', hooks=[figS_axpos, suptitle_pls])
).redim(
    ptot='Number of features',
    n_required='Req. sample size to obtain\n' + r'$\geq$ 90% power & $\leq$ 10%' + '\nerror in other metrics'
).opts(*fig_opts).opts(
    opts.Curve(linewidth=1),
    opts.Overlay(logx=True, logy=True, show_legend=False)
)

fig_req_sample_size_by_metric
[13]:

Required sample size per variable

Another way to condense the information shown in the heatmaps above, is to investigate the number of required samples per feature.

[14]:
def plot_n_per_ftr(n_req_per_ftr, color=None, label='', plotted_data=None, panel_label=None):

    px_exponent = 1 # error_bars['max'][2, 1]
    n_per_ftr = n_req_per_ftr.mean('Sigma_id').rename('y')  # rename for redim below
    n_per_ftr_stacked = n_per_ftr.stack(it=('r', 'px'))

    panel = hv.Curve(n_per_ftr.mean('px')).opts(
        color='black', linestyle='--', linewidth=1
    )

    for r in n_per_ftr.r.values:
        panel *= hv.Curve(
            ([.85*r, 1.15*r], [float(n_per_ftr.sel(r=r).mean('px'))]*2),
            kdims='r',
            label=label,
        ).opts(
            color=color,
            linewidth=1
        )

    panel *= hv.Scatter(
        (n_per_ftr_stacked.r.values * np.random.uniform(.95, 1.05, size=n_per_ftr_stacked.r.shape), n_per_ftr_stacked),
        kdims='r',
    ).opts(marker='.', color=color)

    if plotted_data is not None:
        plotted_data[panel_label] = n_per_ftr_stacked.unstack().to_pandas()

    panel = panel.redim(
        r='$r_\mathrm{true}$',
        y='Req. samples per feature'
    ).opts(
        opts.Overlay(logx=True, logy=True)
    )
    return panel


[15]:
# Use only px, r for which both CCA and PLS data available!
mask = np.isfinite(cca_n_req_per_ftr['combined'].mean('Sigma_id')) & np.isfinite(pls_n_req_per_ftr['combined'].mean('Sigma_id'))

panels_req_n_per_variable = (
    plot_n_per_ftr(pls_n_req_per_ftr['combined'].where(mask), color=clr_pls, label='PLS', plotted_data=plotted_data, panel_label='PLS_req_n_per_ftr').opts(yaxis='bare')#.relabel('PLS')
    * plot_n_per_ftr(cca_n_req_per_ftr['combined'].where(mask), color=clr_cca, label='CCA', plotted_data=plotted_data, panel_label='CCA_req_n_per_ftr').opts(show_legend=False)#.relabel('CCA')
    #* hv.Text(1, 2000, 'PLS', halign='right', valign='top', fontsize=8).opts(color=clr_pls)
    #* hv.Text(1, 1000, 'CCA', halign='right', valign='top', fontsize=8).opts(color=clr_cca)
).opts(*fig_opts).opts(
    opts.Scatter(s=15),
    opts.Overlay(logx=True, logy=True, xlim=(None, 1.), padding=.02, show_legend=False, fig_inches=(1.7, None)),
)

panels_req_n_per_variable
[15]:

Figure

[16]:
fig = (
    (
        panels_req_sample_size[0.3]
        * hv.Text(256, 100, '$r_\mathrm{true}=0.3$', halign='right', valign='bottom', fontsize=8)
        * hv.Text(5, 5e4, 'PLS', halign='left', valign='top', fontsize=8).opts(color=clr_pls)
        * hv.Text(5, 2.5e4, 'CCA', halign='left', valign='top', fontsize=8).opts(color=clr_cca)
    ).opts(logx=True, logy=True, show_legend=False, ylabel='Req. sample size to obtain\n' + r'$\geq$ 90% power & $\leq$ 10%' + '\nerror in other metrics')
    + panels_req_n_per_variable.opts(show_legend=False)
).cols(
    2
).opts(*fig_opts).opts(
    opts.Layout(sublabel_position=(-.3, .95), hspace=.6)
)

hv.save(fig, 'fig/fig_required_sample_size.pdf')
save_source_data(plotted_data, 'fig7')

fig
[16]:
[17]:
figS = (
    fig_required_sample_size_heatmap
    + fig_req_sample_size_by_metric
).cols(
    4
).opts(*fig_opts).opts(
    opts.Layout(fig_inches=(7, .7), hspace=.75)
)

hv.save(figS, 'fig/figS_required_sample_size.pdf')

figS
[17]: