Cross-validated estimation

What are the properties of cross-validated CCA and PLS estimations?

Setup

[1]:
import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
warnings.filterwarnings(
    'ignore', 'aspect is not supported for Axes with xscale=log, yscale=linear', category=UserWarning
)  # holoviews emits this for log-linear plots
from tqdm import TqdmExperimentalWarning
warnings.filterwarnings('ignore', category=TqdmExperimentalWarning)

import numpy as np
import xarray as xr
import pandas as pd
import scipy.stats
from scipy.stats import pearsonr, spearmanr, zscore
from scipy.spatial.distance import pdist, cdist, squareform

from sklearn.decomposition import PCA, SparsePCA
from sklearn.utils import check_random_state
from sklearn.model_selection import KFold, ShuffleSplit

from gemmr.estimators import SVDCCA, SVDPLS
from gemmr.generative_model import *
from gemmr.data import load_outcomes, print_ds_stats
from gemmr.sample_analysis import *
from gemmr.sample_size.interpolation import *
from gemmr.metrics import *
from gemmr.util import subset_ds

import matplotlib
from matplotlib.lines import Line2D
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').param.set_param(dpi=120)

from my_config import *
[2]:
r_clrs = hv.Palette(cmap_r, samples=3).values
r_clrs_cv = ['mediumseagreen', None, 'steelblue']
[3]:
# load data
data_home = None
res = dict(
    cca=load_outcomes('cv_cca_cca_random_sum+-2+-2', data_home=data_home).sel(mode=0),
    pls=load_outcomes('cv_pls_pls_random_sum+-2+-2', data_home=data_home).sel(mode=0)
)
Loading data from subfolder 'gemmr_latest'
Loading data from subfolder 'gemmr_latest'

What’s in the outcome data files?

[4]:
res['cca'] = subset_ds(res['cca'], n_keep=10)
res['pls'] = subset_ds(res['pls'], n_keep=10)
[5]:
print_ds_stats(res['cca'])
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 5, r: 4)>
array([[10, 10, 10, 10],
       [10, 10, 10, 10],
       [10, 10, 10, 10],
       [10, 10, 10, 10],
       [10, 10, 10, 10]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32

power           not calculated
[6]:
print_ds_stats(res['pls'])
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 5, r: 4)>
array([[10, 10, 10, 10],
       [10, 10, 10, 10],
       [10, 10, 10, 10],
       [10, 10, 10, 10],
       [10, 10, 10, 10]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32

power           not calculated

Dependence of association strength on samples per feature

[7]:
def plot_between_assocs_cv(ds, *, insample_assoc='between_assocs', cv_assoc=None, rs=(.1, .5), px=8, qs=(.025, .975), y_label='between-set assoc strength'):
    panel = hv.Overlay()
    for ri, r in enumerate([.1, .3, .5]):

        if r not in rs:
            continue

        _between_assocs = ds[insample_assoc].sel(r=r, px=px).dropna('n_per_ftr', how='all').stack(stacked=['rep', 'Sigma_id'])
        _between_assocs_cv = ds[cv_assoc].sel(r=r, px=px, cv='kfold5').dropna('n_per_ftr', how='all').stack(stacked=['rep', 'Sigma_id'])
        panel *= (
            hv.Curve(
                _between_assocs.mean('stacked')
            ).opts(color=r_clrs[ri])
            * hv.Area(
                (_between_assocs.n_per_ftr,
                 _between_assocs.quantile(qs[0], 'stacked'),
                 _between_assocs.quantile(qs[1], 'stacked')),
                vdims=['y', 'y2']
            ).opts(color=r_clrs[ri])
            * hv.Curve(
                _between_assocs_cv.mean('stacked')
            ).opts(color=r_clrs_cv[ri], linestyle='--')
            * hv.Area(
                (_between_assocs_cv.n_per_ftr,
                 _between_assocs_cv.quantile(qs[0], 'stacked'),
                 _between_assocs_cv.quantile(qs[1], 'stacked')),
                vdims=['y', 'y2']
            ).opts(color=r_clrs_cv[ri], linestyle='--')
        )

    return panel.redim(**{
        'n_per_ftr': 'Samples / feature',
        insample_assoc: y_label
    }).opts(
        opts.Overlay(logx=True, logy=False),
        opts.Area(linewidth=0.5, alpha=.2)
    )
[8]:
panels_assoc_strength = (
    (
        plot_between_assocs_cv(res['cca'], insample_assoc='between_assocs', cv_assoc='between_corrs_cv', y_label='Association strength')
    ) + (
        plot_between_assocs_cv(res['pls'], insample_assoc='between_assocs', cv_assoc='between_covs_cv', y_label='y')
        * hv.Text(90, .8, '$r_\mathrm{true}=$', halign='right', valign='top', fontsize=7)
        * hv.Text(100*(4)**1, .8, '0.5', halign='right', valign='top', fontsize=7).opts(color=r_clrs[-1])
        # * hv.Text(100*(4)**2, .8, '0.3', halign='right', valign='top', fontsize=7).opts(color=r_clrs[-2])
        * hv.Text(100*(4)**2, .8, '0.1', halign='right', valign='top', fontsize=7).opts(color=r_clrs[-3])
    ).redim(
        y=hv.Dimension('assoc_strength2', label='Association strength')
    ).opts(logx=True, logy=False, ylabel='', ylim=(-.4, .9))
).opts(*fig_opts)

panels_assoc_strength
[8]:

Estimation errors of association strengths

[9]:
def plot_abs_errors(ds, *, insample_errorfun=mk_betweenAssocRelError, cv_assoc=None, px=8, rs=(.1, .5), qs=(.025, .975), y_label='error'):

    e = np.abs(insample_errorfun(ds))
    ecv = np.abs(insample_errorfun(ds, datavar=cv_assoc))

    panel = hv.Overlay()

    for ri, r in enumerate([.1, .3, .5,]):

        if r not in rs:
            continue

        _e = e.sel(r=r, px=px).dropna('n_per_ftr', how='all').stack(stacked=['Sigma_id', 'rep'])
        _ecv = ecv.sel(r=r, px=px, cv='kfold5').dropna('n_per_ftr', how='all').stack(stacked=['Sigma_id', 'rep'])

        panel *= (
            hv.Curve(
                _e.mean('stacked')
            ).opts(color=r_clrs[ri])
            * hv.Area(
                (_e.n_per_ftr,
                 _e.quantile(qs[0], 'stacked'),
                 _e.quantile(qs[1], 'stacked')),
                vdims=['y', 'y2']
            ).opts(color=r_clrs[ri])
            * hv.Curve(
                _ecv.mean('stacked')
            ).opts(color=r_clrs_cv[ri], linestyle='--')
            * hv.Area(
                (_ecv.n_per_ftr,
                 _ecv.quantile(qs[0], 'stacked'),
                 _ecv.quantile(qs[1], 'stacked')),
                vdims=['y', 'y2']
            ).opts(color=r_clrs_cv[ri], linestyle='--')
        )

    return panel.redim(
        n_per_ftr='Samples / feature',
        y=y_label,
    ).opts(
        opts.Overlay(logx=True, logy=True, ylim=(None, 1)),
        opts.Area(linewidth=.5, alpha=.2)
    )
[10]:
panels_assoc_strength_error = (
    plot_abs_errors(res['cca'], insample_errorfun=mk_betweenAssocRelError, cv_assoc='between_corrs_cv', y_label='| Relative association\nstrength error |')
    + plot_abs_errors(res['pls'], insample_errorfun=mk_betweenAssocRelError, cv_assoc='between_covs_cv', y_label='| Relative association\nstrength error |').opts(ylabel='')
).opts(*fig_opts).opts(
    opts.Overlay(ylim=(.001, 10)),
)

panels_assoc_strength_error
[10]:

Required sample size

[11]:
def scatter_req_sample_size(ds, cv_assoc, error_level=0.1, lg=None):

    e = np.abs(mk_betweenAssocRelError(ds))
    ecv = np.abs(mk_betweenAssocRelError_cv(ds, cv_assoc))
    relAssocE = mk_meanBetweenAssocRelError(ds, cv_assoc)

    n_req = calc_n_required(e.mean('rep'), -error_level, error_level)
    n_req_cv = calc_n_required(ecv.mean('rep'), -error_level, error_level)
    n_req_mean = calc_n_required(relAssocE.mean('rep'), -error_level, error_level)

    n_req = n_req.mean('Sigma_id').mean('px')
    n_req_cv = n_req_cv.mean('Sigma_id').mean('px')
    n_req_mean = n_req_mean.mean('Sigma_id').mean('px')

    if lg is None:
        lg = [False]*5
    else:
        lg = np.asarray(lg).astype(bool).tolist()

    return (
        hv.Scatter((n_req.r, n_req), label='in-sample').opts(color=clr_inSample, marker='s', show_legend=lg[0])
        * hv.Scatter((n_req_cv.r, n_req_cv.sel(cv='kfold5')), label='5-fold CV').opts(color=clr_5cv, marker='o', show_legend=lg[1])
    ).redim(
        x='$r_\mathrm{true}$',
        y='Req samples / feature'
    ).opts(
        opts.Scatter(alpha=1),
        opts.Overlay(logx=True, logy=True, padding=.02, xlim=(.08, 1))
    )
[12]:
error_level = 0.1
panels_req_samples = (
    (
        scatter_req_sample_size(res['cca'], 'between_corrs_cv', error_level=error_level, lg=[1, 1, 1, 0, 0])
        * hv.Text(.09, 10, 'in-sample', halign='left', valign='top', fontsize=7).opts(color=clr_inSample)
        * hv.Text(.09, 5, '5-fold CV', halign='left', valign='top', fontsize=7).opts(color=clr_5cv)
    )
    + (
        scatter_req_sample_size(res['pls'], 'between_covs_cv', error_level=error_level, lg=[0, 0, 0, 1, 1])
    ).opts(ylabel='')
).opts(*fig_opts).opts(
    opts.Overlay(xlim=(.08, 1), ylim=(1, 2000), logx=True, logy=True, show_legend=False)
)

panels_req_samples
[12]:

Figure

[13]:
def hook_ls_legend(plot, element):
    ax = plot.handles['axis']
    l = Line2D([10], [0.1], color='black', linestyle='-', linewidth=1)
    handles = [
        Line2D([10], [0.1], color='black', linestyle='-', linewidth=1),
        Line2D([10], [0.1], color='black', linestyle='-.', linewidth=1),
    ]
    ax.legend(handles,
              [
                  'in sample',
                  '5-fold CV',
              ],
              handlelength=1.5, fontsize=7, frameon=False,
              loc='lower right', bbox_to_anchor=(1, 0)
             )


fig = (
    panels_assoc_strength.Overlay.I.opts(xlim=(3, 9000), ylim=(-.4, .9), hooks=[hook_ls_legend, Format_log_axis('x', major_numticks=4, minor_numticks=5), suptitle_cca])
    + panels_assoc_strength.Overlay.II.opts(xlim=(3, 9000), ylim=(-.4, .9), hooks=[Format_log_axis('x', major_numticks=4, minor_numticks=5), suptitle_pls])
    + panels_assoc_strength_error.opts(opts.Overlay(xlim=(3, 9000), hooks=[Format_log_axis('x', major_numticks=4, minor_numticks=5)]))
    + panels_req_samples
).cols(
    2
).opts(*fig_opts).opts(
    opts.Layout(hspace=.6, vspace=0.6, sublabel_position=(-.35, .95))
)

hv.save(fig, 'fig/figS_cv.pdf')

fig
[13]: