Samples per feature is a key parameter

We have used “samples per feature” as an effective sample size parameter to account for the fact that datasets can have hugely varying number of features. To see if this is justified we here plot sample-per-feature dependencies of several metrics separately for different feature-numbers. If “samples per feature” is a “good” parameterization, the curves for different feature numbers should overlap.

Setup

[1]:
import numpy as np
import xarray as xr
from scipy.stats import pearsonr, spearmanr
from scipy.spatial.distance import pdist

from gemmr.data import load_outcomes, print_ds_stats
from gemmr.generative_model import *
from gemmr.sample_analysis import *
from gemmr.estimators import SVDCCA
from gemmr.sample_size import *
from gemmr.metrics import *

import holoviews as hv
from holoviews import opts
hv.renderer('matplotlib').set_param(dpi=120)
hv.extension('matplotlib')

from my_config import *

import warnings
warnings.simplefilter('ignore', UserWarning)
/Users/mdh56/Projects/gemmr/gemmr/sample_analysis/analyzers.py:13: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  from tqdm.autonotebook import tqdm
[2]:
r_clrs = hv.Palette(cmap_r, samples=5).values

Auxiliary plotting functions

[3]:
def plot_pxs_rs(ds, metric, r_clrs, y_label=None, rs=(.1, .3, .5, .7, .9), reparameterize=False):

    metric = metric.rename('y')  # for redim later
    n_pxs = len(metric.px)

    panel = hv.Overlay()
    for ri, r in enumerate(rs):
        if r not in metric.r:
            continue
        for pxi, px in enumerate(np.sort(metric.px.values)):
            py = ds.py.sel(Sigma_id=0, r=r, px=px)
            if np.isnan(py):
                continue
            else:
                py = int(py)

            _metric = metric.sel(r=r, px=px)
            if np.isnan(_metric.values).all():
                continue

            if reparameterize:
                _metric = _metric.assign_coords(n_per_ftr=(_metric.n_per_ftr*(px+py))/(px+py)**1.5)

            panel *= hv.Curve(_metric).opts(linewidth=1, color=r_clrs[ri], alpha=1 - pxi/n_pxs)

    if reparameterize:
        x_label = r'Samples / (# features)$^{1.5}$'
    else:
        x_label = 'Samples per feature'
    return panel.redim(
        n_per_ftr=x_label,
        y=y_label
    ).opts(logx=True, logy=True)
[4]:
def mk_fig(ds_cca, ds_pls):
    return (
        (
            plot_pxs_rs(ds_cca, ds_cca.power.mean('Sigma_id'), r_clrs, 'Power')
            * hv.Text(300, 0.05*(1.5)**2, r'$r_\mathrm{true}=$', fontsize=7, halign='right', valign='bottom')
            * hv.Text(1000, 0.05*(1.5)**2, r'$0.1$', fontsize=7, halign='right', valign='bottom').opts(color=r_clrs[0])
            * hv.Text(200, 0.05*1.5, r'$0.3$', fontsize=7, halign='right', valign='bottom').opts(color=r_clrs[1])
            * hv.Text(1000, 0.05*1.5, r'$0.5$', fontsize=7, halign='right', valign='bottom').opts(color=r_clrs[2])
            * hv.Text(200, 0.05, r'$0.7$', fontsize=7, halign='right', valign='bottom').opts(color=r_clrs[3])
            * hv.Text(1000, 0.05, r'$0.9$', fontsize=7, halign='right', valign='bottom').opts(color=r_clrs[4])
        ).relabel('CCA').opts(logx=True, logy=True, xlabel='')
        + plot_pxs_rs(ds_pls, ds_pls.power.mean('Sigma_id'), r_clrs, 'Power').relabel('PLS').opts(xlabel='', ylabel='')
        + plot_pxs_rs(ds_pls, ds_pls.power.mean('Sigma_id'), r_clrs, 'Power', reparameterize=True).relabel('PLS (reparameterized)').opts(xlabel='', ylabel='')
        #
        + plot_pxs_rs(ds_cca, mk_betweenAssocRelError(ds_cca).mean('rep').mean('Sigma_id'), r_clrs, 'Relative association\nstrength error').opts(xlabel='', ylim=(None, 1), yticks=4)
        + plot_pxs_rs(ds_pls, mk_betweenAssocRelError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Relative Association strength error2').opts(xlabel='', ylabel='', ylim=(None, 1), yticks=4)
        + plot_pxs_rs(ds_pls, mk_betweenAssocRelError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Relative Association strength error2', reparameterize=True).opts(xlabel='', ylabel='', ylim=(None, 1), yticks=4)
        #
        + plot_pxs_rs(ds_cca, mk_weightError(ds_cca).mean('rep').mean('Sigma_id'), r_clrs, 'Weight error').opts(xlabel='')
        + plot_pxs_rs(ds_pls, mk_weightError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Weight error').opts(xlabel='', ylabel='')
        + plot_pxs_rs(ds_pls, mk_weightError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Weight error', reparameterize=True).opts(xlabel='', ylabel='')
        #
        + plot_pxs_rs(ds_cca, mk_scoreError(ds_cca).mean('rep').mean('Sigma_id'), r_clrs, 'Score error').opts(xlabel='')
        + plot_pxs_rs(ds_pls, mk_scoreError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Score error').opts(xlabel='', ylabel='')
        + plot_pxs_rs(ds_pls, mk_scoreError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Score error', reparameterize=True).opts(xlabel='', ylabel='')
        #
        + plot_pxs_rs(ds_cca, mk_loadingError(ds_cca).mean('rep').mean('Sigma_id'), r_clrs, 'Loading error')
        + plot_pxs_rs(ds_pls, mk_loadingError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Loading error').opts(ylabel='')
        + plot_pxs_rs(ds_pls, mk_loadingError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Loading error', reparameterize=True).opts(ylabel='')
    ).cols(
        3
    ).opts(*fig_opts).opts(
        opts.Overlay(xlim=(3, 1000), xticks=3),
        opts.Layout(hspace=.45, sublabel_position=(-.55, .85), fig_inches=(7./4*3, None))
    )

p_tot not fixed, p_X = p_Y

First, we investigate sample-per-feature dependence for different total numbers of features, but assume that the number of features in datasets \(X\) and \(Y\) is the same.

[5]:
ds_cca = load_outcomes('cca').sel(mode=0)
ds_pls = load_outcomes('pls', tag='axPlusay-2').sel(mode=0)

What’s in the outcome data files?

[6]:
print_ds_stats(ds_cca)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024]
r                [0.1 0.3 0.5 0.7 0.9]
px               [  2   4   8  16  32  64 128]
ax+ay range     (0.00, 0.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32 64 128

power           calculated
[7]:
print_ds_stats(ds_pls)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7 0.9]
px               [  2   4   8  16  32  64 128]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [25, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0, 25, 25, 25, 25],
       [ 0,  0,  0, 25, 25],
       [ 0,  0,  0, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32 64 128

power           calculated
[8]:
fig = mk_fig(ds_cca, ds_pls)

hv.save(fig, 'fig/figS_n_per_ftr_parameterization.pdf')

fig
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
WARNING:param.OverlayPlot245474: Logarithmic axis range encountered value less than or equal to zero, please supply explicit lower-bound to override default of 0.010.
WARNING:param.OverlayPlot247391: Logarithmic axis range encountered value less than or equal to zero, please supply explicit lower-bound to override default of 0.010.
WARNING:param.OverlayPlot249068: Logarithmic axis range encountered value less than or equal to zero, please supply explicit lower-bound to override default of 0.010.
WARNING:param.OverlayPlot307465: Logarithmic axis range encountered value less than or equal to zero, please supply explicit lower-bound to override default of 0.010.
WARNING:param.OverlayPlot309382: Logarithmic axis range encountered value less than or equal to zero, please supply explicit lower-bound to override default of 0.010.
WARNING:param.OverlayPlot311059: Logarithmic axis range encountered value less than or equal to zero, please supply explicit lower-bound to override default of 0.010.
[8]:

p_tot=64, p_x != p_Y

Second, we investigate sample-per-feature dependence assuming that total numbers of features is 64, but that the number of features in datasets \(X\) and \(Y\) is different.

[9]:
ds_cca = load_outcomes('cca', tag='ptot64').sel(mode=0)
ds_pls = load_outcomes('pls', tag='axPlusay-2_ptot64').sel(mode=0)

What’s in the outcome data files?

[10]:
print_ds_stats(ds_cca)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024]
r                [0.1 0.3 0.5 0.7 0.9]
px               [ 2  4  8 16 32]
ax+ay range     (0.00, 0.00)
py              != px

<xarray.DataArray 'n_Sigmas' (px: 5, r: 5)>
array([[10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32

power           calculated
[11]:
print_ds_stats(ds_pls)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024]
r                [0.3 0.5 0.7 0.9]
px               [ 2  4  8 16 32]
ax+ay range     (-2.00, -2.00)
py              != px

<xarray.DataArray 'n_Sigmas' (px: 5, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25]])
Coordinates:
  * r        (r) float64 0.3 0.5 0.7 0.9
  * px       (px) int64 2 4 8 16 32

power           calculated
[12]:
fig_ptot64 = mk_fig(ds_cca, ds_pls)

hv.save(fig_ptot64, 'fig/figS_n_per_ftr_parameterization_ptot64.pdf')

fig_ptot64
/anaconda3/envs/gemmrtest/lib/python3.8/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice
  return np.nanmean(a, axis=axis, dtype=dtype)
[12]: