Samples per feature is a key parameter

We have used “samples per feature” as an effective sample size parameter to account for the fact that datasets can have hugely varying number of features. To see if this is justified we here plot sample-per-feature dependencies of several metrics separately for different feature-numbers. If “samples per feature” is a “good” parameterization, the curves for different feature numbers should overlap.

Setup

[1]:
import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.simplefilter('ignore', MatplotlibDeprecationWarning)
warnings.filterwarnings(
    'ignore', 'aspect is not supported for Axes with xscale=log, yscale=linear', category=UserWarning
)  # holoviews emits this for log-linear plots
from tqdm import TqdmExperimentalWarning
warnings.filterwarnings('ignore', category=TqdmExperimentalWarning)
warnings.simplefilter('ignore', UserWarning)

import numpy as np
import xarray as xr
from scipy.stats import pearsonr, spearmanr
from scipy.spatial.distance import pdist

from gemmr.data import load_outcomes, print_ds_stats
from gemmr.generative_model import *
from gemmr.sample_analysis import *
from gemmr.estimators import SVDCCA
from gemmr.sample_size import *
from gemmr.metrics import *
from gemmr.util import subset_ds

import holoviews as hv
from holoviews import opts
hv.renderer('matplotlib').param.set_param(dpi=120)
hv.extension('matplotlib')

from my_config import *
[3]:
r_clrs = hv.Palette(cmap_r, samples=4).values
linestyles = [
    (0, (1, 1)),
    (1, (2, 2)),
    (2, (3, 3)),
    (3, (4, 4)),
    (4, (5, 5)),
    (5, (6, 6)),
]
linestyles = [
    (0, (1, 1)),
    (0, (3, 2)),
    (0, (5, 3)),
    (0, (7, 4)),
    (0, (9, 5)),
    (0, (11, 6)),
]
[4]:
# will be populated below and then written to file
plotted_data = {}

Auxiliary plotting functions

[5]:
def plot_pxs_rs(ds, metric, r_clrs, y_label=None, rs=(.1, .3, .5, .7), reparameterize=False, eff_n_per_ftr=False, plotted_data=None, panel_label=None):

    metric = metric.rename('y')  # for redim later
    n_pxs = len(metric.px)

    panel = hv.Overlay()
    _plotted_data = {}
    for ri, r in enumerate(rs):
        if r not in metric.r:
            continue
        for pxi, px in enumerate(np.sort(metric.px.values)):
            py = ds.py.sel(Sigma_id=0, r=r, px=px)
            if np.isnan(py):
                continue
            else:
                py = int(py)

            _metric = metric.sel(r=r, px=px)

            if np.isnan(_metric.values).all():
                continue

            if eff_n_per_ftr:
                n_per_ftr_eff = (_metric.n_per_ftr*(px+py) - 1) / (px+py - 1)
                _metric = _metric.assign_coords(n_per_ftr=n_per_ftr_eff)

            if reparameterize:
                _metric = _metric.assign_coords(n_per_ftr=(_metric.n_per_ftr*(px+py))/(px+py)**1.5)

            panel *= hv.Curve(_metric).opts(linewidth=1, color=r_clrs[ri], linestyle=linestyles[pxi])  # , alpha=1 - pxi/n_pxs)

            _plotted_data[(r, px)] = _metric.squeeze().to_series()

    if plotted_data is not None:
        df = pd.concat(_plotted_data, axis=1)
        df.columns.names = ['r_true', 'pX']
        plotted_data[panel_label] = df

    if reparameterize:
        x_label = r'Samples / (# features)$^{1.5}$'
    else:
        x_label = 'Samples per feature'
    return panel.redim(
        n_per_ftr=x_label,
        y=y_label
    ).opts(logx=True, logy=True)
[6]:
def mk_fig(ds_cca, ds_pls, plotted_data=None, panel_label_prefix=''):
    delta_y = 1.5
    pxs = np.sort(ds_cca.px.values)
    linestyle_overlay = hv.Overlay()
    for pxi, px in enumerate(pxs):
        linestyle_overlay *= (
            hv.Curve([(90, .05*(delta_y)**pxi), (300, .05*(delta_y)**pxi)]).opts(logx=True, logy=True, color='black', linewidth=.5, linestyle=linestyles[pxi])
            * hv.Text(1000, .05*(delta_y)**pxi, f'{pxs[pxi]}', fontsize=7, halign='right', valign='center')
        )
    linestyle_overlay *= hv.Text(1000, .05*(delta_y)**len(pxs), '$p_X$', fontsize=7, halign='right', valign='center')
    return (
        (
            plot_pxs_rs(ds_cca, ds_cca.power.mean('Sigma_id'), r_clrs, 'Power', plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_CCA_power')
            * hv.Text(300, 0.05*(1.75)**2, r'$r_\mathrm{true}=$', fontsize=7, halign='right', valign='bottom')
            * hv.Text(1000, 0.05*(1.75)**2, r'$0.1$', fontsize=7, halign='right', valign='bottom').opts(color=r_clrs[0])
            * hv.Text(200, 0.05*1.75, r'$0.3$', fontsize=7, halign='right', valign='bottom').opts(color=r_clrs[1])
            * hv.Text(1000, 0.05*1.75, r'$0.5$', fontsize=7, halign='right', valign='bottom').opts(color=r_clrs[2])
            * hv.Text(200, 0.05, r'$0.7$', fontsize=7, halign='right', valign='bottom').opts(color=r_clrs[3])
            # * hv.Text(1000, 0.05, r'$0.9$', fontsize=7, halign='right', valign='bottom').opts(color=r_clrs[4])
        ).relabel('CCA').opts(logx=True, logy=True, xlabel='')
        + plot_pxs_rs(ds_pls, ds_pls.power.mean('Sigma_id'), r_clrs, 'Power', plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_PLS_power').relabel('PLS').opts(xlabel='', ylabel='')
        + (
            plot_pxs_rs(ds_pls, ds_pls.power.mean('Sigma_id'), r_clrs, 'Power', reparameterize=True, plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_PLSreparam_power')
            * linestyle_overlay
        ).relabel('PLS (reparameterized)').opts(xlabel='', ylabel='')
        #
        + plot_pxs_rs(ds_cca, mk_betweenAssocRelError(ds_cca).mean('rep').mean('Sigma_id'), r_clrs, 'Relative association\nstrength error', plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_CCA_relAssocStrengthError').opts(xlabel='', ylim=(None, 1), yticks=4)
        + plot_pxs_rs(ds_pls, mk_betweenAssocRelError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Relative Association strength error2', plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_PLS_relAssocStrengthError').opts(xlabel='', ylabel='', ylim=(None, 1), yticks=4)
        + plot_pxs_rs(ds_pls, mk_betweenAssocRelError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Relative Association strength error2', plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_PLSreparam_relAssocStrengthError', reparameterize=True).opts(xlabel='', ylabel='', ylim=(None, 1), yticks=4)
        #
        + plot_pxs_rs(ds_cca, mk_weightError(ds_cca).mean('rep').mean('Sigma_id'), r_clrs, 'Weight error', plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_CCA_weightError').opts(xlabel='')
        + plot_pxs_rs(ds_pls, mk_weightError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Weight error', plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_PLS_weightError').opts(xlabel='', ylabel='')
        + plot_pxs_rs(ds_pls, mk_weightError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Weight error', reparameterize=True, plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_PLSreparam_weightError').opts(xlabel='', ylabel='')
        #
        + plot_pxs_rs(ds_cca, mk_scoreError(ds_cca).mean('rep').mean('Sigma_id'), r_clrs, 'Score error', plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_CCA_scoreError').opts(xlabel='')
        + plot_pxs_rs(ds_pls, mk_scoreError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Score error', plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_PLS_scoreError').opts(xlabel='', ylabel='')
        + plot_pxs_rs(ds_pls, mk_scoreError(ds_pls).mean('rep').mean('Sigma_id'), r_clrs, 'Score error', reparameterize=True, plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_PLSreparam_scoreError').opts(xlabel='', ylabel='')
        #
        + plot_pxs_rs(ds_cca, mk_loadingError(ds_cca.where(ds_cca.px>2)).mean('rep').mean('Sigma_id'), r_clrs, 'Loading error', plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_CCA_loadingError')
        + plot_pxs_rs(ds_pls, mk_loadingError(ds_pls.where(ds_pls.px>2)).mean('rep').mean('Sigma_id'), r_clrs, 'Loading error', plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_PLS_loadingError').opts(ylabel='')
        + plot_pxs_rs(ds_pls, mk_loadingError(ds_pls.where(ds_pls.px>2)).mean('rep').mean('Sigma_id'), r_clrs, 'Loading error', reparameterize=True, plotted_data=plotted_data, panel_label=f'{panel_label_prefix}_PLSreparam_loadingError').opts(ylabel='')
    ).cols(
        3
    ).opts(*fig_opts).opts(
        opts.Overlay(xlim=(3, 1000), xticks=3),
        opts.Layout(hspace=.45, sublabel_position=(-.55, .85), fig_inches=(7./4*3, None))
    )

p_tot not fixed, p_X = p_Y

First, we investigate sample-per-feature dependence for different total numbers of features, but assume that the number of features in datasets \(X\) and \(Y\) is the same.

[7]:
ds_cca = load_outcomes('sweep_cca_cca_random_sum+-2+-2').sel(mode=0)
ds_pls = load_outcomes('sweep_pls_pls_random_sum+-2+-2').sel(mode=0)
Loading data from subfolder 'gemmr_latest'
Loading data from subfolder 'gemmr_latest'
[8]:
ds_cca = ds_cca.sel(px=ds_cca.px<128)
ds_pls = ds_pls.sel(px=ds_pls.px<128)
[9]:
ds_cca = subset_ds(ds_cca, n_keep=25)
ds_pls = subset_ds(ds_pls, n_keep=25)

What’s in the outcome data files?

[10]:
print_ds_stats(ds_cca)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated
[11]:
print_ds_stats(ds_pls)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.00, -2.00)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0,  0,  0, 10]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated
[12]:
fig_XeqY = mk_fig(ds_cca, ds_pls, plotted_data=plotted_data, panel_label_prefix='pX=pY')
fig_XeqY
[12]:

p_tot=64, p_x != p_Y

Second, we investigate sample-per-feature dependence assuming that total numbers of features is 64, but that the number of features in datasets \(X\) and \(Y\) is different.

[13]:
ds_cca = load_outcomes('sweep_cca_cca_random_sum+-2+-2_ptot64').sel(mode=0)
ds_pls = load_outcomes('sweep_pls_pls_random_sum+-2+-2_ptot64').sel(mode=0)
Loading data from subfolder 'gemmr_latest'
Loading data from subfolder 'gemmr_latest'
[14]:
ds_cca = subset_ds(ds_cca, n_keep=40)
ds_pls = subset_ds(ds_pls, n_keep=40)

What’s in the outcome data files?

[15]:
print_ds_stats(ds_cca)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32]
ax+ay range     (-2.00, -2.00)
py              != px

<xarray.DataArray 'n_Sigmas' (px: 5, r: 4)>
array([[40, 40, 40, 40],
       [40, 40, 40, 40],
       [40, 40, 40, 40],
       [40, 40, 40, 40],
       [40, 40, 40, 40]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32

power           calculated
[16]:
print_ds_stats(ds_pls)
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32]
ax+ay range     (-2.00, -2.00)
py              != px

<xarray.DataArray 'n_Sigmas' (px: 5, r: 4)>
array([[40, 40, 40, 40],
       [40, 40, 40, 40],
       [40, 40, 40, 40],
       [ 0, 40, 40, 40],
       [ 0, 40, 40, 40]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32

power           calculated
[17]:
fig_ptot64 = mk_fig(ds_cca, ds_pls, plotted_data=plotted_data, panel_label_prefix='pX+pY=64')
fig_ptot64
[17]:

Assemble figure

[18]:
class Bigtitle:
    def __init__(self, title):
        self.title = title
    def __call__(self, plot, element):
        ax = plot.handles['axis']
        ax.text(0.5, 1.3, self.title, ha='center', transform=ax.transAxes, fontsize=10)


fig = hv.Layout()

for r in range(5):

    for c in range(3):
        panel = fig_XeqY.Overlay[r*3 + c]
        if (r == 0) and (c == 1):
            panel = panel.opts(hooks=[Bigtitle('$p_x=p_y$')])

        fig += panel

    for c in range(3):
        panel = fig_ptot64.Overlay[r*3 + c]
        if c == 0:
            panel = panel.opts(ylabel='')
        if (r == 0) and (c == 1):
            panel = panel.opts(hooks=[Bigtitle('$p_x+p_y=64$')])
        fig += panel


default_fontsizes = dict(title=6, labels=6, ticks=6, minor_ticks=6, legend=6,)
fig = fig.cols(
    6
).opts(*fig_opts).opts(
    opts.Layout(fig_inches=(7, None), sublabel_size=7, fontsize=6, sublabel_position=(-.35, .95)),
    opts.Overlay(fontsize=default_fontsizes)
)

hv.save(fig, 'fig/figS_n_per_ftr_parameterization_combined.pdf')
save_source_data(plotted_data, 'fig3')

fig
[18]: