Cross-validated estimation

What are the properties of cross-validated CCA and PLS estimations?

Setup

[1]:
import itertools

import numpy as np
import xarray as xr
import pandas as pd
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, spearmanr, zscore
from scipy.spatial.distance import pdist, cdist, squareform

from sklearn.decomposition import PCA, SparsePCA
from sklearn.utils import check_random_state
from sklearn.model_selection import KFold, ShuffleSplit

from gemmr.estimators import SVDCCA, SVDPLS
from gemmr.generative_model import *
from gemmr.data import load_outcomes, print_ds_stats
from gemmr.sample_analysis import *
from gemmr.sample_size.interpolation import *
from gemmr.metrics import *

import matplotlib
from matplotlib.lines import Line2D
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').set_param(dpi=120)

from my_config import *

import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
warnings.filterwarnings(
    'ignore', 'aspect is not supported for Axes with xscale=log, yscale=linear', category=UserWarning
)  # holoviews emits this for log-linear plots
/Users/mdh56/Projects/gemmr/gemmr/sample_analysis/analyzers.py:13: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  from tqdm.autonotebook import tqdm
[2]:
r_clrs = hv.Palette(cmap_r, samples=3).values
[3]:
# load data
res = dict(
    cca=load_outcomes('cca', tag='cv').sel(mode=0),
    pls=load_outcomes('pls', tag='axPlusay-2_cv').sel(mode=0)
)

What’s in the outcome data files?

[4]:
print_ds_stats(res['cca'])
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024]
r                [0.1 0.3 0.5 0.7 0.9]
px               [ 2  4  8 16 32]
ax+ay range     (0.00, 0.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 5, r: 5)>
array([[10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32

power           not calculated
[5]:
print_ds_stats(res['pls'])
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7 0.9]
px               [ 2  4  8 16 32]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 5, r: 5)>
array([[10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32

power           not calculated

Dependence of association strength on samples per feature

[6]:
def plot_between_assocs_cv(ds, cv_assoc, y_label='between-set assoc strength'):
    panel = hv.Overlay()
    for ri, r in enumerate([.1, .3, .5,]):
        panel *= (
            hv.Curve(
                (ds.between_assocs.sel(r=r).mean('rep').mean('Sigma_id')).mean('px')
            ).opts(color=r_clrs[ri])
            * hv.Curve(
                ds[cv_assoc].sel(r=r, cv='kfold5').mean('rep').mean('Sigma_id').mean('px')
            ).opts(color=r_clrs[ri], linestyle='-.')
            * hv.Curve(
                ds[cv_assoc].sel(r=r, cv='shuffle20x20%test').mean('rep').mean('Sigma_id').mean('px')
            ).opts(color=r_clrs[ri], linestyle=':')
        )

    return panel.redim(
        n_per_ftr='Samples / feature',
        between_assocs=y_label
    ).opts(
        logx=True, logy=True
    )
[7]:
panels_assoc_strength = (
    plot_between_assocs_cv(res['cca'], 'between_corrs_cv', y_label='Association strength')
    + (
        plot_between_assocs_cv(res['pls'], 'between_covs_cv', y_label='y')
        * hv.Text(90, .5, '$r_\mathrm{true}=$', halign='right', valign='top', fontsize=7)
        * hv.Text(100*(4)**1, .5, '0.5', halign='right', valign='top', fontsize=7).opts(color=r_clrs[-1])
        * hv.Text(100*(4)**2, .5, '0.3', halign='right', valign='top', fontsize=7).opts(color=r_clrs[-2])
        * hv.Text(100*(4)**3, .5, '0.1', halign='right', valign='top', fontsize=7).opts(color=r_clrs[-3])
    ).redim(
        y=hv.Dimension('assoc_strength2', label='Association strength')
    ).opts(logx=True, logy=True, ylabel='', ylim=(.005, None))
).opts(*fig_opts)#.opts(logx=True, logy=True)

panels_assoc_strength
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
[7]:

Estimation errors of association strengths

[8]:
def plot_abs_errors(ds, cv_assoc, y_label='error'):

    e = np.abs(mk_betweenAssocRelError(ds))
    ecv = np.abs(mk_betweenAssocRelError_cv(ds, cv_assoc))
    #relAssocE = mk_meanBetweenAssocRelError(ds).sel(mode=0)

    panel = hv.Overlay()

    for ri, r in enumerate([.1, .3, .5,]):
        panel *= (
            hv.Curve(
                e.sel(r=r).mean('Sigma_id').mean('rep').mean('px')
            ).opts(color=r_clrs[ri])
            * hv.Curve(
                ecv.sel(r=r, cv='kfold5').mean('Sigma_id').mean('rep').mean('px')
            ).opts(color=r_clrs[ri], linestyle='--')
            * hv.Curve(
                ecv.sel(r=r, cv='shuffle20x20%test').mean('Sigma_id').mean('rep').mean('px')
            ).opts(color=r_clrs[ri], linestyle=':')
        )

    return panel.redim(
        n_per_ftr='Samples / feature',
        y=y_label,
    ).opts(
        logx=True, logy=True, ylim=(None, 1)
    )
[9]:
panels_assoc_strength_error = (
    plot_abs_errors(res['cca'], 'between_corrs_cv', y_label='| Relative association\nstrength error |')
    + plot_abs_errors(res['pls'], 'between_covs_cv', y_label='| Relative association\nstrength error |').opts(ylabel='')
).opts(*fig_opts)#.opts(logx=True, logy=True)

panels_assoc_strength_error
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
[9]:

Required sample size

[10]:
def scatter_req_sample_size(ds, cv_assoc, error_level=0.1, lg=None):

    e = np.abs(mk_betweenAssocRelError(ds))
    ecv = np.abs(mk_betweenAssocRelError_cv(ds, cv_assoc))
    relAssocE = mk_meanBetweenAssocRelError(ds, cv_assoc)

    n_req = calc_n_required(e.mean('rep'), -error_level, error_level)
    n_req_cv = calc_n_required(ecv.mean('rep'), -error_level, error_level)
    n_req_mean = calc_n_required(relAssocE.mean('rep'), -error_level, error_level)

    n_req = n_req.mean('Sigma_id').mean('px')
    n_req_cv = n_req_cv.mean('Sigma_id').mean('px')
    n_req_mean = n_req_mean.mean('Sigma_id').mean('px')

    if lg is None:
        lg = [False]*5
    else:
        lg = np.asarray(lg).astype(bool).tolist()

    return (
        hv.Scatter((n_req.r, n_req), label='in-sample').opts(color=clr_inSample, marker='s', show_legend=lg[0])
        * hv.Scatter((n_req_cv.r*.975, n_req_cv.sel(cv='kfold5')), label='5-fold CV').opts(color=clr_5cv, marker='o', show_legend=lg[1])
        * hv.Scatter((n_req_cv.r*1.025, n_req_cv.sel(cv='shuffle20x20%test')), label=r'20$\times$5-fold CV').opts(color=clr_20x5cv, marker='o', show_legend=lg[2])
        * hv.Scatter((n_req_mean.r*.975, n_req_mean.sel(cv='kfold5')), label='5-fold CV').opts(color=clr_5cvMean, marker='d', show_legend=lg[3])
        * hv.Scatter((n_req_mean.r*1.025, n_req_mean.sel(cv='shuffle20x20%test')), label=r'20$\times$5-fold CV').opts(color=clr_20x5cvMean, marker='d', show_legend=lg[4])
    ).redim(
        x='$r_\mathrm{true}$',
        y='Req samples / feature'
    ).opts(
        opts.Scatter(alpha=1),
        opts.Overlay(logx=True, logy=True, padding=.02, xlim=(.08, 1))
    )
[11]:
error_level = 0.1
panels_req_samples = (
    (
        scatter_req_sample_size(res['cca'], 'between_corrs_cv', error_level=error_level, lg=[1, 1, 1, 0, 0])
        * hv.Text(.09, 10, 'in-sample', halign='left', valign='top', fontsize=7).opts(color=clr_inSample)
        * hv.Text(.09, 5, '5-fold CV', halign='left', valign='top', fontsize=7).opts(color=clr_5cv)
        * hv.Text(.09, 2.5, '20 x 5-fold CV', halign='left', valign='top', fontsize=7).opts(color=clr_20x5cv)
    )
    + (
        scatter_req_sample_size(res['pls'], 'between_covs_cv', error_level=error_level, lg=[0, 0, 0, 1, 1])
        * hv.Text(.09, 10, 'avg with in-sample:', halign='left', valign='top', fontsize=7).opts(color='black')
        * hv.Text(.09, 5, '5-fold CV', halign='left', valign='top', fontsize=7).opts(color=clr_5cvMean)
        * hv.Text(.09, 2.5, '20 x 5-fold CV', halign='left', valign='top', fontsize=7).opts(color=clr_20x5cvMean)
    ).opts(ylabel='')
).opts(*fig_opts).opts(
    opts.Overlay(xlim=(.08, 1), ylim=(1, None), logx=True, logy=True, show_legend=False)
)

panels_req_samples
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
[11]:

Figure

[12]:
import matplotlib
[13]:
matplotlib.ticker.LogLocator()
[13]:
<matplotlib.ticker.LogLocator at 0x11e04e6a0>
[14]:
def hook_ls_legend(plot, element):
    ax = plot.handles['axis']
    l = Line2D([10], [0.1], color='black', linestyle='-', linewidth=1)
    handles = [
        Line2D([10], [0.1], color='black', linestyle='-', linewidth=1),
        Line2D([10], [0.1], color='black', linestyle='-.', linewidth=1),
        Line2D([10], [0.1], color='black', linestyle=':', linewidth=1)
    ]
    ax.legend(handles, ['in sample', '5-fold CV', r'20$\times$5-fold CV'],
              handlelength=1.5, fontsize=7, frameon=False,
              loc='lower right', bbox_to_anchor=(1.4, 0)
             )


fig = (
    panels_assoc_strength.Overlay.I.opts(xlim=(3, 9000), hooks=[hook_ls_legend, Format_log_axis('x', major_numticks=4, minor_numticks=5), suptitle_cca])
    + panels_assoc_strength.Overlay.II.opts(xlim=(3, 9000), hooks=[Format_log_axis('x', major_numticks=4, minor_numticks=5), suptitle_pls])
    + panels_assoc_strength_error.opts(opts.Overlay(xlim=(3, 9000), hooks=[Format_log_axis('x', major_numticks=4, minor_numticks=5)]))
    + panels_req_samples
).cols(
    2
).opts(*fig_opts).opts(
    opts.Layout(hspace=.6, vspace=0.6, sublabel_position=(-.35, .95))
)

hv.save(fig, 'fig/figS_cv.pdf')

fig
[14]: