Cross-validated estimation¶
What are the properties of cross-validated CCA and PLS estimations?
Setup¶
[1]:
import itertools
import numpy as np
import xarray as xr
import pandas as pd
import scipy.linalg
import scipy.stats
from scipy.stats import pearsonr, spearmanr, zscore
from scipy.spatial.distance import pdist, cdist, squareform
from sklearn.decomposition import PCA, SparsePCA
from sklearn.utils import check_random_state
from sklearn.model_selection import KFold, ShuffleSplit
from gemmr.estimators import SVDCCA, SVDPLS
from gemmr.generative_model import *
from gemmr.data import load_outcomes, print_ds_stats
from gemmr.sample_analysis import *
from gemmr.sample_size.interpolation import *
from gemmr.metrics import *
import matplotlib
from matplotlib.lines import Line2D
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').set_param(dpi=120)
from my_config import *
import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
warnings.filterwarnings(
'ignore', 'aspect is not supported for Axes with xscale=log, yscale=linear', category=UserWarning
) # holoviews emits this for log-linear plots
/Users/mdh56/Projects/gemmr/gemmr/sample_analysis/analyzers.py:13: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
from tqdm.autonotebook import tqdm
[2]:
r_clrs = hv.Palette(cmap_r, samples=3).values
[3]:
# load data
res = dict(
cca=load_outcomes('cca', tag='cv').sel(mode=0),
pls=load_outcomes('pls', tag='axPlusay-2_cv').sel(mode=0)
)
What’s in the outcome data files?
[4]:
print_ds_stats(res['cca'])
n_rep 100
n_per_ftr [ 3 4 8 16 32 64 128 256 512 1024]
r [0.1 0.3 0.5 0.7 0.9]
px [ 2 4 8 16 32]
ax+ay range (0.00, 0.00)
py == px
<xarray.DataArray 'n_Sigmas' (px: 5, r: 5)>
array([[10, 10, 10, 10, 10],
[10, 10, 10, 10, 10],
[10, 10, 10, 10, 10],
[10, 10, 10, 10, 10],
[10, 10, 10, 10, 10]])
Coordinates:
* r (r) float64 0.1 0.3 0.5 0.7 0.9
* px (px) int64 2 4 8 16 32
power not calculated
[5]:
print_ds_stats(res['pls'])
n_rep 100
n_per_ftr [ 3 4 8 16 32 64 128 256 512 1024 2048 4096 8192]
r [0.1 0.3 0.5 0.7 0.9]
px [ 2 4 8 16 32]
ax+ay range (-2.00, -2.00)
py == px
<xarray.DataArray 'n_Sigmas' (px: 5, r: 5)>
array([[10, 10, 10, 10, 10],
[10, 10, 10, 10, 10],
[10, 10, 10, 10, 10],
[10, 10, 10, 10, 10],
[10, 10, 10, 10, 10]])
Coordinates:
* r (r) float64 0.1 0.3 0.5 0.7 0.9
* px (px) int64 2 4 8 16 32
power not calculated
Dependence of association strength on samples per feature¶
[6]:
def plot_between_assocs_cv(ds, cv_assoc, y_label='between-set assoc strength'):
panel = hv.Overlay()
for ri, r in enumerate([.1, .3, .5,]):
panel *= (
hv.Curve(
(ds.between_assocs.sel(r=r).mean('rep').mean('Sigma_id')).mean('px')
).opts(color=r_clrs[ri])
* hv.Curve(
ds[cv_assoc].sel(r=r, cv='kfold5').mean('rep').mean('Sigma_id').mean('px')
).opts(color=r_clrs[ri], linestyle='-.')
* hv.Curve(
ds[cv_assoc].sel(r=r, cv='shuffle20x20%test').mean('rep').mean('Sigma_id').mean('px')
).opts(color=r_clrs[ri], linestyle=':')
)
return panel.redim(
n_per_ftr='Samples / feature',
between_assocs=y_label
).opts(
logx=True, logy=True
)
[7]:
panels_assoc_strength = (
plot_between_assocs_cv(res['cca'], 'between_corrs_cv', y_label='Association strength')
+ (
plot_between_assocs_cv(res['pls'], 'between_covs_cv', y_label='y')
* hv.Text(90, .5, '$r_\mathrm{true}=$', halign='right', valign='top', fontsize=7)
* hv.Text(100*(4)**1, .5, '0.5', halign='right', valign='top', fontsize=7).opts(color=r_clrs[-1])
* hv.Text(100*(4)**2, .5, '0.3', halign='right', valign='top', fontsize=7).opts(color=r_clrs[-2])
* hv.Text(100*(4)**3, .5, '0.1', halign='right', valign='top', fontsize=7).opts(color=r_clrs[-3])
).redim(
y=hv.Dimension('assoc_strength2', label='Association strength')
).opts(logx=True, logy=True, ylabel='', ylim=(.005, None))
).opts(*fig_opts)#.opts(logx=True, logy=True)
panels_assoc_strength
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
[7]:
Estimation errors of association strengths¶
[8]:
def plot_abs_errors(ds, cv_assoc, y_label='error'):
e = np.abs(mk_betweenAssocRelError(ds))
ecv = np.abs(mk_betweenAssocRelError_cv(ds, cv_assoc))
#relAssocE = mk_meanBetweenAssocRelError(ds).sel(mode=0)
panel = hv.Overlay()
for ri, r in enumerate([.1, .3, .5,]):
panel *= (
hv.Curve(
e.sel(r=r).mean('Sigma_id').mean('rep').mean('px')
).opts(color=r_clrs[ri])
* hv.Curve(
ecv.sel(r=r, cv='kfold5').mean('Sigma_id').mean('rep').mean('px')
).opts(color=r_clrs[ri], linestyle='--')
* hv.Curve(
ecv.sel(r=r, cv='shuffle20x20%test').mean('Sigma_id').mean('rep').mean('px')
).opts(color=r_clrs[ri], linestyle=':')
)
return panel.redim(
n_per_ftr='Samples / feature',
y=y_label,
).opts(
logx=True, logy=True, ylim=(None, 1)
)
[9]:
panels_assoc_strength_error = (
plot_abs_errors(res['cca'], 'between_corrs_cv', y_label='| Relative association\nstrength error |')
+ plot_abs_errors(res['pls'], 'between_covs_cv', y_label='| Relative association\nstrength error |').opts(ylabel='')
).opts(*fig_opts)#.opts(logx=True, logy=True)
panels_assoc_strength_error
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
[9]:
Required sample size¶
[10]:
def scatter_req_sample_size(ds, cv_assoc, error_level=0.1, lg=None):
e = np.abs(mk_betweenAssocRelError(ds))
ecv = np.abs(mk_betweenAssocRelError_cv(ds, cv_assoc))
relAssocE = mk_meanBetweenAssocRelError(ds, cv_assoc)
n_req = calc_n_required(e.mean('rep'), -error_level, error_level)
n_req_cv = calc_n_required(ecv.mean('rep'), -error_level, error_level)
n_req_mean = calc_n_required(relAssocE.mean('rep'), -error_level, error_level)
n_req = n_req.mean('Sigma_id').mean('px')
n_req_cv = n_req_cv.mean('Sigma_id').mean('px')
n_req_mean = n_req_mean.mean('Sigma_id').mean('px')
if lg is None:
lg = [False]*5
else:
lg = np.asarray(lg).astype(bool).tolist()
return (
hv.Scatter((n_req.r, n_req), label='in-sample').opts(color=clr_inSample, marker='s', show_legend=lg[0])
* hv.Scatter((n_req_cv.r*.975, n_req_cv.sel(cv='kfold5')), label='5-fold CV').opts(color=clr_5cv, marker='o', show_legend=lg[1])
* hv.Scatter((n_req_cv.r*1.025, n_req_cv.sel(cv='shuffle20x20%test')), label=r'20$\times$5-fold CV').opts(color=clr_20x5cv, marker='o', show_legend=lg[2])
* hv.Scatter((n_req_mean.r*.975, n_req_mean.sel(cv='kfold5')), label='5-fold CV').opts(color=clr_5cvMean, marker='d', show_legend=lg[3])
* hv.Scatter((n_req_mean.r*1.025, n_req_mean.sel(cv='shuffle20x20%test')), label=r'20$\times$5-fold CV').opts(color=clr_20x5cvMean, marker='d', show_legend=lg[4])
).redim(
x='$r_\mathrm{true}$',
y='Req samples / feature'
).opts(
opts.Scatter(alpha=1),
opts.Overlay(logx=True, logy=True, padding=.02, xlim=(.08, 1))
)
[11]:
error_level = 0.1
panels_req_samples = (
(
scatter_req_sample_size(res['cca'], 'between_corrs_cv', error_level=error_level, lg=[1, 1, 1, 0, 0])
* hv.Text(.09, 10, 'in-sample', halign='left', valign='top', fontsize=7).opts(color=clr_inSample)
* hv.Text(.09, 5, '5-fold CV', halign='left', valign='top', fontsize=7).opts(color=clr_5cv)
* hv.Text(.09, 2.5, '20 x 5-fold CV', halign='left', valign='top', fontsize=7).opts(color=clr_20x5cv)
)
+ (
scatter_req_sample_size(res['pls'], 'between_covs_cv', error_level=error_level, lg=[0, 0, 0, 1, 1])
* hv.Text(.09, 10, 'avg with in-sample:', halign='left', valign='top', fontsize=7).opts(color='black')
* hv.Text(.09, 5, '5-fold CV', halign='left', valign='top', fontsize=7).opts(color=clr_5cvMean)
* hv.Text(.09, 2.5, '20 x 5-fold CV', halign='left', valign='top', fontsize=7).opts(color=clr_20x5cvMean)
).opts(ylabel='')
).opts(*fig_opts).opts(
opts.Overlay(xlim=(.08, 1), ylim=(1, None), logx=True, logy=True, show_legend=False)
)
panels_req_samples
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
return np.nanmean(a, axis=axis, dtype=dtype)
[11]:
Figure¶
[12]:
import matplotlib
[13]:
matplotlib.ticker.LogLocator()
[13]:
<matplotlib.ticker.LogLocator at 0x11e04e6a0>
[14]:
def hook_ls_legend(plot, element):
ax = plot.handles['axis']
l = Line2D([10], [0.1], color='black', linestyle='-', linewidth=1)
handles = [
Line2D([10], [0.1], color='black', linestyle='-', linewidth=1),
Line2D([10], [0.1], color='black', linestyle='-.', linewidth=1),
Line2D([10], [0.1], color='black', linestyle=':', linewidth=1)
]
ax.legend(handles, ['in sample', '5-fold CV', r'20$\times$5-fold CV'],
handlelength=1.5, fontsize=7, frameon=False,
loc='lower right', bbox_to_anchor=(1.4, 0)
)
fig = (
panels_assoc_strength.Overlay.I.opts(xlim=(3, 9000), hooks=[hook_ls_legend, Format_log_axis('x', major_numticks=4, minor_numticks=5), suptitle_cca])
+ panels_assoc_strength.Overlay.II.opts(xlim=(3, 9000), hooks=[Format_log_axis('x', major_numticks=4, minor_numticks=5), suptitle_pls])
+ panels_assoc_strength_error.opts(opts.Overlay(xlim=(3, 9000), hooks=[Format_log_axis('x', major_numticks=4, minor_numticks=5)]))
+ panels_req_samples
).cols(
2
).opts(*fig_opts).opts(
opts.Layout(hspace=.6, vspace=0.6, sublabel_position=(-.35, .95))
)
hv.save(fig, 'fig/figS_cv.pdf')
fig
[14]: