import functools as fct
import os
import warnings
from typing import List, Optional, Tuple, Union
import cv2
import cvxpy as cvx
import dask as da
import dask.array as darr
import networkx as nx
import numba as nb
import numpy as np
import pandas as pd
import pyfftw.interfaces.numpy_fft as numpy_fft
import pymetis
import scipy.sparse
import sparse
import xarray as xr
import zarr
from distributed import get_client
from scipy.linalg import lstsq, toeplitz
from scipy.ndimage import label
from scipy.signal import butter, lfilter, welch
from scipy.sparse import dia_matrix
from skimage import morphology as moph
from sklearn.linear_model import LassoLars
from statsmodels.tsa.stattools import acovf
from .utilities import (
custom_arr_optimize,
custom_delay_optimize,
open_minian,
rechunk_like,
save_minian,
med_baseline,
)
[docs]def get_noise_fft(
varr: xr.DataArray, noise_range=(0.25, 0.5), noise_method="logmexp"
) -> xr.DataArray:
"""
Estimates noise along the "frame" dimension aggregating power spectral
density within `noise_range`.
This function compute a Fast Fourier transform (FFT) along the "frame"
dimension in a vectorized fashion, and estimate noise by aggregating its
power spectral density (PSD). Note that `noise_range` is specified relative
to the sampling frequency, so 0.5 represents the Nyquist frequency. Three
`noise_method` are availabe for aggregating the psd: "mean" and "median"
will use the mean and median across all frequencies as the estimation of
noise. "logmexp" takes the mean of the logarithmic psd, then transform it
back with an exponential function.
Parameters
----------
varr : xr.DataArray
Input data, should have a "frame" dimension.
noise_range : tuple, optional
Range of noise frequency to be aggregated as a fraction of sampling
frequency. By default `(0.25, 0.5)`.
noise_method : str, optional
Method of aggreagtion for noise. Should be one of `"mean"` `"median"`
`"logmexp"` or `"sum"`. By default `"logmexp"`.
Returns
-------
sn : xr.DataArray
Spectral density of the noise. Same shape as `varr` with the "frame"
dimension removed.
"""
try:
clt = get_client()
threads = min(clt.nthreads().values())
except ValueError:
threads = 1
sn = xr.apply_ufunc(
noise_fft,
varr,
input_core_dims=[["frame"]],
output_core_dims=[[]],
dask="parallelized",
vectorize=True,
kwargs=dict(
noise_range=noise_range, noise_method=noise_method, threads=threads
),
output_dtypes=[np.float],
)
return sn
[docs]def noise_fft(
px: np.ndarray, noise_range=(0.25, 0.5), noise_method="logmexp", threads=1
) -> float:
"""
Estimates noise of the input by aggregating power spectral density within
`noise_range`.
The PSD is estimated using FFT.
Parameters
----------
px : np.ndarray
Input data.
noise_range : tuple, optional
Range of noise frequency to be aggregated as a fraction of sampling
frequency. By default `(0.25, 0.5)`.
noise_method : str, optional
Method of aggreagtion for noise. Should be one of `"mean"` `"median"`
`"logmexp"` or `"sum"`. By default "logmexp".
threads : int, optional
Number of threads to use for pyfftw. By default `1`.
Returns
-------
noise : float
The estimated noise level of input.
See Also
-------
get_noise_fft
"""
_T = len(px)
nr = np.around(np.array(noise_range) * _T).astype(int)
px = 1 / _T * np.abs(numpy_fft.rfft(px, threads=threads)[nr[0] : nr[1]]) ** 2
if noise_method == "mean":
return np.sqrt(px.mean())
elif noise_method == "median":
return np.sqrt(px.median())
elif noise_method == "logmexp":
eps = np.finfo(px.dtype).eps
return np.sqrt(np.exp(np.log(px + eps).mean()))
elif noise_method == "sum":
return np.sqrt(px.sum())
[docs]def get_noise_welch(
varr: xr.DataArray, noise_range=(0.25, 0.5), noise_method="logmexp"
) -> xr.DataArray:
"""
Estimates noise along the "frame" dimension aggregating power spectral
density within `noise_range`.
The PSD is estimated using welch method as an alternative to FFT. The welch
method assumes the noise in the signal to be a stochastic process and
attenuates noise by windowing the original signal into segments and
averaging over them.
Parameters
----------
varr : xr.DataArray
Input data. Should have a "frame" dimension.
noise_range : tuple, optional
Range of noise frequency to be aggregated as a fraction of sampling
frequency. By default `(0.25, 0.5)`.
noise_method : str, optional
Method of aggreagtion for noise. Should be one of `"mean"` `"median"`
`"logmexp"` or `"sum"`. By default `"logmexp"`.
Returns
-------
sn : xr.DataArray
Spectral density of the noise. Same shape as `varr` with the "frame"
dimension removed.
See Also
-------
get_noise_fft : For more details on the parameters.
"""
sn = xr.apply_ufunc(
noise_welch,
varr.chunk(dict(frame=-1)),
input_core_dims=[["frame"]],
dask="parallelized",
vectorize=True,
kwargs=dict(noise_range=noise_range, noise_method=noise_method),
output_dtypes=[varr.dtype],
)
return sn
[docs]def noise_welch(
y: np.ndarray, noise_range=(0.25, 0.5), noise_method="logmexp"
) -> float:
"""
Estimates noise of the input by aggregating power spectral density within
`noise_range`.
The PSD is estimated using welch method.
Parameters
----------
px : np.ndarray
Input data.
noise_range : tuple, optional
Range of noise frequency to be aggregated as a fraction of sampling
frequency. By default `(0.25, 0.5)`.
noise_method : str, optional
Method of aggreagtion for noise. Should be one of `"mean"` `"median"`
`"logmexp"` or `"sum"`. By default `"logmexp"`.
threads : int, optional
Number of threads to use for pyfftw. By default `1`.
Returns
-------
noise : float
The estimated noise level of input.
See Also
-------
get_noise_welch
"""
ff, Pxx = welch(y)
mask0, mask1 = ff > noise_range[0], ff < noise_range[1]
mask = np.logical_and(mask0, mask1)
Pxx_ind = Pxx[mask]
sn = {
"mean": lambda x: np.sqrt(np.mean(x / 2)),
"median": lambda x: np.sqrt(np.median(x / 2)),
"logmexp": lambda x: np.sqrt(np.exp(np.mean(np.log(x / 2)))),
}[noise_method](Pxx_ind)
return sn
[docs]def update_spatial(
Y: xr.DataArray,
A: xr.DataArray,
C: xr.DataArray,
sn: xr.DataArray,
b: xr.DataArray = None,
f: xr.DataArray = None,
dl_wnd=5,
sparse_penal=0.5,
update_background=False,
normalize=True,
size_thres=(9, None),
in_memory=False,
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray]:
"""
Update spatial components given the input data and temporal dynamic for each
cell.
This function carries out spatial update of the CNMF algorithm. The update
is done in parallel and independently for each pixel. To save computation
time, we compute a subsetting matrix `sub` by dilating the initial
spatial foorprint of each cell. The window size of the dilation is
controled by `dl_wnd`. Then for each pixel, only cells that have a non-zero
value in `sub` at the current pixel will be considered for update.
Optionally, the spatial footprint of the background can be updated in the
same fashion based on the temporal dynamic of the background. After the
update, the spatial footprint of each cell can be optionally noramlized to
unit sum, so that difference in fluorescent intensity will not be reflected
in spatial footprint. A `size_thres` can be passed in to filter out cells
whose size (number of non-zero values in spatial footprint) is outside the
specified range. Finally, the temporal dynamic of cells `C` can either be
load in memory before the update or lazy-loaded during the update. Note that
if `in_memory` is `False`, then `C` must be stored under the intermediate
folder specified as environment variable `MINIAN_INTERMEDIATE`.
Parameters
----------
Y : xr.DataArray
Input movie data. Should have dimensions "height", "width" and "frame".
A : xr.DataArray
Previous estimation of spatial footprints. Should have dimension
"height", "width" and "unit_id".
C : xr.DataArray
Estimation of temporal component for each cell. Should have dimension
"frame" and "unit_id".
sn : xr.DataArray
Estimation of noise level for each pixel. Should have dimension "height"
and "width".
b : xr.DataArray, optional
Previous estimation of spatial footprint of background. Fhould have
dimension "height" and "width".
f : xr.DataArray, optional
Estimation of temporal dynamic of background. Should have dimension
"frame".
dl_wnd : int, optional
Window of morphological dilation in pixel when computing the subsetting
matrix. By default `5`.
sparse_penal : float, optional
Global scalar controlling sparsity of the result. The higher the value,
the sparser the spatial footprints. By default `0.5`.
update_background : bool, optional
Whether to update the spatial footprint of background. If `True`, then
both `b` and `f` need to be provided. By default `False`.
normalize : bool, optional
Whether to normalize resulting spatial footprints of each cell to unit
sum. By default `True`
size_thres : tuple, optional
The range of size in pixel allowed for the resulting spatial footprints.
If `None`, then no filtering will be done. By default `(9, None)`.
in_memory : bool, optional
Whether to load `C` into memory before spatial update. By default
`False`.
Returns
-------
A_new : xr.DataArray
New estimation of spatial footprints. Same shape as `A` except the
"unit_id" dimension might be smaller due to filtering.
mask : xr.DataArray
Boolean mask of whether a cell passed size filtering. Has dimension
"unit_id" that is same as input `A`. Useful for subsetting other
variables based on the result of spatial update.
b_new : xr.DataArray
New estimation of spatial footprint of background. Only returned if
`update_background` is `True`. Same shape as `b`.
norm_fac : xr.DataArray
Normalizing factor. Userful to scale temporal activity of cells. Only
returned if `normalize` is `True`.
Notes
-------
During spatial update, the algorithm solve the following optimization
problem for each pixel:
.. math::
\\begin{aligned}
& \\underset{\mathbf{a}}{\\text{minimize}}
& & \\left \\lVert \mathbf{y} - \mathbf{a}^T \mathbf{C} \\right \\rVert
^2 + \\alpha \\left \\lvert \mathbf{a} \\right \\rvert \\\\
& \\text{subject to} & & \mathbf{a} \geq 0
\\end{aligned}
Where :math:`\mathbf{y}` is the fluorescent dynamic of the pixel,
:math:`\mathbf{a}` is spatial footprint values across all cells on that
pixel, :math:`\mathbf{C}` is temporal component matrix across all cells. The
parameter :math:`\\alpha` is the product of the noise level on each pixel
`sn` and the global scalar `sparse_penal`. Higher value of :math:`\\alpha`
will result in more sparse estimation of spatial footprints.
"""
intpath = os.environ["MINIAN_INTERMEDIATE"]
if in_memory:
C_store = C.compute().values
else:
C_path = os.path.join(intpath, C.name + ".zarr", C.name)
C_store = zarr.open_array(C_path)
print("estimating penalty parameter")
alpha = sparse_penal * sn
alpha = rechunk_like(alpha.compute(), sn)
print("computing subsetting matrix")
selem = moph.disk(dl_wnd)
sub = xr.apply_ufunc(
cv2.dilate,
A,
input_core_dims=[["height", "width"]],
output_core_dims=[["height", "width"]],
vectorize=True,
kwargs=dict(kernel=selem),
dask="parallelized",
output_dtypes=[A.dtype],
)
sub = sub > 0
sub.data = sub.data.map_blocks(sparse.COO)
if update_background:
assert b is not None, "`b` must be provided when updating background"
assert f is not None, "`f` must be provided when updating background"
b_in = rechunk_like(b > 0, Y).assign_coords(unit_id=-1).expand_dims("unit_id")
b_in.data = b_in.data.map_blocks(sparse.COO)
b_in = b_in.compute()
sub = xr.concat([sub, b_in], "unit_id")
f_in = f.compute().data
else:
f_in = None
sub = rechunk_like(sub.transpose("height", "width", "unit_id").compute(), Y)
print("fitting spatial matrix")
ssub = darr.map_blocks(
sps_any,
sub.data,
drop_axis=2,
chunks=((1, 1)),
meta=sparse.ones(1).astype(bool),
).compute()
Y_trans = Y.transpose("height", "width", "frame")
# take fast route if a lot of chunks are empty
if ssub.sum() < 500:
A_new = np.empty(sub.data.numblocks, dtype=object)
for (hblk, wblk), has_unit in np.ndenumerate(ssub):
cur_sub = sub.data.blocks[hblk, wblk, :]
if has_unit:
cur_blk = update_spatial_block(
Y_trans.data.blocks[hblk, wblk, :],
alpha.data.blocks[hblk, wblk],
cur_sub,
C_store=C_store,
f=f_in,
)
else:
cur_blk = darr.array(sparse.zeros((cur_sub.shape)))
A_new[hblk, wblk, 0] = cur_blk
A_new = darr.block(A_new.tolist())
else:
A_new = update_spatial_block(
Y_trans.data,
alpha.data,
sub.data,
C_store=C_store,
f=f_in,
)
with da.config.set(**{"optimization.fuse.ave-width": 6}):
A_new = da.optimize(A_new)[0]
A_new = xr.DataArray(
darr.moveaxis(A_new, -1, 0).map_blocks(lambda a: a.todense(), dtype=A.dtype),
dims=["unit_id", "height", "width"],
coords={
"unit_id": sub.coords["unit_id"],
"height": A.coords["height"],
"width": A.coords["width"],
},
)
A_new = save_minian(
A_new.rename("A_new"),
intpath,
overwrite=True,
chunks={"unit_id": 1, "height": -1, "width": -1},
)
add_rets = []
if update_background:
b_new = A_new.sel(unit_id=-1).compute()
A_new = A_new[:-1, :, :]
add_rets.append(b_new)
if size_thres:
low, high = size_thres
A_bin = A_new > 0
mask = np.ones(A_new.sizes["unit_id"], dtype=bool)
if low:
mask = np.logical_and(
(A_bin.sum(["height", "width"]) > low).compute(), mask
)
if high:
mask = np.logical_and(
(A_bin.sum(["height", "width"]) < high).compute(), mask
)
mask = xr.DataArray(
mask, dims=["unit_id"], coords={"unit_id": A_new.coords["unit_id"].values}
)
else:
mask = (A_new.sum(["height", "width"]) > 0).compute()
print("{} out of {} units dropped".format(len(mask) - mask.sum().values, len(mask)))
A_new = A_new.sel(unit_id=mask)
if normalize:
norm_fac = A_new.max(["height", "width"]).compute()
A_new = A_new / norm_fac
add_rets.append(norm_fac)
return (A_new, mask, *add_rets)
[docs]def sps_any(x: sparse.COO) -> np.ndarray:
"""
Compute `any` on a sparse array.
Parameters
----------
x : sparse.COO
Input sparse array.
Returns
-------
x_any : np.ndarray
2d boolean numpy array.
"""
return np.atleast_2d(x.nnz > 0)
[docs]def update_spatial_perpx(
y: np.ndarray,
alpha: float,
sub: sparse.COO,
C_store: Union[np.ndarray, zarr.core.Array],
f: Optional[np.ndarray],
) -> sparse.COO:
"""
Update spatial footprints across all the cells for a single pixel.
This function use :class:`sklearn.linear_model.LassoLars` to solve the
optimization problem. `C_store` can either be a in-memory numpy array, or a
zarr array, in which case it will be lazy-loaded. If `f` is not `None`, then
`sub[-1]` is expected to be the subsetting mask for background, and the last
element of the return value will be the spatial footprint of background.
Parameters
----------
y : np.ndarray
Input fluorescent trace for the given pixel.
alpha : float
Parameter of the optimization problem controlling sparsity.
sub : sparse.COO
Subsetting matrix.
C_store : Union[np.ndarray, zarr.core.Array]
Estimation of temporal dynamics of cells.
f : np.ndarray, optional
Temporal dynamic of background.
Returns
-------
A_px : sparse.COO
Spatial footprint values across all cells for the given pixel.
See Also
-------
update_spatial : for more explanation of parameters
"""
if f is not None:
idx = sub[:-1].nonzero()[0]
else:
idx = sub.nonzero()[0]
try:
C = C_store.get_orthogonal_selection((idx, slice(None))).T
except AttributeError:
C = C_store[idx, :].T
if (f is not None) and sub[-1]:
C = np.concatenate([C, f.reshape((-1, 1))], axis=1)
idx = np.concatenate([idx, np.array(len(sub) - 1).reshape(-1)])
clf = LassoLars(alpha=alpha, positive=True)
coef = clf.fit(C, y).coef_
mask = coef > 0
coef = coef[mask]
idx = idx[mask]
return sparse.COO(coords=idx, data=coef, shape=sub.shape)
[docs]@darr.as_gufunc(signature="(f),(),(u)->(u)", output_dtypes=float)
def update_spatial_block(
y: np.ndarray, alpha: np.ndarray, sub: sparse.COO, **kwargs
) -> sparse.COO:
"""
Carry out spatial update for each 3d block of data.
This function wraps around :func:`update_spatial_perpx` so that it can be
applied to 3d blocks of data. Keyword arguments are passed to
:func:`update_spatial_perpx`.
Parameters
----------
y : np.ndarray
Input data, should have dimension (height, width, frame).
alpha : np.ndarray
Alpha parameter for the optimization problem. Should have dimension
(height, width).
sub : sparse.COO
Subsetting matrix. Should have dimension (height, width, unit_id).
Returns
-------
A_blk : sparse.COO
Resulting spatial footprints. Should have dimension (height, width,
unit_id).
See Also
-------
update_spatial_perpx
update_spatial
"""
C_store = kwargs.get("C_store")
f = kwargs.get("f")
crd_ls = []
data_ls = []
for h, w in zip(*sub.any(axis=-1).nonzero()):
res = update_spatial_perpx(y[h, w, :], alpha[h, w], sub[h, w, :], C_store, f)
crd = res.coords
crd = np.concatenate([np.full_like(crd, h), np.full_like(crd, w), crd], axis=0)
crd_ls.append(crd)
data_ls.append(res.data)
if data_ls:
return sparse.COO(
coords=np.concatenate(crd_ls, axis=1),
data=np.concatenate(data_ls),
shape=sub.shape,
)
else:
return sparse.zeros(sub.shape)
[docs]def compute_trace(
Y: xr.DataArray, A: xr.DataArray, b: xr.DataArray, C: xr.DataArray, f: xr.DataArray
) -> xr.DataArray:
"""
Compute the residule traces `YrA` for each cell.
`YrA` is computed as `C + A_norm(YtA - CtA)`, where `YtA` is `(Y -
b.dot(f)).tensordot(A, ["height", "width"])`, representing the projection of
background-subtracted movie onto the spatial footprints, and `CtA` is
`C.dot(AtA, ["unit_id"])` with `AtA = A.tensordot(A, ["height", "width"])`,
hence `CtA` represent for each cell the sum of temporal activities that's
shared with any other cells, then finally `A_norm` is a "unit_id"x"unit_id"
diagonal matrix that normalize the result with sum of squares of spatial
footprints for each cell. Together, the `YrA` trace is a "unit_id"x"frame"
matrix, representing the sum of previous temporal components and the
residule temporal fluctuations as estimated by projecting the data onto the
spatial footprints and subtracting the cross-talk fluctuations.
Parameters
----------
Y : xr.DataArray
Input movie data. Should have dimensions ("frame", "height", "width").
A : xr.DataArray
Spatial footprints of cells. Should have dimensions ("unit_id", "height",
"width").
b : xr.DataArray
Spatial footprint of background. Should have dimensions ("height", "width").
C : xr.DataArray
Temporal components of cells. Should have dimensions ("frame", "unit_id").
f : xr.DataArray
Temporal dynamic of background. Should have dimension "frame".
Returns
-------
YrA : xr.DataArray
Residule traces for each cell. Should have dimensions("frame", "unit_id").
"""
fms = Y.coords["frame"]
uid = A.coords["unit_id"]
Y = Y.data
A = darr.from_array(A.data.map_blocks(sparse.COO).compute(), chunks=-1)
C = C.data.map_blocks(sparse.COO).T
b = (
b.fillna(0)
.data.map_blocks(sparse.COO)
.reshape((1, Y.shape[1], Y.shape[2]))
.compute()
)
f = f.fillna(0).data.reshape((-1, 1))
AtA = darr.tensordot(A, A, axes=[(1, 2), (1, 2)]).compute()
A_norm = (
(1 / (A ** 2).sum(axis=(1, 2)))
.map_blocks(
lambda a: sparse.diagonalize(sparse.COO(a)), chunks=(A.shape[0], A.shape[0])
)
.compute()
)
B = darr.tensordot(f, b, axes=[(1), (0)])
Y = Y - B
YtA = darr.tensordot(Y, A, axes=[(1, 2), (1, 2)])
YtA = darr.dot(YtA, A_norm)
CtA = darr.dot(C, AtA)
CtA = darr.dot(CtA, A_norm)
YrA = (YtA - CtA + C).clip(0)
arr_opt = fct.partial(
custom_arr_optimize,
inline_patterns=["from-getitem-transpose"],
rename_dict={"tensordot": "tensordot_restricted"},
)
with da.config.set(array_optimize=arr_opt):
YrA = da.optimize(YrA)[0]
YrA = xr.DataArray(
YrA,
dims=["frame", "unit_id"],
coords={"frame": fms, "unit_id": uid},
)
return YrA.transpose("unit_id", "frame")
[docs]def update_temporal(
A: xr.DataArray,
C: xr.DataArray,
b: Optional[xr.DataArray] = None,
f: Optional[xr.DataArray] = None,
Y: Optional[xr.DataArray] = None,
YrA: Optional[xr.DataArray] = None,
noise_freq=0.25,
p=2,
add_lag="p",
jac_thres=0.1,
sparse_penal=1,
bseg: Optional[np.ndarray] = None,
med_wd: Optional[int] = None,
zero_thres=1e-8,
max_iters=200,
use_smooth=True,
normalize=True,
warm_start=False,
post_scal=False,
scs_fallback=False,
concurrent_update=False,
) -> Tuple[
xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray
]:
"""
Update temporal components and deconvolve calcium traces for each cell given
spatial footprints.
This function carries out temporal update of the CNMF algorithm. The update
is done in parallel and independently for each group of cells. The grouping
of cells is controlled by `jac_thres`. The relationship between calcium and
deconvolved spikes is modeled as an Autoregressive process (AR) of order
`p`. The AR coefficients are estimated from autocovariances of `YrA` traces
for each cell, with `add_lag` controls how many timesteps of autocovariances
are used. Optionally, the `YrA` traces can be smoothed for the estimation of
AR coefficients only. The noise level for each cell is estimated using FFT
with `noise_freq` as cut-off, and controls the sparsity of the result
together with the global `sparse_penal` parameter. `YrA` traces for each
cells can be optionally normalized to unit sum to make `sparse_penal` to
have comparable effects across cells. If abrupt change of baseline
fluorescence is expected, a `bseg` vector can be passed to enable estimation
of independent baseline for different segments of time. The temporal update
itself is performed by solving an optimization problem using `cvxpy`, with
`concurrent_update`, `warm_start`, `max_iters`, `scs_fallback` controlling
different aspects of the optimization. Finally, the results can be filtered
with `zero_thres` to suppress small values caused by numerical errors, and a
post-hoc scaling process can be optionally used to scale the result based on
`YrA` to get around unwanted effects from sparse penalty or normalization.
Parameters
----------
A : xr.DataArray
Estimation of spatial footprints for each cell. Should have dimensions
("unit_id", "height", "width").
C : xr.DataArray
Previous estimation of calcium dynamic of cells. Should have dimensions
("frame", "unit_id"). Only used if `warm_start = True` or if `YrA is
None`.
b : xr.DataArray, optional
Estimation of spatial footprint of background. Should have dimensions
("height", "width"). Only used if `YrA is None`. By default `None`.
f : xr.DataArray, optional
Estimation of temporal dynamic of background. Should have dimension
"frame". Only used if `YrA is None`. By default `None`.
Y : xr.DataArray, optional
Input movie data. Should have dimensions ("frame", "height", "width").
Only used if `YrA is None`. By default `None`.
YrA : xr.DataArray, optional
Estimation of residule traces for each cell. Should have dimensions
("frame", "unit_id"). If `None` then one will be computed using
`computea_trace` with relevant inputs. By default `None`.
noise_freq : float, optional
Frequency cut-off for both the estimation of noise level and the
optional smoothing, specified as a fraction of sampling frequency. By
default `0.25`.
p : int, optional
Order of the AR process. By default `2`.
add_lag : str, optional
Additional number of timesteps in covariance to use for the estimation
of AR coefficients. If `0`, then only the first `p` number of timesteps
will be used to estimate the `p` number of AR coefficients. If greater
than `0`, then the system is over-determined and least square will be
used to estimate AR coefficients. If `"p"`, then `p` number of
additional timesteps will be used. By default `"p"`.
jac_thres : float, optional
Threshold for Jaccard Index. Cells whose overlap in spatial footprints
(number of common pixels divided by number of total pixels) exceeding
this threshold will be grouped together transitively for temporal
update. By default `0.1`.
sparse_penal : int, optional
Global scalar controlling sparsity of the result. The higher the value,
the sparser the deconvolved spikes. By default `1`.
bseg : np.ndarray, optional
1d vector with length "frame" representing segments for which baseline
should be estimated independently. An independent baseline will be
estimated for frames corresponding to each unique label in this vector.
If `None` then a single scalar baseline will be estimated for each cell.
By default `None`.
med_wd : int, optional
Window size for the median filter used for baseline correction. For each
cell, the baseline flurescence is estimated by median-filtering the
temporal activity. Then the baseline is subtracted from the temporal
activity right before the optimization step. If `None` then no baseline
correction will be performed. By default `None`.
zero_thres : float, optional
Threshold to filter out small values in the result. Any values smaller
than this threshold will be set to zero. By default `1e-8`.
max_iters : int, optional
Maximum number of iterations for optimization. Can be increased to get
around sub-optimal results. By default `200`.
use_smooth : bool, optional
Whether to smooth the `YrA` for the estimation of AR coefficients. If
`True`, then a smoothed version of `YrA` will be computed by low-pass
filter with `noise_freq` and used for the estimation of AR coefficients
only. By default `True`.
normalize : bool, optional
Whether to normalize `YrA` for each cell to unit sum such that sparse
penalty has simlar effect for all the cells. Each group of cell will be
normalized together (with mean of the sum for each cell) to preserve
relative amplitude of fluorescence between overlapping cells. By default
`True`.
warm_start : bool, optional
Whether to use previous estimation of `C` to warm start the
optimization. Can lead to faster convergence in theory. Experimental. By
default `False`.
post_scal : bool, optional
Whether to apply the post-hoc scaling process, where a scalar will be
estimated with least square for each cell to scale the amplitude of
temporal component to `YrA`. Useful to get around unwanted dampening of
result values caused by high `sparse_penal` or to revert the per-cell
normalization. By default `False`.
scs_fallback : bool, optional
Whether to fall back to `scs` solver if the default `ecos` solver fails.
By default `False`.
concurrent_update : bool, optional
Whether to update a group of cells as a single optimization problem.
Yields slightly more accurate estimation when cross-talk between cells
are severe, but significantly increase convergence time and memory
demand. By default `False`.
Returns
-------
C_new : xr.DataArray
New estimation of the calcium dynamic for each cell. Should have same
shape as `C` except the "unit_id" dimension might be smaller due to
dropping of cells and filtering.
S_new : xr.DataArray
New estimation of the deconvolved spikes for each cell. Should have
dimensions ("frame", "unit_id") and same shape as `C_new`.
b0_new : xr.DataArray
New estimation of baseline fluorescence for each cell. Should have
dimensions ("frame", "unit_id") and same shape as `C_new`. Each cell
should only have one unique value if `bseg is None`.
c0_new : xr.DataArray
New estimation of a initial calcium decay, in theory triggered by
calcium events happened before the recording starts. Should have
dimensions ("frame", "unit_id") and same shape as `C_new`.
g : xr.DataArray
Estimation of AR coefficient for each cell. Useful for visualizing
modeled calcium dynamic. Should have dimensions ("lag", "unit_id") with
"lag" having length `p`.
mask : xr.DataArray
Boolean mask of whether a cell has any temporal dynamic after the update
and optional filtering. Has dimension "unit_id" that is same as input
`C`. Useful for subsetting other variables based on the result of
temporal update.
Notes
-------
During temporal update, the algorithm solve the following optimization
problem for each cell:
.. math::
\\begin{aligned}
& \\underset{\mathbf{c} \, \mathbf{b_0} \,
\mathbf{c_0}}{\\text{minimize}}
& & \\left \\lVert \mathbf{y} - \mathbf{c} - \mathbf{c_0} -
\mathbf{b_0} \\right \\rVert ^2 + \\alpha \\left \\lvert \mathbf{G}
\mathbf{c} \\right \\rvert \\\\
& \\text{subject to}
& & \mathbf{c} \geq 0, \; \mathbf{G} \mathbf{c} \geq 0
\\end{aligned}
Where :math:`\mathbf{y}` is the estimated residule trace (`YrA`) for the
cell, :math:`\mathbf{c}` is the calcium dynamic of the cell,
:math:`\mathbf{G}` is a "frame"x"frame" matrix constructed from the
estimated AR coefficients of cell, such that the deconvolved spikes of the
cell is given by :math:`\mathbf{G}\mathbf{c}`. If `bseg is None`, then
:math:`\mathbf{b_0}` is a single scalar, otherwise it is a 1d vector with
dimension "frame" constrained to have multiple independent values, each
corresponding to a segment of time specified in `bseg`. :math:`\mathbf{c_0}`
is a 1d vector with dimension "frame" constrained to be the product of a
scalar (representing initial calcium concentration) and the decay dynamic
given by the estimated AR coefficients. The parameter :math:`\\alpha` is the
product of estimated noise level of the cell and the global scalar
`sparse_penal`. Higher value of :math:`\\alpha` will result in more sparse
estimation of deconvolved spikes.
"""
intpath = os.environ["MINIAN_INTERMEDIATE"]
if YrA is None:
YrA = compute_trace(Y, A, b, C, f).persist()
Ymask = (YrA > 0).any("frame").compute()
A, C, YrA = A.sel(unit_id=Ymask), C.sel(unit_id=Ymask), YrA.sel(unit_id=Ymask)
print("grouping overlaping units")
A_sps = (A.data.map_blocks(sparse.COO) > 0).compute().astype(np.float32)
A_inter = sparse.tensordot(A_sps, A_sps, axes=[(1, 2), (1, 2)])
A_usum = np.tile(A_sps.sum(axis=(1, 2)).todense(), (A_sps.shape[0], 1))
A_usum = A_usum + A_usum.T
jac = scipy.sparse.csc_matrix(A_inter / (A_usum - A_inter) > jac_thres)
unit_labels = label_connected(jac)
YrA = YrA.assign_coords(unit_labels=("unit_id", unit_labels))
print("updating temporal components")
c_ls = []
s_ls = []
b_ls = []
c0_ls = []
g_ls = []
uid_ls = []
grp_dim = "unit_labels"
C = C.assign_coords(unit_labels=("unit_id", unit_labels))
if warm_start:
C.data = C.data.map_blocks(scipy.sparse.csr_matrix)
inline_opt = fct.partial(
custom_delay_optimize,
inline_patterns=["getitem", "rechunk-merge"],
)
for cur_YrA, cur_C in zip(YrA.groupby(grp_dim), C.groupby(grp_dim)):
uid_ls.append(cur_YrA[1].coords["unit_id"].values.reshape(-1))
cur_YrA, cur_C = cur_YrA[1].data.rechunk(-1), cur_C[1].data.rechunk(-1)
# peak memory demand for cvxpy is roughly 500 times input
mem_cvx = cur_YrA.nbytes if concurrent_update else cur_YrA[0].nbytes
mem_cvx = mem_cvx * 500
mem_demand = max(mem_cvx, cur_YrA.nbytes * 5) / 1e6
# issue a warning if expected memory demand is larger than 1G
if mem_demand > 1e3:
warnings.warn(
"{} cells will be updated togeter, "
"which takes roughly {} MB of memory. "
"Consider merging the units "
"or changing jac_thres".format(cur_YrA.shape[0], mem_demand)
)
if not warm_start:
cur_C = None
if cur_YrA.shape[0] > 1:
dl_opt = inline_opt
else:
dl_opt = custom_delay_optimize
# explicitly using delay (rather than gufunc) seem to promote the
# depth-first behavior of dask
with da.config.set(delayed_optimize=dl_opt):
res = da.optimize(
da.delayed(update_temporal_block)(
cur_YrA,
noise_freq=noise_freq,
p=p,
add_lag=add_lag,
normalize=normalize,
concurrent=concurrent_update,
use_smooth=use_smooth,
c_last=cur_C,
bseg=bseg,
med_wd=med_wd,
sparse_penal=sparse_penal,
max_iters=max_iters,
scs_fallback=scs_fallback,
zero_thres=zero_thres,
)
)[0]
c_ls.append(darr.from_delayed(res[0], shape=cur_YrA.shape, dtype=cur_YrA.dtype))
s_ls.append(darr.from_delayed(res[1], shape=cur_YrA.shape, dtype=cur_YrA.dtype))
b_ls.append(darr.from_delayed(res[2], shape=cur_YrA.shape, dtype=cur_YrA.dtype))
c0_ls.append(
darr.from_delayed(res[3], shape=cur_YrA.shape, dtype=cur_YrA.dtype)
)
g_ls.append(
darr.from_delayed(res[4], shape=(cur_YrA.shape[0], p), dtype=cur_YrA.dtype)
)
uids_new = np.concatenate(uid_ls)
C_new = xr.DataArray(
darr.concatenate(c_ls, axis=0),
dims=["unit_id", "frame"],
coords={
"unit_id": uids_new,
"frame": YrA.coords["frame"],
},
name="C_new",
)
S_new = xr.DataArray(
darr.concatenate(s_ls, axis=0),
dims=["unit_id", "frame"],
coords={
"unit_id": uids_new,
"frame": YrA.coords["frame"].values,
},
name="S_new",
)
b0_new = xr.DataArray(
darr.concatenate(b_ls, axis=0),
dims=["unit_id", "frame"],
coords={
"unit_id": uids_new,
"frame": YrA.coords["frame"].values,
},
name="b0_new",
)
c0_new = xr.DataArray(
darr.concatenate(c0_ls, axis=0),
dims=["unit_id", "frame"],
coords={
"unit_id": uids_new,
"frame": YrA.coords["frame"].values,
},
name="c0_new",
)
g = xr.DataArray(
darr.concatenate(g_ls, axis=0),
dims=["unit_id", "lag"],
coords={"unit_id": uids_new, "lag": np.arange(p)},
name="g",
)
arr_opt = fct.partial(custom_arr_optimize, keep_patterns=["^update_temporal_block"])
with da.config.set(array_optimize=arr_opt):
da.compute(
[
save_minian(
var.chunk({"unit_id": 1}), intpath, compute=False, overwrite=True
)
for var in [C_new, S_new, b0_new, c0_new, g]
]
)
int_ds = open_minian(intpath, return_dict=True)
C_new, S_new, b0_new, c0_new, g = (
int_ds["C_new"],
int_ds["S_new"],
int_ds["b0_new"],
int_ds["c0_new"],
int_ds["g"],
)
mask = (S_new.sum("frame") > 0).compute()
print("{} out of {} units dropped".format((~mask).sum().values, len(Ymask)))
C_new, S_new, b0_new, c0_new, g = (
C_new[mask],
S_new[mask],
b0_new[mask],
c0_new[mask],
g[mask],
)
sig_new = C_new + b0_new + c0_new
sig_new = da.optimize(sig_new)[0]
YrA_new = YrA.sel(unit_id=mask)
if post_scal and len(sig_new) > 0:
print("post-hoc scaling")
scal = lstsq_vec(sig_new.data, YrA_new.data).compute().reshape((-1, 1))
C_new, S_new, b0_new, c0_new = (
C_new * scal,
S_new * scal,
b0_new * scal,
c0_new * scal,
)
return C_new, S_new, b0_new, c0_new, g, mask
[docs]@darr.as_gufunc(signature="(f),(f)->()", output_dtypes=float)
def lstsq_vec(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""
Estimate a least-square scaling from `a` to `b` in vectorized fashion.
Parameters
----------
a : np.ndarray
Source of the scaling.
b : np.ndarray
Target of the scaling.
Returns
-------
scale : np.ndarray
A scaler that scales `a` to `b`.
"""
a = a.reshape((-1, 1))
return np.linalg.lstsq(a, b.squeeze(), rcond=-1)[0]
[docs]def get_ar_coef(
y: np.ndarray, sn: float, p: int, add_lag: int, pad: Optional[int] = None
) -> np.ndarray:
"""
Estimate Autoregressive coefficients of order `p` given a timeseries `y`.
Parameters
----------
y : np.ndarray
Input timeseries.
sn : float
Estimated noise level of the input `y`.
p : int
Order of the autoregressive process.
add_lag : int
Additional number of timesteps of covariance to use for the estimation.
pad : int, optional
Length of the output. If not `None` then the resulting coefficients will
be zero-padded to this length. By default `None`.
Returns
-------
g : np.ndarray
The estimated AR coefficients.
"""
if add_lag == "p":
max_lag = p * 2
else:
max_lag = p + add_lag
cov = acovf(y, fft=True)
C_mat = toeplitz(cov[:max_lag], cov[:p]) - sn ** 2 * np.eye(max_lag, p)
g = lstsq(C_mat, cov[1 : max_lag + 1])[0]
if pad:
res = np.zeros(pad)
res[: len(g)] = g
return res
else:
return g
def get_p(y):
dif = np.append(np.diff(y), 0)
rising = dif > 0
prd_ris, num_ris = label(rising)
ext_prd = np.zeros(num_ris)
for id_prd in range(num_ris):
prd = y[prd_ris == id_prd + 1]
ext_prd[id_prd] = prd[-1] - prd[0]
id_max_prd = np.argmax(ext_prd)
return np.sum(rising[prd_ris == id_max_prd + 1])
[docs]def update_temporal_block(
YrA: np.ndarray,
noise_freq: float,
p: int,
add_lag="p",
normalize=True,
use_smooth=True,
med_wd=None,
concurrent=False,
**kwargs
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Update temporal components given residule traces of a group of cells.
This function wraps around :func:`update_temporal_cvxpy`, but also carry out
additional initial steps given `YrA` of a group of cells. Additional keyword
arguments are passed through to :func:`update_temporal_cvxpy`.
Parameters
----------
YrA : np.ndarray
Residule traces of a group of cells. Should have dimension ("unit_id",
"frame").
noise_freq : float
Frequency cut-off for both the estimation of noise level and the
optional smoothing. Specified as a fraction of sampling frequency.
p : int
Order of the AR process.
add_lag : str, optional
Additional number of timesteps in covariance to use for the estimation
of AR coefficients. By default "p".
normalize : bool, optional
Whether to normalize `YrA` for each cell to unit sum. By default `True`.
use_smooth : bool, optional
Whether to smooth the `YrA` for the estimation of AR coefficients. By
default `True`.
med_wd : int, optional
Median window used for baseline correction.
concurrent : bool, optional
Whether to update a group of cells as a single optimization problem. By
default `False`.
Returns
-------
c : np.ndarray
New estimation of the calcium dynamic of the group of cells. Should have
dimensions ("unit_id", "frame") and same shape as `YrA`.
s : np.ndarray
New estimation of the deconvolved spikes of the group of cells. Should
have dimensions ("unit_id", "frame") and same shape as `c`.
b : np.ndarray
New estimation of baseline fluorescence of the group of cells. Should
have dimensions ("unit_id", "frame") and same shape as `c`.
c0 : np.ndarray
New estimation of a initial calcium decay of the group of cells. Should
have dimensions ("unit_id", "frame") and same shape as `c`.
g : np.ndarray
Estimation of AR coefficient for each cell. Should have dimensions
("unit_id", "lag") with "lag" having length `p`.
See Also
-------
update_temporal : for more explanation of parameters
"""
vec_get_noise = np.vectorize(
noise_fft,
otypes=[float],
excluded=["noise_range", "noise_method"],
signature="(f)->()",
)
vec_get_p = np.vectorize(get_p, otypes=[int], signature="(f)->()")
vec_get_ar_coef = np.vectorize(
get_ar_coef,
otypes=[float],
excluded=["pad", "add_lag"],
signature="(f),(),()->(l)",
)
if normalize:
amean = YrA.sum(axis=1).mean()
norm_factor = YrA.shape[1] / amean
YrA *= norm_factor
else:
norm_factor = np.ones(YrA.shape[0])
tn = vec_get_noise(YrA, noise_range=(noise_freq, 1))
if use_smooth:
YrA_ar = filt_fft_vec(YrA, noise_freq, "low")
tn_ar = vec_get_noise(YrA_ar, noise_range=(noise_freq, 1))
else:
YrA_ar, tn_ar = YrA, tn
# auto estimation of p is disabled since it's never used and makes it
# impossible to pre-determine the shape of output
# if p is None:
# p = np.clip(vec_get_p(YrA_ar), 1, None)
pmax = np.max(p)
g = vec_get_ar_coef(YrA_ar, tn_ar, p, pad=pmax, add_lag=add_lag)
del YrA_ar, tn_ar
if med_wd is not None:
for i, cur_yra in enumerate(YrA):
YrA[i, :] = med_baseline(cur_yra, med_wd)
if concurrent:
c, s, b, c0 = update_temporal_cvxpy(YrA, g, tn, **kwargs)
else:
res_ls = []
for cur_yra, cur_g, cur_tn in zip(YrA, g, tn):
res = update_temporal_cvxpy(cur_yra, cur_g, cur_tn, **kwargs)
res_ls.append(res)
c = np.concatenate([r[0] for r in res_ls], axis=0) / norm_factor
s = np.concatenate([r[1] for r in res_ls], axis=0) / norm_factor
b = np.concatenate([r[2] for r in res_ls], axis=0) / norm_factor
c0 = np.concatenate([r[3] for r in res_ls], axis=0) / norm_factor
return c, s, b, c0, g
[docs]def update_temporal_cvxpy(
y: np.ndarray, g: np.ndarray, sn: np.ndarray, A=None, bseg=None, **kwargs
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Solve the temporal update optimization problem using `cvxpy`
Parameters
----------
y : np.ndarray
Input residule trace of one or more cells.
g : np.ndarray
Estimated AR coefficients of one or more cells.
sn : np.ndarray
Noise level of one or more cells.
A : np.ndarray, optional
Spatial footprint of one or more cells. Not used. By default `None`.
bseg : np.ndarray, optional
1d vector with length "frame" representing segments for which baseline
should be estimated independently. By default `None`.
Returns
-------
c : np.ndarray
New estimation of the calcium dynamic of the group of cells. Should have
dimensions ("unit_id", "frame") and same shape as `y`.
s : np.ndarray
New estimation of the deconvolved spikes of the group of cells. Should
have dimensions ("unit_id", "frame") and same shape as `c`.
b : np.ndarray
New estimation of baseline fluorescence of the group of cells. Should
have dimensions ("unit_id", "frame") and same shape as `c`.
c0 : np.ndarray
New estimation of a initial calcium decay of the group of cells. Should
have dimensions ("unit_id", "frame") and same shape as `c`.
Other Parameters
-------
sparse_penal : float
Sparse penalty parameter for all the cells.
max_iters : int
Maximum number of iterations.
use_cons : bool, optional
Whether to try constrained version of the problem first. By default
`False`.
scs_fallback : bool
Whether to fall back to `scs` solver if the default `ecos` solver fails.
c_last : np.ndarray, optional
Initial estimation of calcium traces for each cell used to warm start.
zero_thres : float
Threshold to filter out small values in the result.
See Also
-------
update_temporal : for more explanation of parameters
"""
# spatial:
# (d, f), (u, p), (d), (d, u)
# (d, f), (p), (d), (d)
# trace:
# (u, f), (u, p), (u)
# (f), (p), ()
# get_parameters
sparse_penal = kwargs.get("sparse_penal")
max_iters = kwargs.get("max_iters")
use_cons = kwargs.get("use_cons", False)
scs = kwargs.get("scs_fallback")
c_last = kwargs.get("c_last")
zero_thres = kwargs.get("zero_thres")
# conform variables to generalize multiple unit case
if y.ndim < 2:
y = y.reshape((1, -1))
if g.ndim < 2:
g = g.reshape((1, -1))
sn = np.atleast_1d(sn)
if A is not None:
if A.ndim < 2:
A = A.reshape((-1, 1))
# get count of frames and units
_T = y.shape[-1]
_u = g.shape[0]
if A is not None:
_d = A.shape[0]
# construct G matrix and decay vector per unit
dc_vec = np.zeros((_u, _T))
G_ls = []
for cur_u in range(_u):
cur_g = g[cur_u, :]
# construct first column and row
cur_c = np.zeros(_T)
cur_c[0] = 1
cur_c[1 : len(cur_g) + 1] = -cur_g
# update G with toeplitz matrix
G_ls.append(
cvx.Constant(
dia_matrix(
(
np.tile(np.concatenate(([1], -cur_g)), (_T, 1)).T,
-np.arange(len(cur_g) + 1),
),
shape=(_T, _T),
).tocsc()
)
)
# update dc_vec
cur_gr = np.roots(cur_c)
dc_vec[cur_u, :] = np.max(cur_gr.real) ** np.arange(_T)
# get noise threshold
thres_sn = sn * np.sqrt(_T)
# construct variables
if bseg is not None:
nseg = int(np.max(bseg) + 1)
b_temp = np.zeros((nseg, _T))
for iseg in range(nseg):
b_temp[iseg, bseg == iseg] = 1
b_cmp = cvx.Variable((_u, nseg))
else:
b_temp = np.ones((1, _T))
b_cmp = cvx.Variable((_u, 1))
b = b_cmp @ b_temp # baseline fluorescence per unit
c0 = cvx.Variable(_u) # initial fluorescence per unit
c = cvx.Variable((_u, _T)) # calcium trace per unit
if c_last is not None:
c.value = c_last
warm_start = True
else:
warm_start = False
s = cvx.vstack([G_ls[u] @ c[u, :] for u in range(_u)]) # spike train per unit
# residual noise per unit
if A is not None:
sig = cvx.vstack(
[
(A * c)[px, :] + (A * b)[px, :] + (A * cvx.diag(c0) * dc_vec)[px, :]
for px in range(_d)
]
)
noise = y - sig
else:
sig = cvx.vstack([c[u, :] + b[u, :] + c0[u] * dc_vec[u, :] for u in range(_u)])
noise = y - sig
noise = cvx.vstack([cvx.norm(noise[i, :], 2) for i in range(noise.shape[0])])
# construct constraints
cons = []
cons.append(
b >= np.broadcast_to(np.min(y, axis=-1).reshape((-1, 1)), y.shape)
) # baseline larger than minimum
cons.append(c0 >= 0) # initial fluorescence larger than 0
cons.append(s >= 0) # spike train non-negativity
# noise constraints
cons_noise = [noise[i] <= thres_sn[i] for i in range(thres_sn.shape[0])]
try:
obj = cvx.Minimize(cvx.sum(cvx.norm(s, 1, axis=1)))
prob = cvx.Problem(obj, cons + cons_noise)
if use_cons:
_ = prob.solve(solver="ECOS")
if not (prob.status == "optimal" or prob.status == "optimal_inaccurate"):
if use_cons:
warnings.warn("constrained version of problem infeasible")
raise ValueError
except (ValueError, cvx.SolverError):
lam = sn * sparse_penal
obj = cvx.Minimize(
cvx.sum(cvx.sum(noise, axis=1) + cvx.multiply(lam, cvx.norm(s, 1, axis=1)))
)
prob = cvx.Problem(obj, cons)
try:
_ = prob.solve(solver="ECOS", warm_start=warm_start, max_iters=max_iters)
if prob.status in ["infeasible", "unbounded", None]:
raise ValueError
except (cvx.SolverError, ValueError):
try:
if scs:
_ = prob.solve(solver="SCS", max_iters=200)
if prob.status in ["infeasible", "unbounded", None]:
raise ValueError
except (cvx.SolverError, ValueError):
warnings.warn(
"problem status is {}, returning zero".format(prob.status),
RuntimeWarning,
)
return [np.zeros(c.shape, dtype=float)] * 4
if not (prob.status == "optimal"):
warnings.warn("problem solved sub-optimally", RuntimeWarning)
c = np.where(c.value > zero_thres, c.value, 0)
s = np.where(s.value > zero_thres, s.value, 0)
b = np.where(b.value > zero_thres, b.value, 0)
c0 = c0.value.reshape((-1, 1)) * dc_vec
c0 = np.where(c0 > zero_thres, c0, 0)
return c, s, b, c0
[docs]def unit_merge(
A: xr.DataArray,
C: xr.DataArray,
add_list: Optional[List[xr.DataArray]] = None,
thres_corr=0.9,
noise_freq: Optional[float] = None,
) -> Tuple[xr.DataArray, xr.DataArray, Optional[List[xr.DataArray]]]:
"""
Merge cells given spatial footprints and temporal components
This function merge all cells that have common pixels based on correlation
of their temporal components. The cells to be merged will become one cell,
with spatial and temporal components taken as mean across all the cells to
be merged. Additionally any variables specified in `add_list` will be merged
in the same manner. Optionally the temporal components can be smoothed
before being used to caculate correlation. Despite the name any timeseries
be passed as `C` and used to calculate the correlation.
Parameters
----------
A : xr.DataArray
Spatial footprints of the cells.
C : xr.DataArray
Temporal component of cells.
add_list : List[xr.DataArray], optional
List of additional variables to be merged. By default `None`.
thres_corr : float, optional
The threshold of correlation. Any pair of spatially overlapping cells
with correlation higher than this threshold will be transitively grouped
together and merged. By default `0.9`.
noise_freq : float, optional
The cut-off frequency used to smooth `C` before calculation of
correlation. If `None` then no smoothing will be done. By default
`None`.
Returns
-------
A_merge : xr.DataArray
Merged spatial footprints of cells.
C_merge : xr.DataArray
Merged temporal components of cells.
add_list : List[xr.DataArray], optional
List of additional merged variables. Only returned if input `add_list`
is not `None`.
"""
print("computing spatial overlap")
with da.config.set(
array_optimize=darr.optimization.optimize,
**{"optimization.fuse.subgraphs": False}
):
A_sps = (A.data.map_blocks(sparse.COO) > 0).rechunk(-1).persist()
A_inter = sparse.tril(
darr.tensordot(
A_sps.astype(np.float32),
A_sps.astype(np.float32),
axes=[(1, 2), (1, 2)],
).compute(),
k=-1,
)
print("computing temporal correlation")
nod_df = pd.DataFrame({"unit_id": A.coords["unit_id"].values})
adj = adj_corr(C, A_inter, nod_df, freq=noise_freq)
print("labeling units to be merged")
adj = adj > thres_corr
adj = adj + adj.T
unit_labels = xr.apply_ufunc(
label_connected,
adj,
input_core_dims=[["unit_id", "unit_id_cp"]],
output_core_dims=[["unit_id"]],
)
print("merging units")
A_merge = (
A.assign_coords(unit_labels=("unit_id", unit_labels))
.groupby("unit_labels")
.mean("unit_id")
.rename(unit_labels="unit_id")
)
C_merge = (
C.assign_coords(unit_labels=("unit_id", unit_labels))
.groupby("unit_labels")
.mean("unit_id")
.rename(unit_labels="unit_id")
)
if add_list:
for ivar, var in enumerate(add_list):
var_mrg = (
var.assign_coords(unit_labels=("unit_id", unit_labels))
.groupby("unit_labels")
.mean("unit_id")
.rename(unit_labels="unit_id")
)
add_list[ivar] = var_mrg
return A_merge, C_merge, add_list
else:
return A_merge, C_merge
[docs]def label_connected(adj: np.ndarray, only_connected=False) -> np.ndarray:
"""
Label connected components given adjacency matrix.
Parameters
----------
adj : np.ndarray
Adjacency matrix. Should be 2d symmetric matrix.
only_connected : bool, optional
Whether to keep only the labels of connected components. If `True`, then
all components with only one node (isolated) will have their labels set
to -1. Otherwise all components will have unique label. By default
`False`.
Returns
-------
labels : np.ndarray
The labels for each components. Should have length `adj.shape[0]`.
"""
try:
np.fill_diagonal(adj, 0)
adj = np.triu(adj)
g = nx.convert_matrix.from_numpy_matrix(adj)
except:
g = nx.convert_matrix.from_scipy_sparse_matrix(adj)
labels = np.zeros(adj.shape[0], dtype=np.int)
for icomp, comp in enumerate(nx.connected_components(g)):
comp = list(comp)
if only_connected and len(comp) == 1:
labels[comp] = -1
else:
labels[comp] = icomp
return labels
[docs]def smooth_sig(
sig: xr.DataArray, freq: float, method="fft", btype="low"
) -> xr.DataArray:
"""
Filter the input timeseries with a cut-off frequency in vecorized fashion.
Parameters
----------
sig : xr.DataArray
The input timeseries. Should have dimension "frame".
freq : float
The cut-off frequency.
method : str, optional
Method used for filtering. Either `"fft"` or `"butter"`. If `"fft"`, the
filtering is carried out with zero-ing fft signal. If `"butter"`, the
fiilterings carried out with :func:`scipy.signal.butter`. By default
"fft".
btype : str, optional
Either `"low"` or `"high"` specify low or high pass filtering. By
default `"low"`.
Returns
-------
sig_smth : xr.DataArray
The filtered signal. Has same shape as input `sig`.
Raises
------
NotImplementedError
if `method` is not "fft" or "butter"
"""
try:
filt_func = {"fft": filt_fft, "butter": filt_butter}[method]
except KeyError:
raise NotImplementedError(method)
sig_smth = xr.apply_ufunc(
filt_func,
sig,
input_core_dims=[["frame"]],
output_core_dims=[["frame"]],
vectorize=True,
kwargs={"btype": btype, "freq": freq},
dask="parallelized",
output_dtypes=[sig.dtype],
)
return sig_smth
[docs]def filt_fft(x: np.ndarray, freq: float, btype: str) -> np.ndarray:
"""
Filter 1d timeseries by zero-ing bands in the fft signal.
Parameters
----------
x : np.ndarray
Input timeseries.
freq : float
Cut-off frequency.
btype : str
Either `"low"` or `"high"` specify low or high pass filtering.
Returns
-------
x_filt : np.ndarray
Filtered timeseries.
"""
_T = len(x)
if btype == "low":
zero_range = slice(int(freq * _T), None)
elif btype == "high":
zero_range = slice(None, int(freq * _T))
xfft = numpy_fft.rfft(x)
xfft[zero_range] = 0
return numpy_fft.irfft(xfft, len(x))
[docs]def filt_butter(x: np.ndarray, freq: float, btype: str) -> np.ndarray:
"""
Filter 1d timeseries with Butterworth filter using
:func:`scipy.signal.butter`.
Parameters
----------
x : np.ndarray
Input timeseries.
freq : float
Cut-off frequency.
btype : str
Either "low" or "high" specify low or high pass filtering.
Returns
-------
x_filt : np.ndarray
Filtered timeseries.
"""
but_b, but_a = butter(2, freq * 2, btype=btype, analog=False)
return lfilter(but_b, but_a, x)
[docs]def filt_fft_vec(x: np.ndarray, freq: float, btype: str) -> np.ndarray:
"""
Vectorized wrapper of :func:`filt_fft`.
Parameters
----------
x : np.ndarray
Input timeseries. Should have 2 dimensions, and the filtering will be
applied along the last dimension.
freq : float
Cut-off frequency.
btype : str
Either `"low"` or `"high"` specify low or high pass filtering.
Returns
-------
x_filt : np.ndarray
Filtered timeseries
"""
for ix, xx in enumerate(x):
x[ix, :] = filt_fft(xx, freq, btype)
return x
[docs]def compute_AtC(A: xr.DataArray, C: xr.DataArray) -> xr.DataArray:
"""
Compute the outer product of spatial and temporal components.
This funtion computes the outer product of spatial and temporal components.
The result is a 3d array representing the movie data as estimated by the
spatial and temporal components.
Parameters
----------
A : xr.DataArray
Spatial footprints of cells. Should have dimensions ("unit_id",
"height", "width").
C : xr.DataArray
Temporal components of cells. Should have dimensions "frame" and
"unit_id".
Returns
-------
AtC : xr.DataArray
The outer product representing estimated movie data. Has dimensions
("frame", "height", "width").
"""
fm, h, w = (
C.coords["frame"].values,
A.coords["height"].values,
A.coords["width"].values,
)
A = darr.from_array(
A.data.map_blocks(sparse.COO, dtype=A.dtype).compute(), chunks=-1
)
C = C.transpose("frame", "unit_id").data.map_blocks(sparse.COO, dtype=C.dtype)
AtC = darr.tensordot(C, A, axes=(1, 0)).map_blocks(
lambda a: a.todense(), dtype=A.dtype
)
arr_opt = fct.partial(
custom_arr_optimize, rename_dict={"tensordot": "tensordot_restricted"}
)
with da.config.set(array_optimize=arr_opt):
AtC = da.optimize(AtC)[0]
return xr.DataArray(
AtC,
dims=["frame", "height", "width"],
coords={"frame": fm, "height": h, "width": w},
)
[docs]def graph_optimize_corr(
varr: xr.DataArray,
G: nx.Graph,
freq: float,
idx_dims=["height", "width"],
chunk=600,
step_size=50,
) -> pd.DataFrame:
"""
Compute correlation in an optimized fashion given a computation graph.
This function carry out out-of-core computation of large correaltion matrix.
It takes in a computaion graph whose node represent timeseries and edges
represent the desired pairwise correlation to be computed. The actual
timeseries are stored in `varr` and indexed with node attributes. The
function can carry out smoothing of timeseries before computation of
correlation. To minimize re-computation of smoothing for each pixel, the
graph is first partitioned using a minial-cut algorithm. Then the
computation is performed in chunks with size `chunk`, with nodes from the
same partition being in the same chunk as much as possible.
Parameters
----------
varr : xr.DataArray
Input timeseries. Should have "frame" dimension in addition to those
specified in `idx_dims`.
G : nx.Graph
Graph representing computation to be carried out. Should be undirected
and un-weighted. Each node should have unique attributes with keys
specified in `idx_dims`, which will be used to index the timeseries in
`varr`. Each edge represent a desired correlation.
freq : float
Cut-off frequency for the optional smoothing. If `None` then no
smoothing will be done.
idx_dims : list, optional
The dimension used to index the timeseries in `varr`. By default
`["height", "width"]`.
chunk : int, optional
Chunk size of each computation. By default `600`.
step_size : int, optional
Step size to iterate through all edges. If too small then the iteration
will take a long time. If too large then the variances in the actual
chunksize of computation will be large. By default `50`.
Returns
-------
eg_df : pd.DataFrame
Dataframe representation of edge list. Has column "source" and "target"
representing the node index of the edge (correlation), and column "corr"
with computed value of correlation.
"""
# a heuristic to make number of partitions scale with nodes
n_cuts, membership = pymetis.part_graph(
max(int(np.ceil(G.number_of_nodes() / chunk)), 1), adjacency=adj_list(G)
)
nx.set_node_attributes(
G, {k: {"part": v} for k, v in zip(sorted(G.nodes), membership)}
)
eg_df = nx.to_pandas_edgelist(G)
part_map = nx.get_node_attributes(G, "part")
eg_df["part_src"] = eg_df["source"].map(part_map)
eg_df["part_tgt"] = eg_df["target"].map(part_map)
eg_df["part_diff"] = (eg_df["part_src"] - eg_df["part_tgt"]).astype(bool)
corr_ls = []
idx_ls = []
npxs = []
egd_same, egd_diff = eg_df[~eg_df["part_diff"]], eg_df[eg_df["part_diff"]]
idx_dict = {d: nx.get_node_attributes(G, d) for d in idx_dims}
def construct_comput(edf, pxs):
px_map = {k: v for v, k in enumerate(pxs)}
ridx = edf["source"].map(px_map).values
cidx = edf["target"].map(px_map).values
idx_arr = {
d: xr.DataArray([dd[p] for p in pxs], dims="pixels")
for d, dd in idx_dict.items()
}
vsub = varr.sel(**idx_arr).data
if len(idx_arr) > 1: # vectorized indexing
vsub = vsub.T
else:
vsub = vsub.rechunk(-1)
with da.config.set(**{"optimization.fuse.ave-width": vsub.shape[0]}):
return da.optimize(smooth_corr(vsub, ridx, cidx, freq=freq))[0]
for _, eg_sub in egd_same.groupby("part_src"):
pixels = list(set(eg_sub["source"]) | set(eg_sub["target"]))
corr_ls.append(construct_comput(eg_sub, pixels))
idx_ls.append(eg_sub.index)
npxs.append(len(pixels))
pixels = set()
eg_ls = []
grp = np.arange(len(egd_diff)) // step_size
for igrp, eg_sub in egd_diff.sort_values("source").groupby(grp):
pixels = pixels | set(eg_sub["source"]) | set(eg_sub["target"])
eg_ls.append(eg_sub)
if (len(pixels) > chunk - step_size / 2) or igrp == max(grp):
pixels = list(pixels)
edf = pd.concat(eg_ls)
corr_ls.append(construct_comput(edf, pixels))
idx_ls.append(edf.index)
npxs.append(len(pixels))
pixels = set()
eg_ls = []
print("pixel recompute ratio: {}".format(sum(npxs) / G.number_of_nodes()))
print("computing correlations")
corr_ls = da.compute(corr_ls)[0]
corr = pd.Series(np.concatenate(corr_ls), index=np.concatenate(idx_ls), name="corr")
eg_df["corr"] = corr
return eg_df
[docs]def adj_corr(
varr: xr.DataArray, adj: np.ndarray, nod_df: pd.DataFrame, freq: float
) -> scipy.sparse.csr_matrix:
"""
Compute correlation in an optimized fashion given an adjacency matrix and
node attributes.
Wraps around :func:`graph_optimize_corr` and construct computation graph
from `adj` and `nod_df`. Also convert the result into a sparse matrix with
same shape as `adj`.
Parameters
----------
varr : xr.DataArray
Input time series. Should have "frame" dimension in addition to column
names of `nod_df`.
adj : np.ndarray
Adjacency matrix.
nod_df : pd.DataFrame
Dataframe containing node attributes. Should have length `adj.shape[0]`
and only contain columns relevant to index the time series.
freq : float
Cut-off frequency for the optional smoothing. If `None` then no
smoothing will be done.
Returns
-------
adj_corr : scipy.sparse.csr_matrix
Sparse matrix of the same shape as `adj` but with values corresponding
the computed correlation.
"""
G = nx.Graph()
G.add_nodes_from([(i, d) for i, d in enumerate(nod_df.to_dict("records"))])
G.add_edges_from([(s, t) for s, t in zip(*adj.nonzero())])
corr_df = graph_optimize_corr(varr, G, freq, idx_dims=nod_df.columns)
return scipy.sparse.csr_matrix(
(corr_df["corr"], (corr_df["source"], corr_df["target"])), shape=adj.shape
)
[docs]def adj_list(G: nx.Graph) -> List[np.ndarray]:
"""
Generate adjacency list representation from graph.
Parameters
----------
G : nx.Graph
The input graph.
Returns
-------
adj_ls : List[np.ndarray]
The adjacency list representation of graph.
"""
gdict = nx.to_dict_of_dicts(G)
return [np.array(list(gdict[k].keys())) for k in sorted(gdict.keys())]
[docs]@darr.as_gufunc(signature="(p,f),(i),(i)->(i)", output_dtypes=[float])
def smooth_corr(
X: np.ndarray, ridx: np.ndarray, cidx: np.ndarray, freq: float
) -> np.ndarray:
"""
Wraps around :func:`filt_fft_vec` and :func:`idx_corr` to carry out both
smoothing and computation of partial correlation.
Parameters
----------
X : np.ndarray
Input time series.
ridx : np.ndarray
Row index of the resulting correlation.
cidx : np.ndarray
Column index of the resulting correlation.
freq : float
Cut-off frequency for the smoothing.
Returns
-------
corr : np.ndarray
Resulting partial correlation.
"""
if freq:
X = filt_fft_vec(X, freq, "low")
return idx_corr(X, ridx, cidx)
[docs]@nb.jit(nopython=True, nogil=True, cache=True)
def idx_corr(X: np.ndarray, ridx: np.ndarray, cidx: np.ndarray) -> np.ndarray:
"""
Compute partial pairwise correlation based on index.
This function compute a subset of a pairwise correlation matrix. The
correlation to be computed are specified by two vectors `ridx` and `cidx` of
same length, representing the row and column index of the full correlation
matrix. The function use them to index the timeseries matrix `X` and compute
only the requested correlations. The result is returned flattened.
Parameters
----------
X : np.ndarray
Input time series. Should have 2 dimensions, where the last dimension
should be the time dimension.
ridx : np.ndarray
Row index of the correlation.
cidx : np.ndarray
Column index of the correlation.
Returns
-------
res : np.ndarray
Flattened resulting correlations. Has same shape as `ridx` or `cidx`.
"""
res = np.zeros(ridx.shape[0])
std = np.zeros(X.shape[0])
for i in range(X.shape[0]):
X[i, :] -= X[i, :].mean()
std[i] = np.sqrt((X[i, :] ** 2).sum())
for i, (r, c) in enumerate(zip(ridx, cidx)):
cur_std = std[r] * std[c]
if cur_std > 0:
res[i] = (X[r, :] * X[c, :]).sum() / cur_std
else:
res[i] = 0
return res
[docs]def update_background(
Y: xr.DataArray, A: xr.DataArray, C: xr.DataArray, b: xr.DataArray = None
) -> Tuple[xr.DataArray, xr.DataArray]:
"""
Update background terms given spatial and temporal components of cells.
A movie representation (with dimensions "height" "width" and "frame") of
estimated cell activities are computed as the product between the spatial
components matrix and the temporal components matrix of cells over the
"unit_id" dimension. Then the residule movie is computed by subtracting the
estimated cell activity movie from the input movie. Then the spatial
footprint of background `b` is the mean of the residule movie over "frame"
dimension, and the temporal component of background `f` is the least-square
solution between the residule movie and the spatial footprint `b`.
Parameters
----------
Y : xr.DataArray
Input movie data. Should have dimensions ("frame", "height", "width").
A : xr.DataArray
Estimation of spatial footprints of cells. Should have dimensions
("unit_id", "height", "width").
C : xr.DataArray
Estimation of temporal activities of cells. Should have dimensions
("unit_id", "frame").
b : xr.DataArray, optional
Previous estimation of spatial footprint of background. If provided it
will be returned as-is, and only temporal activity of background will be
updated
Returns
-------
b_new : xr.DataArray
New estimation of the spatial footprint of background. Has
dimensions ("height", "width").
f_new : xr.DataArray
New estimation of the temporal activity of background. Has dimension
"frame".
"""
intpath = os.environ["MINIAN_INTERMEDIATE"]
AtC = compute_AtC(A, C)
Yb = (Y - AtC).clip(0)
Yb = save_minian(Yb.rename("Yb"), intpath, overwrite=True)
if b is None:
b_new = Yb.mean("frame").persist()
else:
b_new = b.persist()
b_stk = (
b_new.stack(spatial=["height", "width"])
.transpose("spatial")
.expand_dims("dummy", axis=-1)
.chunk(-1)
)
Yb_stk = Yb.stack(spatial=["height", "width"]).transpose("spatial", "frame")
f_new = darr.linalg.lstsq(b_stk.data, Yb_stk.data)[0]
f_new = xr.DataArray(
f_new.squeeze(), dims=["frame"], coords={"frame": Yb.coords["frame"]}
).persist()
return b_new, f_new