CCA vs PLS errors
[12]:
import numpy as np
import xarray as xr
from gemmr.data import load_outcomes, print_ds_stats
from gemmr.metrics import *
from gemmr.sample_size.interpolation import calc_n_required_all_metrics
from gemmr.util import subset_ds
import colorcet
import holoviews as hv
hv.extension('matplotlib')
hv.renderer('matplotlib').param.set_param(dpi=120)
from holoviews import opts
from my_config import *
Setup
[3]:
data_home = None
ds_cca = load_outcomes('sweep_cca_cca_random_sum+-3+0_wOtherModel', model='cca', add_prefix='cca_', data_home=data_home).sel(mode=0)
ds_pls = load_outcomes('sweep_pls_pls_random_sum+-3+0_wOtherModel', model='pls', add_prefix='pls_', data_home=data_home).sel(mode=0)
ds = xr.concat([ds_cca, ds_pls], 'Sigma_id')
Loading data from subfolder 'gemmr_latest'
Loading data from subfolder 'gemmr_latest'
[4]:
ds_cca = ds_cca.sel(px=ds_cca.px < 128)
ds_pls = ds_pls.sel(px=ds_pls.px < 128)
[5]:
ds_cca = subset_ds(ds_cca, n_keep=25, keyvar='cca_between_assocs')
ds_pls = subset_ds(ds_pls, n_keep=25, keyvar='pls_between_assocs')
[6]:
print_ds_stats(ds_cca, prefix='cca_')
n_rep 100
n_per_ftr [ 3 4 8 16 32 64 128 256 512 1024 2048 4096 8192]
r [0.1 0.3 0.5 0.7]
px [ 2 4 8 16 32 64]
ax+ay range (-2.97, -0.10)
py == px
<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
[25, 25, 25, 25],
[25, 25, 25, 25],
[25, 25, 25, 25],
[25, 25, 25, 25],
[ 0, 25, 25, 25]])
Coordinates:
* r (r) float64 0.1 0.3 0.5 0.7
* px (px) int64 2 4 8 16 32 64
power calculated
[7]:
print_ds_stats(ds_pls, prefix='pls_')
n_rep 100
n_per_ftr [ 3 4 8 16 32 64 128 256 512 1024 2048 4096 8192]
r [0.1 0.3 0.5 0.7]
px [ 2 4 8 16 32 64]
ax+ay range (-2.97, -0.10)
py == px
<xarray.DataArray 'n_Sigmas' (px: 6, r: 4)>
array([[25, 25, 25, 25],
[25, 25, 25, 25],
[25, 25, 25, 25],
[ 0, 25, 25, 25],
[ 0, 25, 25, 25],
[ 0, 0, 0, 21]])
Coordinates:
* r (r) float64 0.1 0.3 0.5 0.7
* px (px) int64 2 4 8 16 32 64
power calculated
Analysis
[8]:
r = slice(None)
px = slice(None)
n_per_ftr = slice(None)#ds.n_per_ftr > 4
fig_ccaVsPls = hv.Layout()
for metric_lbl, metric in [
('combined', lambda *args, **kwargs: mk_combinedError(*args, assoc_metric='corr', abs_assoc_error=True, **kwargs)),
('power', mk_fnr),
('assocciation strength', mk_absBetweenAssocRelError),
('correlation', mk_absBetweenCorrRelError),
('weight', mk_weightError),
('score', mk_scoreError),
('loading', mk_loadingError)
]:
if metric_lbl == 'loading':
colorbar = True
else:
colorbar = False
if metric_lbl in ['weight', 'score', 'loading', 'correlation']:
xlabel = None
else:
xlabel = ''
if metric_lbl in ['combined', 'weight']:
ylabel = None
else:
ylabel = ''
e_cca = metric(ds, prefix='cca_').dropna('n_per_ftr', how='all')
e_pls = metric(ds, prefix='pls_').dropna('n_per_ftr', how='all')
fig_ccaVsPls += (
hv.HexTiles(
(e_cca.where(e_cca < 0.5).sel(r=r, px=px, n_per_ftr=n_per_ftr).values.ravel(),
e_pls.where(e_pls < 0.5).sel(r=r, px=px, n_per_ftr=n_per_ftr).values.ravel())
)
* hv.Curve(([0, 1], [0, 1])).opts(color='black')
).redim(
x='Error CCA',
y='Error PLS'
).opts(
opts.HexTiles(
cmap='magma', colorbar=colorbar, logz=True, clim=(1, 10000), gridsize=(10, 10), xlim=(0, 0.5), ylim=(0, 0.5),
clabel='# analyzed datasets', xlabel=xlabel, ylabel=ylabel
)
).relabel(
metric_lbl
)
fig_ccaVsPls = fig_ccaVsPls.cols(
4
).opts(*fig_opts).opts(
fig_inches=(7, None), sublabel_position=(-.45, .85)
)
fig_ccaVsPls
[8]:
[9]:
pc1_proj = np.maximum(
np.abs(ds.pls_x_weights_true.sel(x_feature=0)),
np.abs(ds.pls_y_weights_true.sel(y_feature=0)),
)
pc1_proj.dims
[9]:
('px', 'r', 'Sigma_id')
[10]:
px = slice(None)
r = slice(None)
n_per_ftr = slice(None)
fig_deltaE = hv.Layout()
for metric_lbl, metric in [
('combined', lambda *args, **kwargs: mk_combinedError(*args, assoc_metric='corr', abs_assoc_error=True, **kwargs)),
('power', mk_fnr),
('assocciation strength', mk_absBetweenAssocRelError),
('correlation', mk_absBetweenCorrRelError),
('weight', mk_weightError),
('score', mk_scoreError),
('loading', mk_loadingError)
]:
if metric_lbl == 'loading':
colorbar = True
else:
colorbar = False
if metric_lbl in ['weight', 'score', 'loading', 'correlation']:
xlabel = None
else:
xlabel = ''
if metric_lbl in ['combined', 'weight']:
ylabel = None
else:
ylabel = ''
e_cca = metric(ds, prefix='cca_')
e_pls = metric(ds, prefix='pls_')
de = e_pls - e_cca
if 'rep' in de.dims:
pc1_proj_ = pc1_proj.expand_dims(rep=de.rep, n_per_ftr=de.n_per_ftr)
else:
pc1_proj_ = pc1_proj.expand_dims(n_per_ftr=de.n_per_ftr)
fig_deltaE += hv.HexTiles(
(pc1_proj_.transpose(*de.dims).sel(r=r, px=px, n_per_ftr=n_per_ftr).values.ravel(),
de.where((np.abs(de) >= 0.0) & (np.abs(de) < .5) & ((e_cca < 0.5) | (e_pls < 0.5))).sel(r=r, px=px, n_per_ftr=n_per_ftr).values.ravel()),
).redim(
x='PLS weight overlap w/ PC1 axis',
y='Error PLS - error CCA',
).opts(
opts.HexTiles(gridsize=(10, 10), colorbar=colorbar, logz=True, clim=(1, 10000), cmap='magma', clabel='# analyzed datasets', xlabel=xlabel, ylabel=ylabel)
).relabel(
metric_lbl
)
fig_deltaE = fig_deltaE.cols(
4
).opts(*fig_opts).opts(
fig_inches=(7, None), sublabel_position=(-.45, .85)
)
fig_deltaE
[10]:
Assemble figure
[11]:
fig = (
fig_ccaVsPls
+ hv.Overlay() # empty plot to fill space
+ fig_deltaE
)
fig = fig.cols(
4
).opts(*fig_opts).opts(
fig_inches=(7, None), vspace=.5, sublabel_position=(-.45, .9)
)
hv.save(fig, 'fig/figS_cca_vs_pls.pdf')
fig
WARNING:param.LayoutPlot09928: :Overlay is empty, skipping subplot.
WARNING:param.LayoutPlot10666: :Overlay is empty, skipping subplot.
[11]: