From 2ce772814c6e19912ca3ed5ca1b1bb764a892ecc Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Wed, 15 Feb 2023 22:16:52 -0800 Subject: [PATCH 1/3] initial --- docs/references.bib | 10 ++ src/scib_metrics/__init__.py | 2 + src/scib_metrics/_cms.py | 77 ++++++++++ src/scib_metrics/utils/__init__.py | 3 + src/scib_metrics/utils/_anderson.py | 218 ++++++++++++++++++++++++++++ tests/test_metrics.py | 5 + tests/utils/test_anderson.py | 16 ++ 7 files changed, 331 insertions(+) create mode 100644 src/scib_metrics/_cms.py create mode 100644 src/scib_metrics/utils/_anderson.py create mode 100644 tests/utils/test_anderson.py diff --git a/docs/references.bib b/docs/references.bib index 310e658..60bac48 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -36,3 +36,13 @@ @article{buttner2018 pages = {43--49}, publisher = {Springer Science and Business Media {LLC}} } + +@article{lutge2021cellmixs, + title={CellMixS: quantifying and visualizing batch effects in single-cell RNA-seq data}, + author={L{\"u}tge, Almut and Zyprych-Walczak, Joanna and Kunzmann, Urszula Brykczynska and Crowell, Helena L and Calini, Daniela and Malhotra, Dheeraj and Soneson, Charlotte and Robinson, Mark D}, + journal={Life science alliance}, + volume={4}, + number={6}, + year={2021}, + publisher={Life Science Alliance} +} diff --git a/src/scib_metrics/__init__.py b/src/scib_metrics/__init__.py index 58bcace..29391f5 100644 --- a/src/scib_metrics/__init__.py +++ b/src/scib_metrics/__init__.py @@ -2,6 +2,7 @@ from importlib.metadata import version from . import nearest_neighbors, utils +from ._cms import cell_mixing_score from ._graph_connectivity import graph_connectivity from ._isolated_labels import isolated_labels from ._kbet import kbet, kbet_per_label @@ -26,6 +27,7 @@ "kbet", "kbet_per_label", "graph_connectivity", + "cell_mixing_score", ] __version__ = version("scib-metrics") diff --git a/src/scib_metrics/_cms.py b/src/scib_metrics/_cms.py new file mode 100644 index 0000000..92d18ed --- /dev/null +++ b/src/scib_metrics/_cms.py @@ -0,0 +1,77 @@ +import warnings +from functools import partial + +import jax.numpy as jnp +import numpy as np +import pandas as pd +from scipy.sparse import csr_matrix + +from scib_metrics.utils import anderson_ksamp, convert_knn_graph_to_idx + + +def _cms_one_cell( + knn_dists: jnp.ndarray, knn_cats: jnp.ndarray, n_categories: int, cell_min: int = 4, unbalanced: bool = False +): + # filter categories with too few cells (cell_min) + cat_counts = jnp.bincount(knn_cats, length=n_categories) + cat_values = jnp.arange(n_categories) + cats_to_use = jnp.where(cat_counts >= cell_min)[0] + cat_values = cat_values[cats_to_use] + mask = jnp.isin(knn_cats, cat_values) + knn_cats = knn_cats[mask] + knn_dists = knn_dists[mask] + + # do not perform AD test if only one group with enough cells is in knn. + if len(cats_to_use) <= 1: + p = jnp.nan if unbalanced else 0.0 + else: + # filter cells with the same representation + if jnp.any(knn_dists == 0): + warnings.warn("Distances equal to 0 - cells with identical representations detected. NaN assigned!") + p = jnp.nan + else: + # perform AD test with remaining cell + res = anderson_ksamp([knn_dists[knn_cats == cat] for cat in cat_values]) + p = res.significance_level + + return p + + +def cell_mixing_score(X: csr_matrix, batches: np.ndarray, cell_min: int = 10, unbalanced: bool = False) -> np.ndarray: + """Compute the cell-specific mixing score (cms) :cite:p:`lutge2021cellmixs`. + + Parameters + ---------- + X + Array of shape (n_cells, n_cells) with non-zero values + representing distances to exactly each cell's k nearest neighbors. + labels + Array of shape (n_cells,) representing cell type label values + for each cell. + cell_min + Minimum number of cells from each group to be included into the Anderson-Darling test. + unbalanced + If True neighborhoods with only one batch present will be set to NaN. This way they are not included into + any summaries or smoothing. + + Returns + ------- + cms + Array of shape (n_cells,) with the cms score for each cell. + """ + categorical_type_batches = pd.Categorical(batches) + batches = np.asarray(categorical_type_batches.codes) + n_categories = len(categorical_type_batches.categories) + knn_dists, knn_idx = convert_knn_graph_to_idx(X) + knn_cats = jnp.asarray(batches[knn_idx]) + knn_dists = jnp.asarray(knn_dists) + + cms_fn = partial(_cms_one_cell, n_categories=n_categories, cell_min=cell_min, unbalanced=unbalanced) + + ps = [] + for dists, cats in zip(knn_dists, knn_cats): + ps.append(cms_fn(dists, cats)) + + # TODO: add smoothing + + return np.array(ps) diff --git a/src/scib_metrics/utils/__init__.py b/src/scib_metrics/utils/__init__.py index f555528..6130c9c 100644 --- a/src/scib_metrics/utils/__init__.py +++ b/src/scib_metrics/utils/__init__.py @@ -1,3 +1,4 @@ +from ._anderson import Anderson_ksampResult, anderson_ksamp from ._diffusion_nn import diffusion_nn from ._dist import cdist, pdist_squareform from ._kmeans import KMeansJax @@ -20,4 +21,6 @@ "convert_knn_graph_to_idx", "check_square", "diffusion_nn", + "Anderson_ksampResult", + "anderson_ksamp", ] diff --git a/src/scib_metrics/utils/_anderson.py b/src/scib_metrics/utils/_anderson.py new file mode 100644 index 0000000..0986664 --- /dev/null +++ b/src/scib_metrics/utils/_anderson.py @@ -0,0 +1,218 @@ +import warnings +from dataclasses import dataclass +from typing import Sequence + +import jax.numpy as jnp +from jax.tree_util import tree_map + +from .._types import NdArray + + +def _anderson_ksamp_midrank( + samples: Sequence[jnp.ndarray], Z: jnp.ndarray, Zstar: jnp.ndarray, k: int, n: jnp.ndarray, N: int +): + """Compute A2akN equation 7 of Scholz and Stephens. + + Parameters + ---------- + samples + Array of sample arrays. + Z + Sorted array of all observations. + Zstar + Sorted array of unique observations. + k + Number of samples. + n + Number of observations in each sample. + N + Total number of observations. + + Returns + ------- + A2aKN + The A2aKN statistics of Scholz and Stephens 1987. + """ + A2akN = 0.0 + Z_ssorted_left = Z.searchsorted(Zstar, "left") + if N == Zstar.size: + lj = 1.0 + else: + lj = Z.searchsorted(Zstar, "right") - Z_ssorted_left + Bj = Z_ssorted_left + lj / 2.0 + for i in jnp.arange(0, k): + s = jnp.sort(samples[i]) + s_ssorted_right = s.searchsorted(Zstar, side="right") + Mij = s_ssorted_right.astype(float) + fij = s_ssorted_right - s.searchsorted(Zstar, "left") + Mij -= fij / 2.0 + inner = lj / float(N) * (N * Mij - Bj * n[i]) ** 2 / (Bj * (N - Bj) - N * lj / 4.0) + A2akN += inner.sum() / n[i] + A2akN *= (N - 1.0) / N + return A2akN + + +def _anderson_ksamp_right( + samples: Sequence[jnp.ndarray], Z: jnp.ndarray, Zstar: jnp.ndarray, k: int, n: jnp.ndarray, N: int +): + """Compute A2akN equation 6 of Scholz & Stephens. + + Parameters + ---------- + samples + Array of sample arrays. + Z + Sorted array of all observations. + Zstar + Sorted array of unique observations. + k + Number of samples. + n + Number of observations in each sample. + N + Total number of observations. + + Returns + ------- + A2KN + The A2KN statistics of Scholz and Stephens 1987. + """ + A2kN = 0.0 + lj = Z.searchsorted(Zstar[:-1], "right") - Z.searchsorted(Zstar[:-1], "left") + Bj = lj.cumsum() + for i in jnp.arange(0, k): + s = jnp.sort(samples[i]) + Mij = s.searchsorted(Zstar[:-1], side="right") + inner = lj / float(N) * (N * Mij - Bj * n[i]) ** 2 / (Bj * (N - Bj)) + A2kN += inner.sum() / n[i] + return A2kN + + +@dataclass +class Anderson_ksampResult: + """Result of `anderson_ksamp`. + + Attributes + ---------- + statistic : float + Normalized k-sample Anderson-Darling test statistic. + critical_values : array + The critical values for significance levels 25%, 10%, 5%, 2.5%, 1%, + 0.5%, 0.1%. + significance_level : float + The approximate p-value of the test. The value is floored / capped + at 0.1% / 25%. + """ + + statistic: float + critical_values: jnp.ndarray + significance_level: float + + +def anderson_ksamp(samples: Sequence[NdArray], midrank: bool = True) -> Anderson_ksampResult: + """Jax implementation of :func:`scipy.stats.anderson_ksamp`. + + The k-sample Anderson-Darling test is a modification of the + one-sample Anderson-Darling test. It tests the null hypothesis + that k-samples are drawn from the same population without having + to specify the distribution function of that population. The + critical values depend on the number of samples. + + Parameters + ---------- + samples + Array of sample data in arrays. + midrank + Type of Anderson-Darling test which is computed. Default + (True) is the midrank test applicable to continuous and + discrete populations. If False, the right side empirical + distribution is used. + + Returns + ------- + result + + Raises + ------ + ValueError + If less than 2 samples are provided, a sample is empty, or no + distinct observations are in the samples. + + Notes + ----- + [1]_ defines three versions of the k-sample Anderson-Darling test: + one for continuous distributions and two for discrete + distributions, in which ties between samples may occur. The + default of this routine is to compute the version based on the + midrank empirical distribution function. This test is applicable + to continuous and discrete data. If midrank is set to False, the + right side empirical distribution is used for a test for discrete + data. According to [1]_, the two discrete test statistics differ + only slightly if a few collisions due to round-off errors occur in + the test not adjusted for ties between samples. + The critical values corresponding to the significance levels from 0.01 + to 0.25 are taken from [1]_. p-values are floored / capped + at 0.1% / 25%. Since the range of critical values might be extended in + future releases, it is recommended not to test ``p == 0.25``, but rather + ``p >= 0.25`` (analogously for the lower bound). + + References + ---------- + .. [1] Scholz, F. W and Stephens, M. A. (1987), K-Sample + Anderson-Darling Tests, Journal of the American Statistical + Association, Vol. 82, pp. 918-924. + """ + k = len(samples) + if k < 2: + raise ValueError("anderson_ksamp needs at least two samples") + + samples = tree_map(jnp.asarray, samples) + Z = jnp.sort(jnp.hstack(samples)) + N = jnp.array(Z.size) + Zstar = jnp.unique(Z) + if Zstar.size < 2: + raise ValueError("anderson_ksamp needs more than one distinct " "observation") + + n = jnp.array([sample.size for sample in samples]) + if jnp.any(n == 0): + raise ValueError("anderson_ksamp encountered sample without " "observations") + + if midrank: + A2kN = _anderson_ksamp_midrank(samples, Z, Zstar, k, n, N) + else: + A2kN = _anderson_ksamp_right(samples, Z, Zstar, k, n, N) + + H = (1.0 / n).sum() + hs_cs = (1.0 / jnp.arange(N - 1, 1, -1)).cumsum() + h = hs_cs[-1] + 1 + g = (hs_cs / jnp.arange(2, N)).sum() + + a = (4 * g - 6) * (k - 1) + (10 - 6 * g) * H + b = (2 * g - 4) * k**2 + 8 * h * k + (2 * g - 14 * h - 4) * H - 8 * h + 4 * g - 6 + c = (6 * h + 2 * g - 2) * k**2 + (4 * h - 4 * g + 6) * k + (2 * h - 6) * H + 4 * h + d = (2 * h + 6) * k**2 - 4 * h * k + sigmasq = (a * N**3 + b * N**2 + c * N + d) / ((N - 1.0) * (N - 2.0) * (N - 3.0)) + m = k - 1 + A2 = (A2kN - m) / jnp.sqrt(sigmasq) + + # The b_i values are the interpolation coefficients from Table 2 + # of Scholz and Stephens 1987 + b0 = jnp.array([0.675, 1.281, 1.645, 1.96, 2.326, 2.573, 3.085]) + b1 = jnp.array([-0.245, 0.25, 0.678, 1.149, 1.822, 2.364, 3.615]) + b2 = jnp.array([-0.105, -0.305, -0.362, -0.391, -0.396, -0.345, -0.154]) + critical = b0 + b1 / jnp.sqrt(m) + b2 / m + + sig = jnp.array([0.25, 0.1, 0.05, 0.025, 0.01, 0.005, 0.001]) + if A2 < critical.min(): + p = sig.max() + warnings.warn(f"p-value capped: true value larger than {p}", stacklevel=2) + elif A2 > critical.max(): + p = sig.min() + warnings.warn(f"p-value floored: true value smaller than {p}", stacklevel=2) + else: + # interpolation of probit of significance level + pf = jnp.polyfit(critical, jnp.log(sig), 2) + p = jnp.exp(jnp.polyval(pf, A2)) + + res = Anderson_ksampResult(A2, critical, p) + return res diff --git a/tests/test_metrics.py b/tests/test_metrics.py index fa53e30..3fbea13 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -77,6 +77,11 @@ def test_ilisi_clisi_knn(): scib_metrics.clisi_knn(X, labels, perplexity=10) +def test_cms(): + X, _, batches = dummy_x_labels_batch(x_is_neighbors_graph=True) + scib_metrics.cell_mixing_score(X, batches) + + def test_nmi_ari_cluster_labels_kmeans(): X, labels = dummy_x_labels() out = scib_metrics.nmi_ari_cluster_labels_kmeans(X, labels) diff --git a/tests/utils/test_anderson.py b/tests/utils/test_anderson.py new file mode 100644 index 0000000..59a2e9e --- /dev/null +++ b/tests/utils/test_anderson.py @@ -0,0 +1,16 @@ +import jax.numpy as jnp +import numpy as np +from scipy import stats + +from scib_metrics.utils import anderson_ksamp + + +def test_anderson_vs_scipy(): + """Test that the Anderson-Darling test gives the same results as scipy.stats""" + rng = np.random.default_rng() + data = [rng.normal(size=50), rng.normal(loc=0.5, size=30)] + orig_res = stats.anderson_ksamp(data) + jax_res = anderson_ksamp([jnp.asarray(d) for d in data]) + assert np.isclose(orig_res.statistic, jax_res.statistic) + assert np.allclose(orig_res.critical_values, jax_res.critical_values) + assert np.isclose(orig_res.significance_level, jax_res.significance_level) From 33732bb0fa16a2f88a25480397f27c0474aa7627 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Thu, 16 Feb 2023 12:53:53 -0800 Subject: [PATCH 2/3] use basic numpy --- docs/api.md | 1 + pyproject.toml | 1 + src/scib_metrics/_cms.py | 29 ++++--- src/scib_metrics/utils/__init__.py | 3 +- src/scib_metrics/utils/_anderson.py | 114 ++++++++++++---------------- tests/test_metrics.py | 3 +- 6 files changed, 67 insertions(+), 84 deletions(-) diff --git a/docs/api.md b/docs/api.md index 055fd88..7288172 100644 --- a/docs/api.md +++ b/docs/api.md @@ -69,6 +69,7 @@ scib_metrics.ilisi_knn(...) utils.convert_knn_graph_to_idx utils.check_square utils.diffusion_nn + utils.anderson_ksamp ``` ### Nearest neighbors diff --git a/pyproject.toml b/pyproject.toml index 419c529..e397619 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "matplotlib", "plottable", "tqdm", + "numba", ] [project.optional-dependencies] diff --git a/src/scib_metrics/_cms.py b/src/scib_metrics/_cms.py index 92d18ed..6d97423 100644 --- a/src/scib_metrics/_cms.py +++ b/src/scib_metrics/_cms.py @@ -1,34 +1,33 @@ import warnings from functools import partial -import jax.numpy as jnp import numpy as np import pandas as pd from scipy.sparse import csr_matrix +from scipy.stats import anderson_ksamp -from scib_metrics.utils import anderson_ksamp, convert_knn_graph_to_idx +from scib_metrics.utils import convert_knn_graph_to_idx def _cms_one_cell( - knn_dists: jnp.ndarray, knn_cats: jnp.ndarray, n_categories: int, cell_min: int = 4, unbalanced: bool = False + knn_dists: np.ndarray, knn_cats: np.ndarray, n_categories: int, cell_min: int = 4, unbalanced: bool = False ): # filter categories with too few cells (cell_min) - cat_counts = jnp.bincount(knn_cats, length=n_categories) - cat_values = jnp.arange(n_categories) - cats_to_use = jnp.where(cat_counts >= cell_min)[0] + cat_values, cat_counts = np.unique(knn_cats, return_counts=True) + cats_to_use = np.where(cat_counts >= cell_min)[0] cat_values = cat_values[cats_to_use] - mask = jnp.isin(knn_cats, cat_values) + mask = np.isin(knn_cats, cat_values) knn_cats = knn_cats[mask] knn_dists = knn_dists[mask] # do not perform AD test if only one group with enough cells is in knn. if len(cats_to_use) <= 1: - p = jnp.nan if unbalanced else 0.0 + p = np.nan if unbalanced else 0.0 else: # filter cells with the same representation - if jnp.any(knn_dists == 0): + if np.any(knn_dists == 0): warnings.warn("Distances equal to 0 - cells with identical representations detected. NaN assigned!") - p = jnp.nan + p = np.nan else: # perform AD test with remaining cell res = anderson_ksamp([knn_dists[knn_cats == cat] for cat in cat_values]) @@ -63,14 +62,12 @@ def cell_mixing_score(X: csr_matrix, batches: np.ndarray, cell_min: int = 10, un batches = np.asarray(categorical_type_batches.codes) n_categories = len(categorical_type_batches.categories) knn_dists, knn_idx = convert_knn_graph_to_idx(X) - knn_cats = jnp.asarray(batches[knn_idx]) - knn_dists = jnp.asarray(knn_dists) + knn_cats = np.asarray(batches[knn_idx]) + knn_dists = np.asarray(knn_dists) cms_fn = partial(_cms_one_cell, n_categories=n_categories, cell_min=cell_min, unbalanced=unbalanced) - - ps = [] - for dists, cats in zip(knn_dists, knn_cats): - ps.append(cms_fn(dists, cats)) + vectorized_fn = np.vectorize(cms_fn, signature="(n),(n)->()") + ps = vectorized_fn(knn_dists, knn_cats) # TODO: add smoothing diff --git a/src/scib_metrics/utils/__init__.py b/src/scib_metrics/utils/__init__.py index 6130c9c..4d5367b 100644 --- a/src/scib_metrics/utils/__init__.py +++ b/src/scib_metrics/utils/__init__.py @@ -1,4 +1,4 @@ -from ._anderson import Anderson_ksampResult, anderson_ksamp +from ._anderson import anderson_ksamp from ._diffusion_nn import diffusion_nn from ._dist import cdist, pdist_squareform from ._kmeans import KMeansJax @@ -21,6 +21,5 @@ "convert_knn_graph_to_idx", "check_square", "diffusion_nn", - "Anderson_ksampResult", "anderson_ksamp", ] diff --git a/src/scib_metrics/utils/_anderson.py b/src/scib_metrics/utils/_anderson.py index 0986664..e720dd5 100644 --- a/src/scib_metrics/utils/_anderson.py +++ b/src/scib_metrics/utils/_anderson.py @@ -1,15 +1,12 @@ -import warnings -from dataclasses import dataclass -from typing import Sequence +from typing import Sequence, Tuple -import jax.numpy as jnp -from jax.tree_util import tree_map - -from .._types import NdArray +import numba +import numpy as np +@numba.njit def _anderson_ksamp_midrank( - samples: Sequence[jnp.ndarray], Z: jnp.ndarray, Zstar: jnp.ndarray, k: int, n: jnp.ndarray, N: int + samples: Sequence[np.ndarray], Z: np.ndarray, Zstar: np.ndarray, k: int, n: np.ndarray, N: int ): """Compute A2akN equation 7 of Scholz and Stephens. @@ -34,26 +31,27 @@ def _anderson_ksamp_midrank( The A2aKN statistics of Scholz and Stephens 1987. """ A2akN = 0.0 - Z_ssorted_left = Z.searchsorted(Zstar, "left") + Z_ssorted_left = np.searchsorted(Z, Zstar, "left").astype(np.float32) if N == Zstar.size: - lj = 1.0 + lj = np.ones_like(Z_ssorted_left) else: - lj = Z.searchsorted(Zstar, "right") - Z_ssorted_left + lj = np.searchsorted(Z, Zstar, "right").astype(np.float32) - Z_ssorted_left Bj = Z_ssorted_left + lj / 2.0 - for i in jnp.arange(0, k): - s = jnp.sort(samples[i]) - s_ssorted_right = s.searchsorted(Zstar, side="right") - Mij = s_ssorted_right.astype(float) - fij = s_ssorted_right - s.searchsorted(Zstar, "left") + for i in np.arange(0, k): + s = np.sort(samples[i]) + s_ssorted_right = np.searchsorted(s, Zstar, side="right").astype(np.float32) + Mij = s_ssorted_right + fij = s_ssorted_right - np.searchsorted(s, Zstar, "left").astype(np.float32) Mij -= fij / 2.0 - inner = lj / float(N) * (N * Mij - Bj * n[i]) ** 2 / (Bj * (N - Bj) - N * lj / 4.0) + inner = lj / numba.float32(N) * (N * Mij - Bj * n[i]) ** 2 / (Bj * (N - Bj) - N * lj / 4.0) A2akN += inner.sum() / n[i] A2akN *= (N - 1.0) / N return A2akN +@numba.njit def _anderson_ksamp_right( - samples: Sequence[jnp.ndarray], Z: jnp.ndarray, Zstar: jnp.ndarray, k: int, n: jnp.ndarray, N: int + samples: Sequence[np.ndarray], Z: np.ndarray, Zstar: np.ndarray, k: int, n: np.ndarray, N: int ): """Compute A2akN equation 6 of Scholz & Stephens. @@ -78,39 +76,19 @@ def _anderson_ksamp_right( The A2KN statistics of Scholz and Stephens 1987. """ A2kN = 0.0 - lj = Z.searchsorted(Zstar[:-1], "right") - Z.searchsorted(Zstar[:-1], "left") + lj = np.searchsorted(Z, Zstar[:-1], "right") - np.searchsorted(Z, Zstar[:-1], "left") Bj = lj.cumsum() - for i in jnp.arange(0, k): - s = jnp.sort(samples[i]) - Mij = s.searchsorted(Zstar[:-1], side="right") + for i in np.arange(0, k): + s = np.sort(samples[i]) + Mij = np.searchsorted(s, Zstar[:-1], side="right") inner = lj / float(N) * (N * Mij - Bj * n[i]) ** 2 / (Bj * (N - Bj)) A2kN += inner.sum() / n[i] return A2kN -@dataclass -class Anderson_ksampResult: - """Result of `anderson_ksamp`. - - Attributes - ---------- - statistic : float - Normalized k-sample Anderson-Darling test statistic. - critical_values : array - The critical values for significance levels 25%, 10%, 5%, 2.5%, 1%, - 0.5%, 0.1%. - significance_level : float - The approximate p-value of the test. The value is floored / capped - at 0.1% / 25%. - """ - - statistic: float - critical_values: jnp.ndarray - significance_level: float - - -def anderson_ksamp(samples: Sequence[NdArray], midrank: bool = True) -> Anderson_ksampResult: - """Jax implementation of :func:`scipy.stats.anderson_ksamp`. +@numba.njit +def anderson_ksamp(samples: Sequence[np.ndarray], midrank: bool = True) -> Tuple[float, np.ndarray, float]: + """Numba-friendly implementation of :func:`scipy.stats.anderson_ksamp`. The k-sample Anderson-Darling test is a modification of the one-sample Anderson-Darling test. It tests the null hypothesis @@ -130,7 +108,7 @@ def anderson_ksamp(samples: Sequence[NdArray], midrank: bool = True) -> Anderson Returns ------- - result + result, tuple of (statistic, critical_values, significance_level) Raises ------ @@ -166,15 +144,23 @@ def anderson_ksamp(samples: Sequence[NdArray], midrank: bool = True) -> Anderson if k < 2: raise ValueError("anderson_ksamp needs at least two samples") - samples = tree_map(jnp.asarray, samples) - Z = jnp.sort(jnp.hstack(samples)) - N = jnp.array(Z.size) - Zstar = jnp.unique(Z) + # join all samples into one long sample + samples = list(map(np.asarray, samples)) + n_samples = sum(list(map(len, samples))) + long_samples = np.empty(n_samples) + i = 0 + for sample in samples: + for s in sample: + long_samples[i] = s + i += 1 + Z = np.sort(long_samples) + N = np.array(Z.size) + Zstar = np.unique(Z) if Zstar.size < 2: raise ValueError("anderson_ksamp needs more than one distinct " "observation") - n = jnp.array([sample.size for sample in samples]) - if jnp.any(n == 0): + n = np.array([sample.size for sample in samples]) + if np.any(n == 0): raise ValueError("anderson_ksamp encountered sample without " "observations") if midrank: @@ -183,9 +169,9 @@ def anderson_ksamp(samples: Sequence[NdArray], midrank: bool = True) -> Anderson A2kN = _anderson_ksamp_right(samples, Z, Zstar, k, n, N) H = (1.0 / n).sum() - hs_cs = (1.0 / jnp.arange(N - 1, 1, -1)).cumsum() + hs_cs = (1.0 / np.arange(N - 1, 1, -1)).cumsum() h = hs_cs[-1] + 1 - g = (hs_cs / jnp.arange(2, N)).sum() + g = (hs_cs / np.arange(2, N)).sum() a = (4 * g - 6) * (k - 1) + (10 - 6 * g) * H b = (2 * g - 4) * k**2 + 8 * h * k + (2 * g - 14 * h - 4) * H - 8 * h + 4 * g - 6 @@ -193,26 +179,24 @@ def anderson_ksamp(samples: Sequence[NdArray], midrank: bool = True) -> Anderson d = (2 * h + 6) * k**2 - 4 * h * k sigmasq = (a * N**3 + b * N**2 + c * N + d) / ((N - 1.0) * (N - 2.0) * (N - 3.0)) m = k - 1 - A2 = (A2kN - m) / jnp.sqrt(sigmasq) + A2 = (A2kN - m) / np.sqrt(sigmasq) # The b_i values are the interpolation coefficients from Table 2 # of Scholz and Stephens 1987 - b0 = jnp.array([0.675, 1.281, 1.645, 1.96, 2.326, 2.573, 3.085]) - b1 = jnp.array([-0.245, 0.25, 0.678, 1.149, 1.822, 2.364, 3.615]) - b2 = jnp.array([-0.105, -0.305, -0.362, -0.391, -0.396, -0.345, -0.154]) - critical = b0 + b1 / jnp.sqrt(m) + b2 / m + b0 = np.array([0.675, 1.281, 1.645, 1.96, 2.326, 2.573, 3.085]) + b1 = np.array([-0.245, 0.25, 0.678, 1.149, 1.822, 2.364, 3.615]) + b2 = np.array([-0.105, -0.305, -0.362, -0.391, -0.396, -0.345, -0.154]) + critical = b0 + b1 / np.sqrt(m) + b2 / m - sig = jnp.array([0.25, 0.1, 0.05, 0.025, 0.01, 0.005, 0.001]) + sig = np.array([0.25, 0.1, 0.05, 0.025, 0.01, 0.005, 0.001]) if A2 < critical.min(): p = sig.max() - warnings.warn(f"p-value capped: true value larger than {p}", stacklevel=2) elif A2 > critical.max(): p = sig.min() - warnings.warn(f"p-value floored: true value smaller than {p}", stacklevel=2) else: # interpolation of probit of significance level - pf = jnp.polyfit(critical, jnp.log(sig), 2) - p = jnp.exp(jnp.polyval(pf, A2)) + pf = np.polyfit(critical, np.log(sig), 2) + p = np.exp(np.polyval(pf, A2)) - res = Anderson_ksampResult(A2, critical, p) + res = (A2, critical, p) return res diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 3fbea13..c952634 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -79,7 +79,8 @@ def test_ilisi_clisi_knn(): def test_cms(): X, _, batches = dummy_x_labels_batch(x_is_neighbors_graph=True) - scib_metrics.cell_mixing_score(X, batches) + score = scib_metrics.cell_mixing_score(X, batches) + assert len(score) == X.shape[0] def test_nmi_ari_cluster_labels_kmeans(): From d3666f95b8a198bcd3ac57facf4f15725361cb3d Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Thu, 16 Feb 2023 12:54:20 -0800 Subject: [PATCH 3/3] use basic numpy --- src/scib_metrics/utils/__init__.py | 2 - src/scib_metrics/utils/_anderson.py | 202 ---------------------------- tests/utils/test_anderson.py | 16 --- 3 files changed, 220 deletions(-) delete mode 100644 src/scib_metrics/utils/_anderson.py delete mode 100644 tests/utils/test_anderson.py diff --git a/src/scib_metrics/utils/__init__.py b/src/scib_metrics/utils/__init__.py index 4d5367b..f555528 100644 --- a/src/scib_metrics/utils/__init__.py +++ b/src/scib_metrics/utils/__init__.py @@ -1,4 +1,3 @@ -from ._anderson import anderson_ksamp from ._diffusion_nn import diffusion_nn from ._dist import cdist, pdist_squareform from ._kmeans import KMeansJax @@ -21,5 +20,4 @@ "convert_knn_graph_to_idx", "check_square", "diffusion_nn", - "anderson_ksamp", ] diff --git a/src/scib_metrics/utils/_anderson.py b/src/scib_metrics/utils/_anderson.py deleted file mode 100644 index e720dd5..0000000 --- a/src/scib_metrics/utils/_anderson.py +++ /dev/null @@ -1,202 +0,0 @@ -from typing import Sequence, Tuple - -import numba -import numpy as np - - -@numba.njit -def _anderson_ksamp_midrank( - samples: Sequence[np.ndarray], Z: np.ndarray, Zstar: np.ndarray, k: int, n: np.ndarray, N: int -): - """Compute A2akN equation 7 of Scholz and Stephens. - - Parameters - ---------- - samples - Array of sample arrays. - Z - Sorted array of all observations. - Zstar - Sorted array of unique observations. - k - Number of samples. - n - Number of observations in each sample. - N - Total number of observations. - - Returns - ------- - A2aKN - The A2aKN statistics of Scholz and Stephens 1987. - """ - A2akN = 0.0 - Z_ssorted_left = np.searchsorted(Z, Zstar, "left").astype(np.float32) - if N == Zstar.size: - lj = np.ones_like(Z_ssorted_left) - else: - lj = np.searchsorted(Z, Zstar, "right").astype(np.float32) - Z_ssorted_left - Bj = Z_ssorted_left + lj / 2.0 - for i in np.arange(0, k): - s = np.sort(samples[i]) - s_ssorted_right = np.searchsorted(s, Zstar, side="right").astype(np.float32) - Mij = s_ssorted_right - fij = s_ssorted_right - np.searchsorted(s, Zstar, "left").astype(np.float32) - Mij -= fij / 2.0 - inner = lj / numba.float32(N) * (N * Mij - Bj * n[i]) ** 2 / (Bj * (N - Bj) - N * lj / 4.0) - A2akN += inner.sum() / n[i] - A2akN *= (N - 1.0) / N - return A2akN - - -@numba.njit -def _anderson_ksamp_right( - samples: Sequence[np.ndarray], Z: np.ndarray, Zstar: np.ndarray, k: int, n: np.ndarray, N: int -): - """Compute A2akN equation 6 of Scholz & Stephens. - - Parameters - ---------- - samples - Array of sample arrays. - Z - Sorted array of all observations. - Zstar - Sorted array of unique observations. - k - Number of samples. - n - Number of observations in each sample. - N - Total number of observations. - - Returns - ------- - A2KN - The A2KN statistics of Scholz and Stephens 1987. - """ - A2kN = 0.0 - lj = np.searchsorted(Z, Zstar[:-1], "right") - np.searchsorted(Z, Zstar[:-1], "left") - Bj = lj.cumsum() - for i in np.arange(0, k): - s = np.sort(samples[i]) - Mij = np.searchsorted(s, Zstar[:-1], side="right") - inner = lj / float(N) * (N * Mij - Bj * n[i]) ** 2 / (Bj * (N - Bj)) - A2kN += inner.sum() / n[i] - return A2kN - - -@numba.njit -def anderson_ksamp(samples: Sequence[np.ndarray], midrank: bool = True) -> Tuple[float, np.ndarray, float]: - """Numba-friendly implementation of :func:`scipy.stats.anderson_ksamp`. - - The k-sample Anderson-Darling test is a modification of the - one-sample Anderson-Darling test. It tests the null hypothesis - that k-samples are drawn from the same population without having - to specify the distribution function of that population. The - critical values depend on the number of samples. - - Parameters - ---------- - samples - Array of sample data in arrays. - midrank - Type of Anderson-Darling test which is computed. Default - (True) is the midrank test applicable to continuous and - discrete populations. If False, the right side empirical - distribution is used. - - Returns - ------- - result, tuple of (statistic, critical_values, significance_level) - - Raises - ------ - ValueError - If less than 2 samples are provided, a sample is empty, or no - distinct observations are in the samples. - - Notes - ----- - [1]_ defines three versions of the k-sample Anderson-Darling test: - one for continuous distributions and two for discrete - distributions, in which ties between samples may occur. The - default of this routine is to compute the version based on the - midrank empirical distribution function. This test is applicable - to continuous and discrete data. If midrank is set to False, the - right side empirical distribution is used for a test for discrete - data. According to [1]_, the two discrete test statistics differ - only slightly if a few collisions due to round-off errors occur in - the test not adjusted for ties between samples. - The critical values corresponding to the significance levels from 0.01 - to 0.25 are taken from [1]_. p-values are floored / capped - at 0.1% / 25%. Since the range of critical values might be extended in - future releases, it is recommended not to test ``p == 0.25``, but rather - ``p >= 0.25`` (analogously for the lower bound). - - References - ---------- - .. [1] Scholz, F. W and Stephens, M. A. (1987), K-Sample - Anderson-Darling Tests, Journal of the American Statistical - Association, Vol. 82, pp. 918-924. - """ - k = len(samples) - if k < 2: - raise ValueError("anderson_ksamp needs at least two samples") - - # join all samples into one long sample - samples = list(map(np.asarray, samples)) - n_samples = sum(list(map(len, samples))) - long_samples = np.empty(n_samples) - i = 0 - for sample in samples: - for s in sample: - long_samples[i] = s - i += 1 - Z = np.sort(long_samples) - N = np.array(Z.size) - Zstar = np.unique(Z) - if Zstar.size < 2: - raise ValueError("anderson_ksamp needs more than one distinct " "observation") - - n = np.array([sample.size for sample in samples]) - if np.any(n == 0): - raise ValueError("anderson_ksamp encountered sample without " "observations") - - if midrank: - A2kN = _anderson_ksamp_midrank(samples, Z, Zstar, k, n, N) - else: - A2kN = _anderson_ksamp_right(samples, Z, Zstar, k, n, N) - - H = (1.0 / n).sum() - hs_cs = (1.0 / np.arange(N - 1, 1, -1)).cumsum() - h = hs_cs[-1] + 1 - g = (hs_cs / np.arange(2, N)).sum() - - a = (4 * g - 6) * (k - 1) + (10 - 6 * g) * H - b = (2 * g - 4) * k**2 + 8 * h * k + (2 * g - 14 * h - 4) * H - 8 * h + 4 * g - 6 - c = (6 * h + 2 * g - 2) * k**2 + (4 * h - 4 * g + 6) * k + (2 * h - 6) * H + 4 * h - d = (2 * h + 6) * k**2 - 4 * h * k - sigmasq = (a * N**3 + b * N**2 + c * N + d) / ((N - 1.0) * (N - 2.0) * (N - 3.0)) - m = k - 1 - A2 = (A2kN - m) / np.sqrt(sigmasq) - - # The b_i values are the interpolation coefficients from Table 2 - # of Scholz and Stephens 1987 - b0 = np.array([0.675, 1.281, 1.645, 1.96, 2.326, 2.573, 3.085]) - b1 = np.array([-0.245, 0.25, 0.678, 1.149, 1.822, 2.364, 3.615]) - b2 = np.array([-0.105, -0.305, -0.362, -0.391, -0.396, -0.345, -0.154]) - critical = b0 + b1 / np.sqrt(m) + b2 / m - - sig = np.array([0.25, 0.1, 0.05, 0.025, 0.01, 0.005, 0.001]) - if A2 < critical.min(): - p = sig.max() - elif A2 > critical.max(): - p = sig.min() - else: - # interpolation of probit of significance level - pf = np.polyfit(critical, np.log(sig), 2) - p = np.exp(np.polyval(pf, A2)) - - res = (A2, critical, p) - return res diff --git a/tests/utils/test_anderson.py b/tests/utils/test_anderson.py deleted file mode 100644 index 59a2e9e..0000000 --- a/tests/utils/test_anderson.py +++ /dev/null @@ -1,16 +0,0 @@ -import jax.numpy as jnp -import numpy as np -from scipy import stats - -from scib_metrics.utils import anderson_ksamp - - -def test_anderson_vs_scipy(): - """Test that the Anderson-Darling test gives the same results as scipy.stats""" - rng = np.random.default_rng() - data = [rng.normal(size=50), rng.normal(loc=0.5, size=30)] - orig_res = stats.anderson_ksamp(data) - jax_res = anderson_ksamp([jnp.asarray(d) for d in data]) - assert np.isclose(orig_res.statistic, jax_res.statistic) - assert np.allclose(orig_res.critical_values, jax_res.critical_values) - assert np.isclose(orig_res.significance_level, jax_res.significance_level)