-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
1,566 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.ipynb filter=strip-notebook-output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,337 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Sparse and orthogonal low-rank Collective Matrix Factorization (solrCMF)\n", | ||
"\n", | ||
"This is a package describing the data integration methodology from [\"Sparse and orthogonal low-rank Collective Matrix Factorization (solrCMF): Efficient data integration in flexible layouts\" (Held et al., 2024, arXiv:2405.10067)](https://arxiv.org/abs/2405.10067).\n", | ||
"\n", | ||
"To install the package run\n", | ||
"```sh\n", | ||
"pip install git+https://github.com/cyianor/solrcmf.git\n", | ||
"```\n", | ||
"\n", | ||
"A simple usage example is shown below:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import solrcmf\n", | ||
"import numpy as np\n", | ||
"from numpy.random import default_rng\n", | ||
"\n", | ||
"# Control randomness\n", | ||
"rng = default_rng(42)\n", | ||
"\n", | ||
"# Simulate some data\n", | ||
"# - `viewdims`: Dimensions of each view\n", | ||
"# - `factor_scales`: The strength/singular value of each factor.\n", | ||
"# The diagonal of the D matrices in the paper.\n", | ||
"# Tuples are used to name data matrices. The first two\n", | ||
"# entries describe the relationship between views observed\n", | ||
"# in the data matrix. The third and following entries\n", | ||
"# are used to make the index unique which is relevant\n", | ||
"# in case of repeated layers of an observed relationship.\n", | ||
"# - `snr`: Signal-to-noise ratio of the noise added to each true signal\n", | ||
"# (can be different for each data matrix)\n", | ||
"# - `factor_sparsity`: Controls how sparse factors are generated in each\n", | ||
"# view V_i.\n", | ||
"# Example: 0.25 means 25% non-zero entries for each factor\n", | ||
"# in a view.\n", | ||
"#\n", | ||
"# The function below generates sparse orthogonal matrices V_i and uses the\n", | ||
"# supplied D_ij to form signal matrices V_i D_ij V_j^T. Noise with\n", | ||
"# residual variance controlled by the signal-to-noise ratio is added.\n", | ||
"xs_sim = solrcmf.simulate(\n", | ||
" viewdims={0: 100, 1: 50, 2: 50},\n", | ||
" factor_scales={\n", | ||
" (0, 1): [7.0, 4.5, 3.9, 0.0, 0.0],\n", | ||
" (0, 2): [8.3, 0.0, 0.0, 5.5, 0.0],\n", | ||
" (1, 2, 0): [6.3, 0.0, 4.7, 0.0, 5.1],\n", | ||
" (1, 2, 1): [0.0, 8.6, 4.9, 0.0, 0.0],\n", | ||
" },\n", | ||
" factor_sparsity={0: 0.25, 1: 0.25, 2: 0.25},\n", | ||
" snr=0.5,\n", | ||
" rng=rng,\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"# `xs_sim` is a dictionary containing\n", | ||
"# - \"xs_truth\", the true signal matrices\n", | ||
"# - \"xs\", the noisy data\n", | ||
"# - \"vs\", the simulated orthogonal factors" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Estimation via multi-block ADMM is encapsulated in the class `SolrCMF` which has a convenient scikit-learn interface." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# It is recommended to center input matrices along rows and columns as\n", | ||
"# well as to scale their Frobenius norm to 1.\n", | ||
"xs_centered = {k: solrcmf.bicenter(x)[0] for k, x in xs_sim[\"xs\"].items()}\n", | ||
"xs_scaled = {k: x / np.sqrt((x**2).sum()) for k, x in xs_centered.items()}\n", | ||
"\n", | ||
"# To determine good starting values, different strategies can be employed.\n", | ||
"# In the paper, the algorithm was run repeatedly on random starting values\n", | ||
"# without penalization and the best result is used as a starting point\n", | ||
"# for hyperparameter selection.\n", | ||
"# The data needs to be provided and `max_rank` sets a maximum rank for the\n", | ||
"# low-rank matrix factorization.\n", | ||
"# - `n_inits` controls how often new random starting values are selected.\n", | ||
"# - `rng` controls the random number generation\n", | ||
"# - `n_jobs` allows to parallize the search for initial values and is used\n", | ||
"# like in joblib https://joblib.readthedocs.io/en/stable/generated/joblib.Parallel.html\n", | ||
"est_init = solrcmf.best_random_init(\n", | ||
" xs_scaled, max_rank=10, n_inits=50, n_jobs=4, rng=rng\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create a SolrCMF estimator\n", | ||
"# - `max_rank` is the maximum rank of the low-rank matrix factorization\n", | ||
"# - `structure_penalty` controls the integration penalty on the\n", | ||
"# diagonal entries of D_ij\n", | ||
"# - `factor_penalty` controls the factor sparsity penalty on the\n", | ||
"# entries of V_i (indirectly through U_i in the ADMM algorithm)\n", | ||
"# - `mu` controls how similar factors in U_i and V_i should be. A larger value\n", | ||
"# forces them together more closely\n", | ||
"# - `init` can be set to \"random\" which constructs a random starting state\n", | ||
"# or \"custom\". In the latter case, a starting state for\n", | ||
"# * `vs`: the V_i matrices\n", | ||
"# * `ds`: the D_ij matrices\n", | ||
"# * `us`: the U_i matrices\n", | ||
"# needs to be supplied when calling the `fit` method. See the example below.\n", | ||
"# - `factor_pruning` whether or not factors without any contribution should be\n", | ||
"# removed during estimation.\n", | ||
"# - `max_iter`: Maximum number of iterations\n", | ||
"est = solrcmf.SolrCMF(\n", | ||
" max_rank=10,\n", | ||
" structure_penalty=0.05,\n", | ||
" factor_penalty=0.08,\n", | ||
" mu=10,\n", | ||
" init=\"custom\",\n", | ||
" factor_pruning=False,\n", | ||
" max_iter=100000,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The estimation is then performed by fitting the model to data. Use the\n", | ||
"final values of the initial runs as starting values. Penalty parameters are not chosen optimally here." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"est.fit(xs_scaled, vs=est_init.vs_, ds=est_init.ds_, us=est_init.vs_)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Estimates for $D_{ij}$ are then in `est.ds_` and estimates for $V_i$ in `est.vs_`.\n", | ||
"\n", | ||
"Scale back to original scale." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for k, d in est.ds_.items():\n", | ||
" rescaled_d = d * np.sqrt((xs_centered[k] ** 2).sum())\n", | ||
" print(\n", | ||
" f\"{str(k):10s}: \"\n", | ||
" f\"{np.array2string(rescaled_d, precision=2, floatmode='fixed')}\"\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Shrinkage can be clearly seen in the singular value estimates compared to the groundtruth.\n", | ||
"\n", | ||
"Setting the right hyperparameters is non-trivial and\n", | ||
"more rigorous method is necessary. The class `SolrCMFCV` is provided for this\n", | ||
"purpose to perform cross-validation automatically.\n", | ||
"\n", | ||
"Cross-validation performs a two-step procedure:\n", | ||
"\n", | ||
"1. Possible model structures are determined by estimating the model for all supplied pairs of hyperparameters. Zero patterns in singular values and factors are recorded.\n", | ||
"2. Cross-validation is then performed by fixing each zero pattern obtained in Step 1 and estimating model parameters on all $K$ combinations of training folds. Test errors are computed on the respective left-out test fold.\n", | ||
"\n", | ||
"The final solution is found by determining the pair of hyperparameters that leads to the minimal CV error and to pick those parameters that are within one standard error of the minimal CV error with most sparsity in the singular values." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Lists of structure and factor penalties are supplied containing the\n", | ||
"# parameter combinations to be tested. Lists need to be of the same length\n", | ||
"# or one needs to be a scalar.\n", | ||
"# - `cv` number of folds\n", | ||
"est_cv = solrcmf.SolrCMFCV(\n", | ||
" max_rank=10,\n", | ||
" structure_penalty=np.exp(rng.uniform(np.log(5e-2), np.log(1.0), 100)),\n", | ||
" factor_penalty=np.exp(rng.uniform(np.log(5e-2), np.log(1.0), 100)),\n", | ||
" mu=10,\n", | ||
" cv=10,\n", | ||
" init=\"custom\",\n", | ||
" max_iter=100000,\n", | ||
" n_jobs=4,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Perform hyperparameter selection. This step can be time-intensive." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Initial values are supplied as lists. If length 1 then they are reused.\n", | ||
"# If same length as hyperparameters then different initial values can be used\n", | ||
"# for each pair of hyperparameters.\n", | ||
"est_cv.fit(xs_scaled, vs=[est_init.vs_], ds=[est_init.ds_], us=[est_init.vs_])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"CV results can be found in the attribute `est_cv.cv_results_` and can be easily converted to a Pandas `DataFrame`. The best result corresponds to the row with index `est_cv.best_index_`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"\n", | ||
"cv_res = pd.DataFrame(est_cv.cv_results_)\n", | ||
"cv_res.loc[est_cv.best_index_, :]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for k, d in est_cv.best_estimator_.ds_.items():\n", | ||
" rescaled_d = d * np.sqrt((xs_centered[k] ** 2).sum())\n", | ||
" print(\n", | ||
" f\"{str(k):10s}: \"\n", | ||
" f\"{np.array2string(rescaled_d, precision=2, floatmode='fixed')}\"\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Due to the small size of the data sources and signal-to-noise ratio of 0.5, it is not possible to recover singular values perfectly. However, thanks to unpenalized re-estimation, the strong shrinkage seen in the manual solution above is not present here.\n", | ||
"\n", | ||
"The factor estimates are in `est_cv.best_estimator_.vs_`, however, sparse factors can be found in `est_cv.best_estimator_.us_`. In this particular run, factor 0 of view 0 in the groundtruth corresponds to factor 0 in view 0 of the estimate. Note that in general factor order is arbitrary." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"np.sum(xs_sim[\"vs\"][0][:, 0] * est_cv.best_estimator_.us_[0][:, 0])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The correctness of the estimated sparsity pattern can be analysed by looking at true positive and false positive rate." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def true_positive_rate(estimate, truth):\n", | ||
" return sum(np.logical_and(estimate != 0.0, truth != 0.0)) / sum(\n", | ||
" truth != 0.0\n", | ||
" )\n", | ||
"\n", | ||
"\n", | ||
"def false_positive_rate(estimate, truth):\n", | ||
" return sum(np.logical_and(estimate != 0.0, truth == 0.0)) / sum(\n", | ||
" truth == 0.0\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"(\n", | ||
" true_positive_rate(\n", | ||
" xs_sim[\"vs\"][0][:, 0], est_cv.best_estimator_.us_[0][:, 0]\n", | ||
" ),\n", | ||
" false_positive_rate(\n", | ||
" xs_sim[\"vs\"][0][:, 0], est_cv.best_estimator_.us_[0][:, 0]\n", | ||
" ),\n", | ||
")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.