CCA vs PLS errors

[12]:
import numpy as np
import xarray as xr

from gemmr.data import load_outcomes, print_ds_stats
from gemmr.metrics import *
from gemmr.sample_size.interpolation import calc_n_required_all_metrics
from gemmr.util import subset_ds

import colorcet
import holoviews as hv
hv.extension('matplotlib')
hv.renderer('matplotlib').param.set_param(dpi=120)
from holoviews import opts

from my_config import *

Setup

[3]:
data_home = None
ds_cca = load_outcomes('sweep_cca_cca_random_sum+-3+0_wOtherModel', model='cca', add_prefix='cca_', data_home=data_home).sel(mode=0)
ds_pls = load_outcomes('sweep_pls_pls_random_sum+-3+0_wOtherModel', model='pls', add_prefix='pls_', data_home=data_home).sel(mode=0)
ds = xr.concat([ds_cca, ds_pls], 'Sigma_id')
Loading data from subfolder 'gemmr_latest'
Loading data from subfolder 'gemmr_latest'
[4]:
ds_cca = ds_cca.sel(px=ds_cca.px < 128)
ds_pls = ds_pls.sel(px=ds_pls.px < 128)
[5]:
ds_cca = subset_ds(ds_cca, n_keep=25, keyvar='cca_between_assocs')
ds_pls = subset_ds(ds_pls, n_keep=25, keyvar='pls_between_assocs')
[6]:
print_ds_stats(ds_cca, prefix='cca_')
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.97, -0.10)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated
[7]:
print_ds_stats(ds_pls, prefix='pls_')
n_rep            100
n_per_ftr        [   3    4    8   16   32   64  128  256  512 1024 2048 4096 8192]
r                [0.1 0.3 0.5 0.7]
px               [ 2  4  8 16 32 64]
ax+ay range     (-2.97, -0.10)
py              == px

<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
       [25, 25, 25, 25],
       [25, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0, 25, 25, 25],
       [ 0,  0,  0, 21]])
Coordinates:
  * r        (r) float64 0.1 0.3 0.5 0.7
  * px       (px) int64 2 4 8 16 32 64

power           calculated

Analysis

[8]:
r = slice(None)
px = slice(None)
n_per_ftr = slice(None)#ds.n_per_ftr > 4

fig_ccaVsPls = hv.Layout()

for metric_lbl, metric in [
        ('combined', lambda *args, **kwargs: mk_combinedError(*args, assoc_metric='corr', abs_assoc_error=True, **kwargs)),
        ('power', mk_fnr),
        ('assocciation strength', mk_absBetweenAssocRelError),
        ('correlation', mk_absBetweenCorrRelError),
        ('weight', mk_weightError),
        ('score', mk_scoreError),
        ('loading', mk_loadingError)
    ]:

    if metric_lbl == 'loading':
        colorbar = True
    else:
        colorbar = False

    if metric_lbl in ['weight', 'score', 'loading', 'correlation']:
        xlabel = None
    else:
        xlabel = ''

    if metric_lbl in ['combined', 'weight']:
        ylabel = None
    else:
        ylabel = ''

    e_cca = metric(ds, prefix='cca_').dropna('n_per_ftr', how='all')
    e_pls = metric(ds, prefix='pls_').dropna('n_per_ftr', how='all')
    fig_ccaVsPls += (
        hv.HexTiles(
            (e_cca.where(e_cca < 0.5).sel(r=r, px=px, n_per_ftr=n_per_ftr).values.ravel(),
             e_pls.where(e_pls < 0.5).sel(r=r, px=px, n_per_ftr=n_per_ftr).values.ravel())
        )
        * hv.Curve(([0, 1], [0, 1])).opts(color='black')

    ).redim(
        x='Error CCA',
        y='Error PLS'
    ).opts(
        opts.HexTiles(
            cmap='magma', colorbar=colorbar, logz=True, clim=(1, 10000), gridsize=(10, 10), xlim=(0, 0.5), ylim=(0, 0.5),
            clabel='# analyzed datasets', xlabel=xlabel, ylabel=ylabel
        )
    ).relabel(
        metric_lbl
    )

fig_ccaVsPls = fig_ccaVsPls.cols(
    4
).opts(*fig_opts).opts(
    fig_inches=(7, None), sublabel_position=(-.45, .85)
)

fig_ccaVsPls
[8]:
[9]:
pc1_proj = np.maximum(
    np.abs(ds.pls_x_weights_true.sel(x_feature=0)),
    np.abs(ds.pls_y_weights_true.sel(y_feature=0)),
)
pc1_proj.dims
[9]:
('px', 'r', 'Sigma_id')
[10]:
px = slice(None)
r = slice(None)
n_per_ftr = slice(None)

fig_deltaE = hv.Layout()
for metric_lbl, metric in [
        ('combined', lambda *args, **kwargs: mk_combinedError(*args, assoc_metric='corr', abs_assoc_error=True, **kwargs)),
        ('power', mk_fnr),
        ('assocciation strength', mk_absBetweenAssocRelError),
        ('correlation', mk_absBetweenCorrRelError),
        ('weight', mk_weightError),
        ('score', mk_scoreError),
        ('loading', mk_loadingError)
    ]:

    if metric_lbl == 'loading':
        colorbar = True
    else:
        colorbar = False

    if metric_lbl in ['weight', 'score', 'loading', 'correlation']:
        xlabel = None
    else:
        xlabel = ''

    if metric_lbl in ['combined', 'weight']:
        ylabel = None
    else:
        ylabel = ''

    e_cca = metric(ds, prefix='cca_')
    e_pls = metric(ds, prefix='pls_')
    de = e_pls - e_cca

    if 'rep' in de.dims:
        pc1_proj_ = pc1_proj.expand_dims(rep=de.rep, n_per_ftr=de.n_per_ftr)
    else:
        pc1_proj_ = pc1_proj.expand_dims(n_per_ftr=de.n_per_ftr)

    fig_deltaE += hv.HexTiles(
        (pc1_proj_.transpose(*de.dims).sel(r=r, px=px, n_per_ftr=n_per_ftr).values.ravel(),
        de.where((np.abs(de) >= 0.0) & (np.abs(de) < .5) & ((e_cca < 0.5) | (e_pls < 0.5))).sel(r=r, px=px, n_per_ftr=n_per_ftr).values.ravel()),
    ).redim(
        x='PLS weight overlap w/ PC1 axis',
        y='Error PLS - error CCA',
    ).opts(
        opts.HexTiles(gridsize=(10, 10), colorbar=colorbar, logz=True, clim=(1, 10000), cmap='magma', clabel='# analyzed datasets', xlabel=xlabel, ylabel=ylabel)
    ).relabel(
        metric_lbl
    )

fig_deltaE = fig_deltaE.cols(
    4
).opts(*fig_opts).opts(
    fig_inches=(7, None), sublabel_position=(-.45, .85)
)

fig_deltaE
[10]:

Assemble figure

[11]:
fig = (
    fig_ccaVsPls
    + hv.Overlay()  # empty plot to fill space
    + fig_deltaE
)

fig = fig.cols(
    4
).opts(*fig_opts).opts(
    fig_inches=(7, None), vspace=.5, sublabel_position=(-.45, .9)
)

hv.save(fig, 'fig/figS_cca_vs_pls.pdf')

fig
WARNING:param.LayoutPlot09928: :Overlay is empty, skipping subplot.
WARNING:param.LayoutPlot10666: :Overlay is empty, skipping subplot.
[11]: