Sample size calculator

We develop a sample size calculator for CCA and PLS. We had previously generated datasets for a variety of parameters and determined the corresponding required sample sizes to bound statistical power and error metrics. To circumvent the computational cost associated with this approach for parameters that we had not simulated, we interpolate the parameter dependencies, by fitting a suitable linear model.

Setup

[1]:
import numpy as np
import pandas as pd
import xarray as xr
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, zscore
from scipy.spatial.distance import pdist, cdist, squareform

from sklearn import clone
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_validate, KFold, ShuffleSplit

from gemmr.data import load_outcomes, print_ds_stats
from gemmr.metrics import *
from gemmr.sample_size import *
from gemmr.sample_size.linear_model import prep_data_for_lm, do_fit_lm
from gemmr.util import subset_ds

import matplotlib
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').param.set_param(dpi=120)

from my_config import *

import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
[2]:
data_home = None
ds_cca = load_outcomes('sweep_cca_cca_random_sum+-3+0_wOtherModel', model='cca', add_prefix='cca_', data_home=data_home).sel(mode=0)
ds_pls = load_outcomes('sweep_pls_pls_random_sum+-3+0_wOtherModel', model='pls', add_prefix='pls_', data_home=data_home).sel(mode=0)

Loading data from subfolder 'gemmr_latest'
Loading data from subfolder 'gemmr_latest'
[3]:
ds_cca = ds_cca.sel(px=ds_cca.px < 128)
ds_pls = ds_pls.sel(px=ds_pls.px < 128)
[4]:
ds_cca = subset_ds(ds_cca, n_keep=25, keyvar='cca_between_assocs')
ds_pls = subset_ds(ds_pls, n_keep=25, keyvar='pls_between_assocs')

What’s in the outcome data files?

[5]:
print_ds_stats(ds_cca, prefix='cca_')
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.97, -0.10)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated
[6]:
print_ds_stats(ds_pls, prefix='pls_')
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.97, -0.10)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0,  0,  0, 21]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated

Fit a linear model to log(required sample size) to predict sample size

[7]:
def run_linear_model(
    _ds,
    cv=KFold(2),
    include_pdiff=False,
    include_pc_var_decay_constants=True,
    include_latent_explained_vars=True,
    prefix=''
):

    r_squareds, error_bars, true_ys, predicted_ys = {}, {}, {}, {}

    n_reqs_per_ftr = calc_n_required_all_metrics(_ds, target_power=0.9, target_error=0.1, search_dim='n_per_ftr', prefix=prefix)

    for metric_label, n_req_per_ftr in n_reqs_per_ftr.items():

        n_req = n_req_per_ftr * (_ds.px + _ds.py)

        lm, X, y, coeff_names = do_fit_lm(
            _ds, n_req,
            include_pdiff=include_pdiff,
            include_pc_var_decay_constants=include_pc_var_decay_constants,
            include_latent_explained_vars=include_latent_explained_vars,
            prefix=prefix
        )

        r_squareds[metric_label] = lm.score(X, y)

        # make error_bar estimates through CV
        cres = cross_validate(lm, X, y,
                              cv=ShuffleSplit(100, test_size=0.5, random_state=42),
                              return_estimator=True)

        coefs = np.r_[[lm.intercept_], lm.coef_].reshape(1, -1)
        qs = (.025, .975)
        cv_cis = np.quantile(
            np.asarray([np.r_[[estr.intercept_], estr.coef_] for estr in cres['estimator']]),
            qs,
            axis=0
        )

        error_bars[metric_label] = np.concatenate([
            np.arange(1 + X.shape[1]).reshape(1, -1),
            coefs,
            cv_cis - coefs
        ], axis=0).T

        median_minus_log_r_true = np.median(X[:, 0])
        masks = [
            X[:, 0] <= median_minus_log_r_true,
            X[:, 0] > median_minus_log_r_true,
        ]

        assert np.all(masks[0] | masks[1])

        _true_ys, _predicted_ys = [], []
        for mask in masks:

            train, test = mask, ~mask

            # make sure there's no overlap in the -log(r) values of train and test set
            assert len(set(X[train, 0].tolist()).intersection(X[test, 0].tolist())) == 0

            _lm = clone(lm).fit(X[train], y[train])
            _true_ys.append(y[test])
            _predicted_ys.append(lm.predict(X[test]))

        true_ys[metric_label] = np.hstack(_true_ys)
        predicted_ys[metric_label] = np.hstack(_predicted_ys)

    return r_squareds, error_bars, true_ys, predicted_ys, coeff_names
[8]:
# due to high loading-error for px < 5:
# make linear model only for px >= 5
r_squareds_cca, error_bars_cca, true_ys_cca, predicted_ys_cca, coeff_names_cca = run_linear_model(
    ds_cca.sel(px=ds_cca.px>4),
    include_pdiff=False,
    include_pc_var_decay_constants=True,
    include_latent_explained_vars=False,  # <- deprecated, doesn't exist any more
    prefix='cca_'
)
[9]:
# due to high loading-error for px < 5:
# make linear model only for px >= 5
r_squareds_pls, error_bars_pls, true_ys_pls, predicted_ys_pls, coeff_names_pls = run_linear_model(
    ds_pls.sel(px=ds_pls.px>4),
    include_pdiff=False,
    include_pc_var_decay_constants=True,
    include_latent_explained_vars=False,  # <- deprecated, doesn't exist any more
    prefix='pls_'
)

Linear model coefficients

[10]:
def plot_lm_coeffs(
    error_bars,
    r_squareds,
    coeff_names,
    metrics=None,
    plot_error_bars=True,
    label='',
    color=None,
):

    def x_ticks_format_variable_name(x):
        int_x = int(x)
        if (x == int_x) and (x < len(coeff_names)):
            return coeff_names[int(x)]
        else:
            return ''

    predictor_variable_dim = hv.Dimension('predictor_variable', label='Predictor variable',
                                          value_format=x_ticks_format_variable_name
                                         )

    markers = ['s', 'o', 'd', 'v', 'x', '.']

    if metrics is None:
        metrics = list(error_bars.keys())

    panel_lm_coeffs = hv.Overlay()
    for metric_i, metric_label in enumerate(metrics):

        if plot_error_bars:
            panel_lm_coeffs *= hv.ErrorBars(
                error_bars[metric_label],
                kdims=predictor_variable_dim,
            )

        panel_lm_coeffs *= hv.Scatter(
            (error_bars[metric_label][:, 0], error_bars[metric_label][:, 1]),
            kdims=predictor_variable_dim,
            label=label,
        ).opts(marker=markers[metric_i])

    if color is None:
        color = hv.Cycle(metric_clrs_li)

    panel_lm_coeffs = panel_lm_coeffs.redim(
            y='Coefficient value'
        ).opts(
            opts.ErrorBars(ylim=(None, None), padding=0.05),
            opts.Scatter(color=color, s=40),
            opts.Text(color=color),
            opts.HLine(color='black', linewidth=.1, linestyle='--'),
            opts.Overlay(xrotation=45, xticks=np.arange(len(coeff_names), dtype=int).tolist()),
        )
    return panel_lm_coeffs
[11]:
panel_lm_coeffs_max = (
    plot_lm_coeffs(error_bars_pls, r_squareds_pls, coeff_names_pls, metrics=['combined'], plot_error_bars=False, label='PLS', color=clr_pls)
    * plot_lm_coeffs(error_bars_cca, r_squareds_cca, coeff_names_cca, metrics=['combined'], plot_error_bars=False, label='CCA', color=clr_cca)
    * hv.Text(3.1, 2.5, 'CCA', halign='right', valign='top', fontsize=7).opts(color=clr_cca)
    * hv.Text(3.1, 2.2, 'PLS', halign='right', valign='top', fontsize=7).opts(color=clr_pls)
).opts(*fig_opts).opts(
    opts.Overlay(show_legend=False, padding=.1, ylim=(-.3, None), fig_inches=(1.7, None)),
)
panel_lm_coeffs_max
[11]:

Linear model prediction errors

[12]:
def scatter_lm_predictions(
    true_ys,
    predicted_ys,
    metric_label
):

    true_ys = np.exp(true_ys[metric_label])
    predicted_ys = np.exp(predicted_ys[metric_label])

    mn = min(true_ys.min(), predicted_ys.min())
    mx = max(true_ys.max(), predicted_ys.max())

    panel_n_crit_lm_predictions = (
        hv.Scatter(
            (true_ys, predicted_ys)
        ) * hv.Curve(
            ([mn, mx], [mn, mx])
        )
    ).redim(
        x='True req. sample size',
        y='Out-of-sample prediction'
    ).opts(
        opts.Curve(color='black', linestyle='--', linewidth=1),
        opts.Scatter(s=3),
    )
    return panel_n_crit_lm_predictions
[13]:
panel_cca_n_req_lm_predictions = scatter_lm_predictions(true_ys_cca, predicted_ys_cca, 'combined').opts(
    opts.Scatter(color=clr_cca),
    opts.Overlay(logx=True, logy=True)
).opts(*fig_opts).opts(fig_inches=(1.7, None))
panel_cca_n_req_lm_predictions
[13]:
[14]:
panel_pls_n_req_lm_predictions = scatter_lm_predictions(true_ys_pls, predicted_ys_pls, 'combined'
).opts(
    opts.Scatter(color=clr_pls),
    opts.Overlay(logx=True, logy=True)
).opts(*fig_opts).opts(fig_inches=(1.7, None))

panel_pls_n_req_lm_predictions
[14]:

Predict correlations with linear model

[15]:
def predict_req_corrs(
    ds, metric='combined',
    include_pc_var_decay_constants=False,
    prefix=''
):

    n_req_per_ftr = calc_n_required_all_metrics(ds, search_dim='n_per_ftr', prefix=prefix)[metric]
    n_req = n_req_per_ftr * (ds.px + ds.py)

    r_lists = [[.1, .3], [.5, .7]]
    true_rs, predicted_rs = [], []
    for r_list in r_lists:
        lm, X, y, coeff_names = do_fit_lm(ds.sel(r=r_list), n_req.sel(r=r_list), include_pdiff=False,
                                                 include_pc_var_decay_constants=include_pc_var_decay_constants,
                                                 include_latent_explained_vars=False,
                                                 prefix=prefix,
                                                )

        intercept = lm.intercept_
        coefs = lm.coef_

        other_rs = [r for r in ds.r.values if r not in r_list]
        Xtest, ytest, coeff_names_test = prep_data_for_lm(ds.sel(r=other_rs), n_req.sel(r=other_rs), include_pdiff=False,
                                                          include_pc_var_decay_constants=include_pc_var_decay_constants,
                                                          include_latent_explained_vars=False,
                                                          prefix=prefix,
                                                         )

        # double check that first column of X contains -np.log(r)
        for rr in -np.log(np.asarray(other_rs)):
            assert rr in Xtest[:, 0]

        true_rs.append(np.exp(-Xtest[:, 0]))

        if not include_pc_var_decay_constants:
            predicted_rs.append(
                np.exp((- ytest + intercept + Xtest[:, 1] * coefs[1]) / coefs[0])
            )
        else:
            predicted_rs.append(
                np.exp((- ytest + intercept + Xtest[:, 1] * coefs[1] + Xtest[:, 2] * coefs[2]) / coefs[0])
            )
    true_rs = np.hstack(true_rs)
    predicted_rs = np.hstack(predicted_rs)


    return true_rs, predicted_rs
[16]:
cca_rs_true, cca_rs_predicted  = predict_req_corrs(ds_cca, prefix='cca_')
pls_rs_true, pls_rs_predicted  = predict_req_corrs(ds_pls, prefix='pls_', include_pc_var_decay_constants=True)
[17]:
fig_predicted_req_corrs = (
    hv.Scatter(
        (cca_rs_true, cca_rs_predicted),
        kdims='true',
        vdims='predicted',
        label='CCA'
    ).opts(sublabel_format='e')
    +  hv.Scatter(
        (pls_rs_true, pls_rs_predicted),
        kdims='true',
        vdims='predicted',
        label='PLS'
    ).opts(ylabel='', sublabel_format='f')
).redim(
    true='True correlation',
    predicted='Out-of-sample prediction\nwith linear model'
).opts(*fig_opts).opts(
    opts.Scatter(s=10, xlim=(0, 1), ylim=(0, None)),
    opts.Layout(sublabel_position=(-.45, .95))
)

fig_predicted_req_corrs
[17]:

Figure

[18]:
fig = (
    hv.Overlay()
    + panel_lm_coeffs_max.relabel('Linear model coefficients').opts(xticks=4, xrotation=25, sublabel_format='b')
    + panel_cca_n_req_lm_predictions.relabel('CCA').opts(sublabel_format='c')
    + panel_pls_n_req_lm_predictions.relabel('PLS').opts(sublabel_format='d')
    # +  panel_pls_n_req_lm_predictions_wLatentExplVar.relabel('PLS w/ additional\nunobservable predictor')
    + fig_predicted_req_corrs
).cols(
    2
).opts(*fig_opts).opts(
    opts.Layout(vspace=.75, sublabel_position=(-.4, 1))
)

hv.save(fig, 'fig/figS_sample_size_calculator.pdf')

fig
WARNING:param.LayoutPlot06878: :Overlay is empty, skipping subplot.
WARNING:param.LayoutPlot07136: :Overlay is empty, skipping subplot.
[18]:

Usage of linear models

[19]:
from gemmr import *
[20]:
cca_sample_size(100, 50, -1, -1)
[20]:
{0.1: 189701, 0.3: 17211, 0.5: 5638}

The result is a dictionary with keys indicating assumed ground truth correlations and values giving the corresponding sample size estimates. For PLS, in addition to the number of features in each set, the powerlaw decay constants for the within-set principal component spectra need to be specified:

[21]:
pls_sample_size(10, 10, -1, -1)
[21]:
{0.1: 18459, 0.3: 2949, 0.5: 1257}