Sparse CCA¶
How many samples are required to bound power and error metrics with sparse CCA?
There are several slightly different sparse CCA methods. Here, we used PMA from the R package of the same name.
Setup¶
[1]:
import xarray as xr
import holoviews as hv
from holoviews import opts
hv.extension('matplotlib')
hv.renderer('matplotlib').set_param(dpi=120)
from gemmr.data import load_outcomes, print_ds_stats
from gemmr.sample_size import calc_n_required_all_metrics
from gemmr.util import nPerFtr2n
from gemmr.plot import heatmap_n_req
from my_config import *
import warnings
warnings.filterwarnings(
'ignore', 'aspect is not supported for Axes with xscale=log, yscale=linear', category=UserWarning
) # holoviews emits this for log-linear plots
[2]:
ds_cca = load_outcomes('cca')
ds_sparsecca = load_outcomes('cca', estr='sparsecca')
[3]:
print_ds_stats(ds_cca)
n_modes 1
n_rep 100
n_per_ftr [ 3 4 8 16 32 64 128 256 512 1024]
r [0.1 0.3 0.5 0.7 0.9]
px [ 2 4 8 16 32 64 128]
ax+ay range (0.00, 0.00)
py == px
<xarray.DataArray 'n_Sigmas' (px: 7, r: 5)>
array([[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[25, 25, 25, 25, 25],
[ 0, 25, 25, 25, 25],
[ 0, 25, 25, 25, 25]])
Coordinates:
* r (r) float64 0.1 0.3 0.5 0.7 0.9
* px (px) int64 2 4 8 16 32 64 128
power calculated
[4]:
print_ds_stats(ds_sparsecca)
n_modes 1
n_rep 100
n_per_ftr [ 4 8 16 32 64 128 256 512 1024]
r [0.3 0.5 0.7 0.9]
px [ 4 8 16 32 64]
ax+ay range (0.00, 0.00)
py == px
<xarray.DataArray 'n_Sigmas' (px: 5, r: 4)>
array([[6, 6, 6, 6],
[6, 6, 6, 6],
[6, 6, 6, 6],
[6, 6, 6, 6],
[6, 6, 6, 6]])
Coordinates:
* px (px) int64 4 8 16 32 64
* r (r) float64 0.3 0.5 0.7 0.9
power calculated
Determine required sample size¶
[5]:
cca_n_req_per_ftr = calc_n_required_all_metrics(ds_cca.sel(mode=0), search_dim='n_per_ftr')
sparsecca_n_req_per_ftr = calc_n_required_all_metrics(ds_sparsecca.sel(mode=0), search_dim='n_per_ftr')
[6]:
cca_n_req = nPerFtr2n(cca_n_req_per_ftr['combined'], ds_cca.py)
sparsecca_n_req = nPerFtr2n(sparsecca_n_req_per_ftr['combined'], ds_sparsecca.py)
[7]:
#Define sparse CCA required sample size heatmap
panel_sparse_required = heatmap_n_req(
sparsecca_n_req,
clabel='Required sample size'
).relabel(
'Required sample size:\nsparse CCA'
).opts(
opts.QuadMesh(logx=True, logz=True, cmap='Inferno', colorbar=True, sublabel_position=(-.45, .95))
)
[8]:
#Define relative difference in sample size between sparse CCA and CCA heatmap
def hook_arrow(plot, element):
ax = plot.handles['axis']
ax.arrow(1.4, .6, 0, .4, width=.001, length_includes_head=True, head_width=.03, color='black', clip_on=False, transform=ax.transAxes)
ax.text(1.525, .8, 'CCA better', rotation=90, ha='center', va='center', transform=ax.transAxes, fontdict={'fontsize': 8})
ax.arrow(1.4, .4, 0, -.4, width=.001, length_includes_head=True, head_width=.03, color='black', clip_on=False, transform=ax.transAxes)
ax.text(1.525, .2, 'sparse CCA better', rotation=90, ha='center', va='center', transform=ax.transAxes, fontdict={'fontsize': 8})
panel_fractional_dif_fig_required_sample_size_heatmap = (
heatmap_n_req((sparsecca_n_req.mean('Sigma_id') - cca_n_req.mean('Sigma_id')) / cca_n_req.mean('Sigma_id'),
clabel=' ')
).relabel(
'Relative difference in\nrequired sample size:\n(sparse CCA - CCA) / CCA'
).opts(
opts.QuadMesh(logx=True, logz=False, sublabel_position=(-.45, .95),
colorbar=True, cmap='RdBu_r', symmetric=True, hooks=[hook_arrow])
)
[10]:
#Combine sparse CCA heatmap and difference heatmap
combined_fig_dif = (
panel_sparse_required.opts(logz=True)
+ panel_fractional_dif_fig_required_sample_size_heatmap
).opts(*fig_opts).opts(
opts.Layout(fig_inches=(4, None), hspace=1),
opts.QuadMesh(sublabel_position=(-.45, .95))
)
hv.save(combined_fig_dif, 'fig/figS_sparseCCA.pdf')
combined_fig_dif
[10]:
How many features are required for \(r_\mathrm{true}=0.3\)?
[26]:
(
hv.Curve(cca_n_req_per_ftr['combined'].sel(r=0.3).mean('Sigma_id'), label='CCA')
* hv.Curve(sparsecca_n_req_per_ftr['combined'].sel(r=0.3).mean('Sigma_id'), label='sparse CCA')
).opts(*fig_opts).opts(fig_inches=(2, None), logx=True)
[26]: