Source code for minian.visualization

import functools as fct
import itertools as itt
import os
from collections import OrderedDict
from typing import Callable, List, Optional, Tuple, Union
from uuid import uuid4

import colorcet as cc
import cv2
import dask
import dask.array as da
import ffmpeg
import holoviews as hv
import numpy as np
import pandas as pd
import panel as pn
import param
import scipy.sparse as scisps
import sklearn.mixture
import skvideo.io
import xarray as xr
from bokeh.palettes import Category10_10, Viridis256
from dask.diagnostics import ProgressBar
from datashader import count_cat
from holoviews.operation.datashader import datashade, dynspread
from holoviews.streams import (
    BoxEdit,
    DoubleTap,
    Pipe,
    RangeXY,
    Selection1D,
    Stream,
    Tap,
)
from holoviews.util import Dynamic
from matplotlib import cm
from panel import widgets as pnwgt
from scipy import linalg
from scipy.ndimage.measurements import center_of_mass
from scipy.spatial import cKDTree

from .cnmf import compute_AtC
from .motion_correction import apply_shifts
from .utilities import custom_arr_optimize, rechunk_like


[docs]class VArrayViewer: """ Interactive visualization for movie data arrays. Hint ---- .. figure:: img/vaviewer.png :width: 500px :align: left The visualization contains following panels from top to bottom: Play Toolbar A toolbar that controls playback of the video. Additionally, when the button "Update Mask" is clicked, the coordinates of the box drawn in *Current Frame* panel will be used to update the `mask` attribute of the `VArrayViewer` instance, which can be later used to subset the data. If multiple arrays are visualized and `layout` is `False`, then drop-down lists corresponding to each metadata dimensions will show up so the user can select which array to visualize. Current Frame Images of the current frame. If multiple movie array are passed in, multiple frames will be labeled and shown. To the side of each frame there is a histogram of intensity values. The "Box Select" tool can be used on the histogram to limit the range of intensity used for color-mapping. Additionally, the "Box Edit Tool" is available for use on the frame image, where you can hold "Shift" and draw a box, whose coordinates can be used to update the `mask` attribute of the `VarrayViewer` instance (remember to click "Update Mask" after drawing). Summary Summary statistics of each frame across time. Only shown if `summary` is not empty. The red vertical line indicate current frame. Attributes ---------- mask : dict Instance attribute that can be retrieved and used to subset data later. Keys are `tuple` with values corresponding to each `meta_dims` and uniquely identify each input array. If `meta_dims` is empty then keys will be empty `tuple` as well. Values are `dict` mapping dimension names (of the arrays) to subsetting slices. The slices are in the plotting coorandinates and can be directly passed to `xr.DataArray.sel` method to subset data. """ def __init__( self, varr: Union[xr.DataArray, List[xr.DataArray], xr.Dataset], framerate=30, summary=["mean"], meta_dims: List[str] = None, datashading=True, layout=False, ): """ Parameters ---------- varr : Union[xr.DataArray, List[xr.DataArray], xr.Dataset] Input array, list of arrays, or dataset to be visualized. Each array should contain dimensions "height", "width" and "frame". If a dataset, then the dimensions specified in `meta_dims` will be used as metadata dimensions that can uniquely identify each array. If a list, then a dimension "data_var" will be constructed and used as metadata dimension, and the `.name` attribute of each array will be used to identify each array. framerate : int, optional The framerate of playback when using the toolbar. By default `30`. summary : list, optional List of summary statistics to plot. The statistics should be one of `{"mean", "max", "min", "diff"}`. By default `["mean"]`. meta_dims : List[str], optional List of dimension names that can uniquely identify each input array in `varr`. Only used if `varr` is a `xr.Dataset`. By default `None`. datashading : bool, optional Whether to use datashading on the summary statistics. By default `True`. layout : bool, optional Whether to visualize all arrays together as layout. If `False` then only one array will be visualized and user can switch array using drop-down lists below the *Play Toolbar*. By default `False`. Raises ------ NotImplementedError if `varr` is not a `xr.DataArray`, a `xr.Dataset` or a list of `xr.DataArray` """ if isinstance(varr, list): for iv, v in enumerate(varr): varr[iv] = v.assign_coords(data_var=v.name) self.ds = xr.concat(varr, dim="data_var") meta_dims = ["data_var"] elif isinstance(varr, xr.DataArray): self.ds = varr.to_dataset() elif isinstance(varr, xr.Dataset): self.ds = varr else: raise NotImplementedError( "video array of type {} not supported".format(type(varr)) ) try: self.meta_dicts = OrderedDict( [(d, list(self.ds.coords[d].values)) for d in meta_dims] ) self.cur_metas = OrderedDict( [(d, v[0]) for d, v in self.meta_dicts.items()] ) except TypeError: self.meta_dicts = dict() self.cur_metas = dict() self._datashade = datashading self._layout = layout self.framerate = framerate self._f = self.ds.coords["frame"].values self._h = self.ds.sizes["height"] self._w = self.ds.sizes["width"] self.mask = dict() CStream = Stream.define( "CStream", f=param.Integer( default=int(self._f.min()), bounds=(self._f.min(), self._f.max()) ), ) self.strm_f = CStream() self.str_box = BoxEdit() self.widgets = self._widgets() if type(summary) is list: summ_all = { "mean": self.ds.mean(["height", "width"]), "max": self.ds.max(["height", "width"]), "min": self.ds.min(["height", "width"]), "diff": self.ds.diff("frame").mean(["height", "width"]), } try: summ = {k: summ_all[k] for k in summary} except KeyError: print("{} Not understood for specifying summary".format(summary)) if summ: print("computing summary") sum_list = [] for k, v in summ.items(): sum_list.append(v.compute().assign_coords(sum_var=k)) summary = xr.concat(sum_list, dim="sum_var") self.summary = summary if layout: self.ds_sub = self.ds self.sum_sub = self.summary else: self.ds_sub = self.ds.sel(**self.cur_metas) try: self.sum_sub = self.summary.sel(**self.cur_metas) except AttributeError: self.sum_sub = self.summary self.pnplot = pn.panel(self.get_hvobj()) def get_hvobj(self): def get_im_ovly(meta): def img(f, ds): return hv.Image(ds.sel(frame=f).compute(), kdims=["width", "height"]) try: curds = self.ds_sub.sel(**meta).rename("_".join(meta.values())) except ValueError: curds = self.ds_sub fim = fct.partial(img, ds=curds) im = hv.DynamicMap(fim, streams=[self.strm_f]).opts( frame_width=500, aspect=self._w / self._h, cmap="Viridis" ) self.xyrange = RangeXY(source=im).rename(x_range="w", y_range="h") if not self._layout: hv_box = hv.Polygons([]).opts( style={"fill_alpha": 0.3, "line_color": "white"} ) self.str_box = BoxEdit(source=hv_box) im_ovly = im * hv_box else: im_ovly = im def hist(f, w, h, ds): if w and h: cur_im = hv.Image( ds.sel(frame=f).compute(), kdims=["width", "height"] ).select(height=h, width=w) else: cur_im = hv.Image( ds.sel(frame=f).compute(), kdims=["width", "height"] ) return hv.operation.histogram(cur_im, num_bins=50).opts( xlabel="fluorescence", ylabel="freq" ) fhist = fct.partial(hist, ds=curds) his = hv.DynamicMap(fhist, streams=[self.strm_f, self.xyrange]).opts( frame_height=int(500 * self._h / self._w), width=150, cmap="Viridis" ) im_ovly = (im_ovly << his).map(lambda p: p.opts(style=dict(cmap="Viridis"))) return im_ovly if self._layout and self.meta_dicts: im_dict = OrderedDict() for meta in itt.product(*list(self.meta_dicts.values())): mdict = {k: v for k, v in zip(list(self.meta_dicts.keys()), meta)} im_dict[meta] = get_im_ovly(mdict) ims = hv.NdLayout(im_dict, kdims=list(self.meta_dicts.keys())) else: ims = get_im_ovly(self.cur_metas) if self.summary is not None: hvsum = ( hv.Dataset(self.sum_sub) .to(hv.Curve, kdims=["frame"]) .overlay("sum_var") ) if self._datashade: hvsum = datashade_ndcurve(hvsum, kdim="sum_var") try: hvsum = hvsum.layout(list(self.meta_dicts.keys())) except: pass vl = hv.DynamicMap(lambda f: hv.VLine(f), streams=[self.strm_f]).opts( style=dict(color="red") ) summ = (hvsum * vl).map( lambda p: p.opts(frame_width=500, aspect=3), [hv.RGB, hv.Curve] ) hvobj = (ims + summ).cols(1) else: hvobj = ims return hvobj
[docs] def show(self) -> pn.layout.Column: """ Return visualizations that can be directly displayed. Returns ------- pn.layout.Column Resulting visualizations containing both plots and toolbars. """ return pn.layout.Column(self.widgets, self.pnplot)
def _widgets(self): w_play = pnwgt.Player( length=len(self._f), interval=10, value=0, width=650, height=90 ) def play(f): if not f.old == f.new: self.strm_f.event(f=int(self._f[f.new])) w_play.param.watch(play, "value") w_box = pnwgt.Button( name="Update Mask", button_type="primary", width=100, height=30 ) w_box.param.watch(self._update_box, "clicks") if not self._layout: wgt_meta = { d: pnwgt.Select(name=d, options=v, height=45, width=120) for d, v in self.meta_dicts.items() } def make_update_func(meta_name): def _update(x): self.cur_metas[meta_name] = x.new self._update_subs() return _update for d, wgt in wgt_meta.items(): cur_update = make_update_func(d) wgt.param.watch(cur_update, "value") wgts = pn.layout.WidgetBox(w_box, w_play, *list(wgt_meta.values())) else: wgts = pn.layout.WidgetBox(w_box, w_play) return wgts def _update_subs(self): self.ds_sub = self.ds.sel(**self.cur_metas) if self.sum_sub is not None: self.sum_sub = self.summary.sel(**self.cur_metas) self.pnplot.objects[0].object = self.get_hvobj() def _update_box(self, click): box = self.str_box.data self.mask.update( { tuple(self.cur_metas.values()): { "height": slice(box["y0"][0], box["y1"][0]), "width": slice(box["x0"][0], box["x1"][0]), } } )
[docs]class CNMFViewer: """ Interactive visualization for CNMF results. Hint ---- .. figure:: img/cnmfviewer.png :width: 1000px The visualization can be divided into two parts vertically: Spatial Top part of the visualization. Shows spatial plots at a given time. From left to right: Spatial Footprints Shows the spatial footprints of all cells. The "Box Select" tool can be used in this panel to select a subset of cells to visualize for both the *Isolated Activities* panel and the *Temporal Activities* panel. Isolated Activities Shows activities of selected cells only. If the "UseAC" checkbox under *General Toolbox* is enabled, then the `AtC` variable computed with the selected cells will be visualized at the given frame (See :func:`minian.cnmf.compute_AtC`). Otherwise the spatial footprints of the cells will be plotted, which would be invariant across time. The "unit_id" coordinates for each cell are shown on top of each cell. Original Movie Shows a single frame of an arbitrary movie data supplied in `org`. Temporal Bottom part of the visualization. Shows temporal activities across time and various toolboxes. From left to right: General Toolbox Contains the following tools: * "Refresh" button, will refresh all visualization when clicked. * "Load Data" button, will load all data in memory for faster visualization, can be very memory-demanding. * "UseAC" checkbox, whether to plot spatial-temporal activities for the *Isolated Activities* panel. * "ShowC", "ShowS", "Normalize" checkboxes, whether to show the calcium traces, the spike signals, or to normalize both traces to unit range for each cell. * "Group" dropbox, "Previous Group" and "Next Group" buttons, select the group of cells to visualize. The grouping is controled by `sortNN` parameter. * Playback toolbar, used to control which timepoint is visualized. * Additional metadata dropdown, if the input dataset contains additional metadata dimensions then dropdown will show up so user can select which dataset to visualize. Temporal Activities Shows temporal activities of selected subset of cells. The red vertical line indicate current frame. Additionally user can double-click anywhere in the plot to move current frame to that location. Manual Label Shows tools to carry out manual labeling of cells. User can either manually assign unit label using the dropdown for each cell, or select some cells with the checkboxes corresponding to the "unit_id", and then merge or discard the units using the buttons. The "Unit Label" dropdowns should update and refelect the merging or discarding actions. Attributes ---------- unit_labels : xr.DataArray 1d array whose values represent the result of manual refinement of cells. The "unit_id" coordinate of this array is identical to input data. The values of this array can be interpreted as new "unit_id" after the manual refinement, where duplicated values indicate merged cells, and values of -1 indicate discarded cells. """ def __init__( self, minian: Optional[xr.Dataset] = None, A: Optional[xr.DataArray] = None, C: Optional[xr.DataArray] = None, S: Optional[xr.DataArray] = None, org: Optional[xr.DataArray] = None, sortNN=True, ): """ Parameters ---------- minian : xr.Dataset, optional Input minian dataset containing all necessary variables. If `None` then all other arguments should be supplied. By default `None`. A : xr.DataArray, optional Spatial footprints of cells. If `None` then it will be retrieved as `minian["A"]`. By default `None`. C : xr.DataArray, optional Calcium dynamic of cells. If `None` then it will be retrieved as `minian["C"]`. By default `None`. S : xr.DataArray, optional Deconvolved spikes of cells. If `None` then it will be retrieved as `minian["S"]`. By default `None`. org : xr.DataArray, optional Arbitrary movie data to be visualized along with results of CNMF. If `None` then it will be retrieved as `minian["org"]`. If this array contains dimensions other than "height", "width" or "frame" then they will be used as metadata dimensions. By default `None`. sortNN : bool, optional Whether to sort the units using :func:`NNsort` so that cells close together will appear in same group for visualization. If `False` then cells are simply grouped in 5 by ascending "unit_id". By default `True`. """ self._A = A if A is not None else minian["A"] self._C = C if C is not None else minian["C"] self._S = S if S is not None else minian["S"] self._org = org if org is not None else minian["org"] try: self.unit_labels = minian["unit_labels"].compute() except: self.unit_labels = xr.DataArray( self._A["unit_id"].values.copy(), dims=self._A["unit_id"].dims, coords=self._A["unit_id"].coords, ).rename("unit_labels") self._C_norm = xr.apply_ufunc( normalize, self._C.chunk(dict(frame=-1, unit_id="auto")), input_core_dims=[["frame"]], output_core_dims=[["frame"]], vectorize=True, dask="parallelized", output_dtypes=[self._C.dtype], ) self._S_norm = xr.apply_ufunc( normalize, self._S.chunk(dict(frame=-1, unit_id="auto")), input_core_dims=[["frame"]], output_core_dims=[["frame"]], vectorize=True, dask="parallelized", output_dtypes=[self._C.dtype], ) self.cents = centroid(self._A, verbose=True) print("computing sum projection") with ProgressBar(): self.Asum = self._A.sum("unit_id").compute() self._NNsort = sortNN self._normalize = False self._useAC = True self._showC = True self._showS = True meta_dims = list(set(self._org.dims) - {"frame", "height", "width"}) self.meta_dicts = {d: list(self._org.coords[d].values) for d in meta_dims} self.metas = {d: v[0] for d, v in self.meta_dicts.items()} if self._NNsort: try: self.cents["NNord"] = self.cents.groupby( meta_dims, group_keys=False ).apply(NNsort) except ValueError: self.cents["NNord"] = NNsort(self.cents) NNcoords = self.cents.set_index(meta_dims + ["unit_id"])[ "NNord" ].to_xarray() self._A = self._A.assign_coords(NNord=NNcoords) self._C = self._C.assign_coords(NNord=NNcoords) self._S = self._S.assign_coords(NNord=NNcoords) self._C_norm = self._C_norm.assign_coords(NNord=NNcoords) self._S_norm = self._S_norm.assign_coords(NNord=NNcoords) self.update_subs() self.strm_f = DoubleTap(rename=dict(x="f")) self.strm_f.add_subscriber(self.callback_f) self.strm_uid = Selection1D() self.strm_uid.add_subscriber(self.callback_uid) Stream_usub = Stream.define("Stream_usub", usub=param.List()) self.strm_usub = Stream_usub() self.strm_usub.add_subscriber(self.callback_usub) self.usub_sel = self.strm_usub.usub self._AC = self._org.sel(**self.metas) self._mov = self._org.sel(**self.metas) self.pipAC = Pipe([]) self.pipmov = Pipe([]) self.pipusub = Pipe([]) self.wgt_meta = self._meta_wgt() self.wgt_spatial_all = self._spatial_all_wgt() self.spatial_all = self._spatial_all() self.temp_comp_sub = self._temp_comp_sub(self._u[:5]) self.wgt_man = self._man_wgt() self.wgt_temp_comp = self._temp_comp_wgt() def update_subs(self): self.A_sub = self._A.sel(**self.metas) self.C_sub = self._C.sel(**self.metas) self.S_sub = self._S.sel(**self.metas) self.org_sub = self._org.sel(**self.metas) self.C_norm_sub = self._C_norm.sel(**self.metas) self.S_norm_sub = self._S_norm.sel(**self.metas) if self._NNsort: self.A_sub = self.A_sub.sortby("NNord") self.C_sub = self.C_sub.sortby("NNord") self.S_sub = self.S_sub.sortby("NNord") self.C_norm_sub = self.C_norm_sub.sortby("NNord") self.S_norm_sub = self.S_norm_sub.sortby("NNord") self._h = ( self.A_sub.isel(unit_id=0) .dropna("height", how="all") .coords["height"] .values ) self._w = ( self.A_sub.isel(unit_id=0).dropna("width", how="all").coords["width"].values ) self._f = self.C_sub.isel(unit_id=0).dropna("frame").coords["frame"].values self._u = self.C_sub.isel(frame=0).dropna("unit_id").coords["unit_id"].values if self.meta_dicts: sub = pd.concat( [self.cents[d] == v for d, v in self.metas.items()], axis="columns" ).all(axis="columns") self.cents_sub = self.cents[sub] else: self.cents_sub = self.cents def compute_subs(self, clicks=None): self.A_sub = self.A_sub.compute() self.C_sub = self.C_sub.compute() self.S_sub = self.S_sub.compute() self.org_sub = self.org_sub.compute() self.C_norm_sub = self.C_norm_sub.compute() self.S_norm_sub = self.S_norm_sub.compute() def update_all(self, clicks=None): self.update_subs() self.strm_uid.event(index=[]) self.strm_f.event(x=0) self.update_spatial_all() def callback_uid(self, index=None): self.update_temp() self.update_AC() self.update_usub_lab() def callback_f(self, f, y): if len(self._AC) > 0 and len(self._mov) > 0: fidx = np.abs(self._f - f).argmin() f = self._f[fidx] if self._useAC: AC = self._AC.sel(frame=f) else: AC = self._AC mov = self._mov.sel(frame=f) self.pipAC.send(AC) self.pipmov.send(mov) try: self.wgt_temp_comp[1].value = int(fidx) except AttributeError: pass else: self.pipAC.send([]) self.pipmov.send([]) def callback_usub(self, usub=None): self.update_temp_comp_sub(usub) self.update_AC(usub) self.update_usub_lab(usub) def _meta_wgt(self): wgt_meta = { d: pnwgt.Select(name=d, options=v, height=45, width=120) for d, v in self.meta_dicts.items() } def make_update_func(meta_name): def _update(x): self.metas[meta_name] = x.new self.update_subs() return _update for d, wgt in wgt_meta.items(): cur_update = make_update_func(d) wgt.param.watch(cur_update, "value") wgt_update = pnwgt.Button( name="Refresh", button_type="primary", height=30, width=120 ) wgt_update.param.watch(self.update_all, "clicks") wgt_load = pnwgt.Button( name="Load Data", button_type="danger", height=30, width=120 ) wgt_load.param.watch(self.compute_subs, "clicks") return pn.layout.WidgetBox( *(list(wgt_meta.values()) + [wgt_update, wgt_load]), width=150 )
[docs] def show(self) -> pn.layout.Column: """ Return visualizations that can be directly displayed. Returns ------- pn.layout.Column Resulting visualizations containing both plots and toolboxes. """ return pn.layout.Column( self.spatial_all, pn.layout.Row( pn.layout.Column( pn.layout.Row(self.wgt_meta, self.wgt_spatial_all), self.wgt_temp_comp, ), self.temp_comp_sub, self.wgt_man, ), )
def _temp_comp_sub(self, usub=None): if usub is None: usub = self.strm_usub.usub if self._normalize: C, S = self.C_norm_sub, self.S_norm_sub else: C, S = self.C_sub, self.S_sub cur_temp = dict() if self._showC: cur_temp["C"] = hv.Dataset( C.sel(unit_id=usub) .compute() .rename("Intensity (A. U.)") .dropna("frame", how="all") ).to(hv.Curve, "frame") if self._showS: cur_temp["S"] = hv.Dataset( S.sel(unit_id=usub) .compute() .rename("Intensity (A. U.)") .dropna("frame", how="all") ).to(hv.Curve, "frame") cur_vl = hv.DynamicMap( lambda f, y: hv.VLine(f) if f else hv.VLine(0), streams=[self.strm_f] ).opts(style=dict(color="red")) cur_cv = hv.Curve([], kdims=["frame"], vdims=["Internsity (A.U.)"]) self.strm_f.source = cur_cv h_cv = len(self._w) // 8 w_cv = len(self._w) * 2 temp_comp = ( cur_cv * datashade_ndcurve( hv.HoloMap(cur_temp, "trace") .collate() .overlay("trace") .grid("unit_id") .add_dimension("time", 0, 0), "trace", ) .opts(plot=dict(shared_xaxis=True)) .map( lambda p: p.opts(plot=dict(frame_height=h_cv, frame_width=w_cv)), hv.RGB ) * cur_vl ) temp_comp[temp_comp.keys()[0]] = temp_comp[temp_comp.keys()[0]].opts( plot=dict(height=h_cv + 75) ) return pn.panel(temp_comp) def update_temp_comp_sub(self, usub=None): self.temp_comp_sub.object = self._temp_comp_sub(usub).object self.wgt_man.objects = self._man_wgt().objects def update_norm(self, norm): self._normalize = norm.new self.update_temp_comp_sub() def _temp_comp_wgt(self): if self.strm_uid.index: cur_idxs = self.strm_uid.index else: cur_idxs = self._u ntabs = np.ceil(len(cur_idxs) / 5) sub_idxs = np.array_split(cur_idxs, ntabs) idxs_dict = OrderedDict( [("group{}".format(i), g.tolist()) for i, g in enumerate(sub_idxs)] ) def_idxs = list(idxs_dict.values())[0] wgt_grp = pnwgt.Select( name="", options=idxs_dict, width=120, height=30, value=def_idxs ) def update_usub(usub): self.usub_sel = [] self.strm_usub.event(usub=usub.new) wgt_grp.param.watch(update_usub, "value") wgt_grp.value = def_idxs self.strm_usub.event(usub=def_idxs) wgt_grp_prv = pnwgt.Button( name="Previous Group", width=120, height=30, button_type="primary" ) def prv(clicks): cur_val = wgt_grp.value ig = list(idxs_dict.values()).index(cur_val) try: prv_val = idxs_dict[list(idxs_dict.keys())[ig - 1]] wgt_grp.value = prv_val except: pass wgt_grp_prv.param.watch(prv, "clicks") wgt_grp_nxt = pnwgt.Button( name="Next Group", width=120, height=30, button_type="primary" ) def nxt(clicks): cur_val = wgt_grp.value ig = list(idxs_dict.values()).index(cur_val) try: nxt_val = idxs_dict[list(idxs_dict.keys())[ig + 1]] wgt_grp.value = nxt_val except: pass wgt_grp_nxt.param.watch(nxt, "clicks") wgt_norm = pnwgt.Checkbox( name="Normalize", value=self._normalize, width=120, height=10 ) wgt_norm.param.watch(self.update_norm, "value") wgt_showC = pnwgt.Checkbox( name="ShowC", value=self._showC, width=120, height=10 ) def callback_showC(val): self._showC = val.new self.update_temp_comp_sub() wgt_showC.param.watch(callback_showC, "value") wgt_showS = pnwgt.Checkbox( name="ShowS", value=self._showS, width=120, height=10 ) def callback_showS(val): self._showS = val.new self.update_temp_comp_sub() wgt_showS.param.watch(callback_showS, "value") wgt_play = pnwgt.Player(length=len(self._f), interval=10, value=0, width=280) def play(f): if not f.old == f.new: self.strm_f.event(x=self._f[f.new]) wgt_play.param.watch(play, "value") wgt_groups = pn.layout.Row( pn.layout.WidgetBox(wgt_norm, wgt_showC, wgt_showS, wgt_grp, width=150), pn.layout.WidgetBox(wgt_grp_prv, wgt_grp_nxt, width=150), ) return pn.layout.Column(wgt_groups, wgt_play) def _man_wgt(self): usub = self.strm_usub.usub usub.sort() usub.reverse() ulabs = self.unit_labels.sel(unit_id=usub).values wgt_sel = { uid: pnwgt.Select( name="Unit Label", options=usub + [-1] + ulabs.tolist(), value=ulb, height=50, width=80, ) for uid, ulb in zip(usub, ulabs) } def callback_ulab(value, uid): self.unit_labels.loc[uid] = value.new for uid, sel in wgt_sel.items(): cb = fct.partial(callback_ulab, uid=uid) sel.param.watch(cb, "value") wgt_check = { uid: pnwgt.Checkbox( name="Unit ID: {}".format(uid), value=False, height=50, width=100 ) for uid in usub } def callback_chk(val, uid): if not val.old == val.new: if val.new: self.usub_sel.append(uid) else: self.usub_sel.remove(uid) for uid, chk in wgt_check.items(): cb = fct.partial(callback_chk, uid=uid) chk.param.watch(cb, "value") wgt_discard = pnwgt.Button( name="Discard Selected", button_type="primary", width=180 ) def callback_discard(clicks): for uid in self.usub_sel: wgt_sel[uid].value = -1 wgt_discard.param.watch(callback_discard, "clicks") wgt_merge = pnwgt.Button( name="Merge Selected", button_type="primary", width=180 ) def callback_merge(clicks): for uid in self.usub_sel: wgt_sel[uid].value = self.usub_sel[0] wgt_merge.param.watch(callback_merge, "clicks") return pn.layout.Column( pn.layout.WidgetBox(wgt_discard, wgt_merge, width=200), pn.layout.Row( pn.layout.WidgetBox(*wgt_check.values(), width=100), pn.layout.WidgetBox(*wgt_sel.values(), width=100), ), ) def update_temp_comp_wgt(self): self.wgt_temp_comp.objects = self._temp_comp_wgt().objects def update_temp(self): self.update_temp_comp_wgt() def update_AC(self, usub=None): if usub is None: usub = self.strm_usub.usub if usub: if self._useAC: umask = (self.A_sub.sel(unit_id=usub) > 0).any("unit_id") A_sub = self.A_sub.sel(unit_id=usub).where(umask, drop=True).fillna(0) C_sub = self.C_sub.sel(unit_id=usub) AC = xr.apply_ufunc( da.dot, A_sub, C_sub, input_core_dims=[ ["height", "width", "unit_id"], ["unit_id", "frame"], ], output_core_dims=[["height", "width", "frame"]], dask="allowed", ) self._AC = AC.compute() wndh, wndw = AC.coords["height"].values, AC.coords["width"].values window = self.A_sub.sel( height=slice(wndh.min(), wndh.max()), width=slice(wndw.min(), wndw.max()), ) self._AC = self._AC.reindex_like(window).fillna(0) self._mov = (self.org_sub.reindex_like(window)).compute() else: self._AC = self.A_sub.sel(unit_id=usub).sum("unit_id") self._mov = self.org_sub self.strm_f.event(x=0) else: self._AC = xr.DataArray([]) self._mov = xr.DataArray([]) self.strm_f.event(x=0) def update_usub_lab(self, usub=None): if usub is None: usub = self.strm_usub.usub if usub: self.pipusub.send(self.cents_sub[self.cents_sub["unit_id"].isin(usub)]) else: self.pipusub.send([]) def _spatial_all_wgt(self): wgt_useAC = pnwgt.Checkbox( name="UseAC", value=self._useAC, width=120, height=15 ) def callback_useAC(val): self._useAC = val.new self.update_AC() wgt_useAC.param.watch(callback_useAC, "value") return pn.layout.WidgetBox(wgt_useAC, width=150) def _spatial_all(self): metas = self.metas Asum = hv.Image(self.Asum.sel(**metas), ["width", "height"]).opts( plot=dict(frame_height=len(self._h), frame_width=len(self._w)), style=dict(cmap="Viridis"), ) cents = ( hv.Dataset( self.cents_sub.drop(list(self.meta_dicts.keys()), axis="columns"), kdims=["width", "height", "unit_id"], ) .to(hv.Points, ["width", "height"]) .opts( style=dict( alpha=0.1, line_alpha=0, size=5, nonselection_alpha=0.1, selection_alpha=0.9, ) ) .collate() .overlay("unit_id") .opts(plot=dict(tools=["hover", "box_select"])) ) self.strm_uid.source = cents fim = fct.partial(hv.Image, kdims=["width", "height"]) AC = hv.DynamicMap(fim, streams=[self.pipAC]).opts( plot=dict(frame_height=len(self._h), frame_width=len(self._w)), style=dict(cmap="Viridis"), ) mov = hv.DynamicMap(fim, streams=[self.pipmov]).opts( plot=dict(frame_height=len(self._h), frame_width=len(self._w)), style=dict(cmap="Viridis"), ) lab = fct.partial(hv.Labels, kdims=["width", "height"], vdims=["unit_id"]) ulab = hv.DynamicMap(lab, streams=[self.pipusub]).opts( style=dict(text_color="red") ) return pn.panel(Asum * cents + AC * ulab + mov) def update_spatial_all(self): self.spatial_all.objects = self._spatial_all().objects
[docs]class AlignViewer: """ Interactive visualization of cross-registration resuls. Hint ---- .. image:: img/alignviewer.png :width: 700px This class visualize the result of cross-registration by color-mapping spatial footprints of cells from three selected sessions as red, green and blue channel and show an overlay image. In addition to the overlay image, following tools are available: Channel Selector Contains "sessionR", "sessionG", and "sessionB" dropdowns, allowing the user to select which sessions are colormapped to each channel. Display Settings Contains the following tools: * "erode" dropdown, set window size of an optional erode operation applied to the spatial footprints for display to reduce overlaps. * "show matched" and "show unmatched" checkboxes, set whether to show cells that are matched or not matched across all three selected sessions. Metadata Selector If additional metadata are present, dropdowns corresponding to each metadata dimensions will be shown. """ def __init__( self, minian_ds: xr.Dataset, cents: pd.DataFrame, mappings: pd.DataFrame, shiftds: xr.Dataset, brt_offset=0, ) -> None: """ Parameters ---------- minian_ds : xr.Dataset Input dataset. Should contain `minian_ds["A"]`. cents : pd.DataFrame Input centroids of cells. mappings : pd.DataFrame Input mappings of cells. shiftds : xr.Dataset Input dataset of shift results. Should contain `shiftds["shifts"]`. brt_offset : int, optional Brightness offset added on top of the color-mapped image. Useful to make the image visually brighter. By default `0`. """ # init self.minian_ds = minian_ds self.cents = cents self.mappings = mappings self.shiftds = shiftds self.brt_offset = brt_offset A = self.minian_ds["A"] self.shifts = rechunk_like(self.shiftds["shifts"], A) self.Ash = apply_shifts(A, self.shifts, fill=0) # option widgets self.erode = 3 wgt_er = pnwgt.Select(name="erode", options=np.arange(0, 20).tolist(), value=3) wgt_er.param.watch(self.cb_update_erd, "value") self.show_ma = True wgt_ma = pnwgt.Checkbox(name="show matched", value=True) wgt_ma.param.watch(self.cb_showma, "value") self.show_uma = True wgt_uma = pnwgt.Checkbox(name="show unmatched", value=True) wgt_uma.param.watch(self.cb_showuma, "value") self.wgt_opt = pn.layout.WidgetBox(wgt_er, wgt_ma, wgt_uma) self.processA() # handling meta try: self.meta_dict = { col: c.unique().tolist() for col, c in mappings["meta"].iteritems() } except KeyError: self.meta_dict = None if self.meta_dict: self.meta = {d: v[0] for d, v in self.meta_dict.items()} wgt_meta = [ pnwgt.Select(name=dim, options=vals) for dim, vals in self.meta_dict.items() ] for w in wgt_meta: w.param.watch(lambda v, n=w.name: self.cb_update_meta(n, v), "value") self.wgt_meta = pn.layout.WidgetBox(*wgt_meta) else: self.wgt_meta = None self.update_meta() # sessionRGB sess = list(mappings["session"].columns) self.sess_rgb = {"r": sess[0], "g": sess[0], "b": sess[0]} wgt_sess = { c: pnwgt.Select(name="session{}".format(c.upper()), options=sess) for c in ["r", "g", "b"] } for wname, w in wgt_sess.items(): w.param.watch(lambda v, n=wname: self.cb_update_rgb(n, v), "value") self.wgt_rgb = pn.layout.WidgetBox(*list(wgt_sess.values())) self.plot = self.update_plot() def processA(self): A = self.Ash if self.erode >= 3: A = xr.apply_ufunc( cv2.erode, A, input_core_dims=[["height", "width"]], output_core_dims=[["height", "width"]], vectorize=True, dask="parallelized", kwargs={"kernel": np.ones((self.erode, self.erode))}, output_dtypes=[float], ) self.dataA = xr.apply_ufunc( norm, A, input_core_dims=[["height", "width"]], output_core_dims=[["height", "width"]], vectorize=True, dask="parallelized", output_dtypes=[float], ) def update_plot(self): Adict = { c: self.curA.sel(session=self.sess_rgb[c]) .dropna("unit_id", how="all") .compute() for c in self.sess_rgb.keys() } map_sub = self.curmap["session"][list(self.sess_rgb.values())].dropna(how="all") map_sub = map_sub.loc[:, ~map_sub.columns.duplicated()] ma_mask = map_sub.notnull().all(axis="columns") imdict = { c: np.zeros((A.sizes["height"], A.sizes["width"])) for c, A in Adict.items() } if self.show_ma: ma_map = map_sub.loc[ma_mask] for c, im in imdict.items(): uids = ma_map[self.sess_rgb[c]].values imdict[c] = im + Adict[c].sel(unit_id=uids).sum("unit_id").compute() if self.show_uma: uma_map = map_sub.loc[~ma_mask] for c, im in imdict.items(): uids = uma_map[self.sess_rgb[c]].dropna().values imdict[c] = im + Adict[c].sel(unit_id=uids).sum("unit_id").compute() cmaps = { "r": cc.m_linear_kryw_0_100_c71, "g": cc.m_linear_green_5_95_c69, "b": cc.m_linear_blue_5_95_c73, } for c, im in imdict.items(): imdict[c] = cm.ScalarMappable(cmap=cmaps[c]).to_rgba(im) im_ovly = xr.DataArray( np.clip(imdict["r"] + imdict["g"] + imdict["b"] + self.brt_offset, 0, 1), dims=["height", "width", "rgb"], coords={ "height": self.curA.coords["height"].values, "width": self.curA.coords["width"].values, }, ) im_opts = { "frame_height": self.curA.sizes["height"], "frame_width": self.curA.sizes["width"], } return pn.panel( hv.RGB( ( im_ovly.coords["width"], im_ovly.coords["height"], im_ovly[:, :, 0], im_ovly[:, :, 1], im_ovly[:, :, 2], im_ovly[:, :, 3], ), kdims=["width", "height"], ).opts(**im_opts) ) def update_meta(self): if self.meta_dict: self.curA = self.dataA.sel(**self.meta).persist() self.curmap = ( self.mappings.set_index([("meta", d) for d in self.meta.keys()]) .loc[tuple(self.meta.values())] .reset_index() ) else: self.curA = self.dataA.persist() self.curmap = self.mappings def cb_update_erd(self, val): self.erode = val.new self.processA() self.update_meta() self.plot.object = self.update_plot().object def cb_update_meta(self, dim, val): self.meta[dim] = val.new self.update_meta() self.plot.object = self.update_plot().object def cb_update_rgb(self, ch, ss): self.sess_rgb[ch] = ss.new self.plot.object = self.update_plot().object def cb_showma(self, val): self.show_ma = val.new self.plot.object = self.update_plot().object def cb_showuma(self, val): self.show_uma = val.new self.plot.object = self.update_plot().object
[docs] def show(self) -> pn.layout.Row: """ Return visualizations that can be directly displayed. Returns ------- pn.layout.Row Resulting visualizations containing both plots and toolbars. """ return pn.layout.Row( self.plot, pn.layout.Column(self.wgt_meta, self.wgt_rgb, self.wgt_opt) )
def write_vid_blk(arr, vpath, options): uid = uuid4() vname = "{}.mp4".format(uid) fpath = os.path.join(vpath, vname) if len(arr.shape) == 2: arr = np.expand_dims(arr, axis=0) writer = skvideo.io.FFmpegWriter( fpath, outputdict={"-" + k: v for k, v in options.items()} ) for fm in arr: writer.writeFrame(fm) writer.close() return fpath
[docs]def write_video( arr: xr.DataArray, vname: Optional[str] = None, vpath: Optional[str] = ".", norm=True, options={"crf": "18", "preset": "ultrafast"}, ) -> str: """ Write a video from a movie array using `python-ffmpeg`. Parameters ---------- arr : xr.DataArray Input movie array. Should have dimensions: ("frame", "height", "width") and should only be chunked along the "frame" dimension. vname : str, optional The name of output video. If `None` then a random one will be generated using :func:`uuid4.uuid`. By default `None`. vpath : str, optional The path to the folder containing the video. By default `"."`. norm : bool, optional Whether to normalize the values of the input array such that they span the full pixel depth range (0, 255). By default `True`. options : dict, optional Optional output arguments passed to `ffmpeg`. By default `{"crf": "18", "preset": "ultrafast"}`. Returns ------- fname : str The absolute path to the video file. See Also -------- ffmpeg.output """ if not vname: vname = "{}.mp4".format(uuid4()) fname = os.path.join(vpath, vname) if norm: arr_opt = fct.partial( custom_arr_optimize, rename_dict={"rechunk": "merge_restricted"} ) with dask.config.set(array_optimize=arr_opt): arr = arr.astype(np.float32) arr_max = arr.max().compute().values arr_min = arr.min().compute().values den = arr_max - arr_min arr -= arr_min arr /= den arr *= 255 arr = arr.clip(0, 255).astype(np.uint8) w, h = arr.sizes["width"], arr.sizes["height"] process = ( ffmpeg.input("pipe:", format="rawvideo", pix_fmt="gray", s="{}x{}".format(w, h)) .filter("pad", int(np.ceil(w / 2) * 2), int(np.ceil(h / 2) * 2)) .output(fname, pix_fmt="yuv420p", vcodec="libx264", r=30, **options) .overwrite_output() .run_async(pipe_stdin=True) ) for blk in arr.data.blocks: process.stdin.write(np.array(blk).tobytes()) process.stdin.close() process.wait() return fname
def concat_video_recursive(vlist, vname=None): if not len(vlist) > 1: return vlist[0] if len(vlist) > 256: vlist = np.array_split(vlist, 256) vlist = [concat_video_recursive(list(v)) for v in vlist] vpath = os.path.dirname(vlist[0]) streams = [ffmpeg.input(p) for p in vlist] if vname is None: vname = "{}.mp4".format(uuid4()) fpath = os.path.join(vpath, vname) ffmpeg.concat(*streams).output(fpath).run(overwrite_output=True) for vp in vlist: os.remove(vp) return fpath
[docs]def generate_videos( varr: xr.DataArray, Y: xr.DataArray, A: Optional[xr.DataArray] = None, C: Optional[xr.DataArray] = None, AC: Optional[xr.DataArray] = None, nfm_norm: int = None, gain=1.5, vpath=".", vname="minian.mp4", options={"crf": "18", "preset": "ultrafast"}, ) -> str: """ Generate a video visualizaing the result of minian pipeline. The resulting video contains four parts: Top left is a original reference movie supplied as `varr`; Top right is the input to CNMF algorithm supplied as `Y`; Bottom right is a movie `AC` representing cellular activities as computed by :func:`minian.cnmf.compute_AtC`; Bottom left is a residule movie computed as the difference between `Y` and `AC`. Since the CNMF algorithm contains various arbitrary scaling process, a normalizing scalar is computed with least square using a subset of frames from `Y` and `AC` such that their numerical values matches. Parameters ---------- varr : xr.DataArray Input reference movie data. Should have dimensions ("frame", "height", "width"), and should only be chunked along "frame" dimension. Y : xr.DataArray Movie data representing input to CNMF algorithm. Should have dimensions ("frame", "height", "width"), and should only be chunked along "frame" dimension. A : xr.DataArray, optional Spatial footprints of cells. Only used if `AC` is `None`. By default `None`. C : xr.DataArray, optional Temporal activities of cells. Only used if `AC` is `None`. By default `None`. AC : xr.DataArray, optional Spatial-temporal activities of cells. Should have dimensions ("frame", "height", "width"), and should only be chunked along "frame" dimension. If `None` then both `A` and `C` should be supplied and :func:`minian.cnmf.compute_AtC` will be used to compute this variable. By default `None`. nfm_norm : int, optional Number of frames to randomly draw from `Y` and `AC` to compute the normalizing factor with least square. By default `None`. gain : float, optional A gain factor multiplied to `Y`. Useful to make the results visually brighter. By default `1.5`. vpath : str, optional Desired folder containing the resulting video. By default `"."`. vname : str, optional Desired name of the video. By default `"minian.mp4"`. options : dict, optional Output options for `ffmpeg`, passed directly to :func:`write_video`. By default `{"crf": "18", "preset": "ultrafast"}`. Returns ------- fname : str Absolute path of the resulting video. """ if AC is None: print("generating traces") AC = compute_AtC(A, C) print("normalizing") gain = 255 / Y.max().compute().values * gain Y = Y * gain if nfm_norm is not None: norm_idx = np.sort( np.random.choice(np.arange(Y.sizes["frame"]), size=nfm_norm, replace=False) ) Y_sub = Y.isel(frame=norm_idx).values.reshape(-1) AC_sub = scisps.csc_matrix(AC.isel(frame=norm_idx).values.reshape((-1, 1))) lsqr = scisps.linalg.lsqr(AC_sub, Y_sub) norm_factor = lsqr[0].item() del Y_sub, AC_sub else: norm_factor = gain AC = AC * norm_factor res = Y - AC print("writing videos") vid = xr.concat( [ xr.concat([varr, Y], "width", coords="minimal"), xr.concat([res, AC], "width", coords="minimal"), ], "height", coords="minimal", ) return write_video(vid, vname, vpath, norm=False, options=options)
[docs]def datashade_ndcurve( ovly: hv.NdOverlay, kdim: Optional[Union[str, List[str]]] = None, spread=False ) -> hv.Overlay: """ Apply datashading to an overlay of curves with legends. Parameters ---------- ovly : hv.NdOverlay The input overlay of curves. kdim : Union[str, List[str]], optional Key dimensions of the overlay. If `None` then the first key dimension of `ovly` will be used. By default `None`. spread : bool, optional Whether to apply :func:`holoviews.operation.datashader.dynspread` to the result. By default `False`. Returns ------- hvres : hv.Overlay Resulting overlay of datashaded curves and points (for legends). """ if not kdim: kdim = ovly.kdims[0].name var = np.unique(ovly.dimension_values(kdim)).tolist() color_key = [(v, Category10_10[iv]) for iv, v in enumerate(var)] color_pts = hv.NdOverlay( { k: hv.Points([0, 0], label=str(k)).opts(style=dict(color=v)) for k, v in color_key } ) ds_ovly = datashade( ovly, aggregator=count_cat(kdim), color_key=dict(color_key), min_alpha=200, normalization="linear", ) if spread: ds_ovly = dynspread(ds_ovly) return ds_ovly * color_pts
[docs]def construct_G(g: np.ndarray, T: np.ndarray) -> np.ndarray: """ Construct a convolving matrix from AR coefficients. Parameters ---------- g : np.ndarray Input AR coefficients. T : np.ndarray Number of time samples of the AR process. Returns ------- G : np.ndarray A `T` x `T` matrix that can be used to multiply with a timeseries to convolve the AR process. See Also -------- minian.cnmf.update_temporal : for more background on the role of AR process in the pipeline """ cur_c, cur_r = np.zeros(T), np.zeros(T) cur_c[0] = 1 cur_r[0] = 1 cur_c[1 : len(g) + 1] = -g return linalg.toeplitz(cur_c, cur_r)
[docs]def normalize(a: np.ndarray) -> np.ndarray: """ Normalize an input array to range (0, 1) using :func:`numpy.interp`. Parameters ---------- a : np.ndarray Input array. Returns ------- a_norm : np.ndarray Normalized array. """ return np.interp(a, (np.nanmin(a), np.nanmax(a)), (0, +1))
[docs]def norm(a: np.ndarray) -> np.ndarray: """ Normalize an input array to range (0, 1) avoiding division-by-zero. Parameters ---------- a : np.ndarray Input array. Returns ------- a_norm : np.ndarray Normalized array. If there is only one unique value in `a` then it is returned unchanged. """ amax = np.nanmax(a) amin = np.nanmin(a) diff = amax - amin if diff > 0: return (a - amin) / (amax - amin) else: return a
[docs]def convolve_G(s: np.ndarray, g: np.ndarray) -> np.ndarray: """ Convolve an AR process to input timeseries. Despite the name, only AR coefficients are needed as input. The convolving matrix will be computed using :func:`construct_G`. Parameters ---------- s : np.ndarray The input timeseries, presumably representing spike signals. g : np.ndarray The AR coefficients. Returns ------- c : np.ndarray Convolved timeseries, presumably representing calcium dynamics. See Also -------- minian.cnmf.update_temporal : for more background on the role of AR process in the pipeline """ G = construct_G(g, len(s)) try: c = np.linalg.inv(G).dot(s) except np.linalg.LinAlgError: c = s.copy() return c
[docs]def construct_pulse_response( g: np.ndarray, length=500 ) -> Tuple[np.ndarray, np.ndarray]: """ Construct a model pulse response corresponding to certain AR coefficients. Parameters ---------- g : np.ndarray The AR coefficients. length : int, optional Number of timepoints in output. By default `500`. Returns ------- s : np.ndarray Model spike with shape `(length,)`, zero everywhere except the first timepoint. c : np.ndarray Model convolved calcium response, with same shape as `s`. See Also -------- minian.cnmf.update_temporal : for more background on the role of AR process in the pipeline """ s = np.zeros(length) s[np.arange(0, length, 500)] = 1 c = convolve_G(s, g) return s, c
[docs]def centroid(A: xr.DataArray, verbose=False) -> pd.DataFrame: """ Compute centroids of spatial footprint of each cell. Parameters ---------- A : xr.DataArray Input spatial footprints. verbose : bool, optional Whether to print message and progress bar. By default `False`. Returns ------- cents_df : pd.DataFrame Centroid of spatial footprints for each cell. Has columns "unit_id", "height", "width" and any other additional metadata dimension. """ def rel_cent(im): im_nan = np.isnan(im) if im_nan.all(): return np.array([np.nan, np.nan]) if im_nan.any(): im = np.nan_to_num(im) cent = np.array(center_of_mass(im)) return cent / im.shape gu_rel_cent = da.gufunc( rel_cent, signature="(h,w)->(d)", output_dtypes=float, output_sizes=dict(d=2), vectorize=True, ) cents = xr.apply_ufunc( gu_rel_cent, A.chunk(dict(height=-1, width=-1)), input_core_dims=[["height", "width"]], output_core_dims=[["dim"]], dask="allowed", ).assign_coords(dim=["height", "width"]) if verbose: print("computing centroids") with ProgressBar(): cents = cents.compute() cents_df = ( cents.rename("cents") .to_series() .dropna() .unstack("dim") .rename_axis(None, axis="columns") .reset_index() ) h_rg = (A.coords["height"].min().values, A.coords["height"].max().values) w_rg = (A.coords["width"].min().values, A.coords["width"].max().values) cents_df["height"] = cents_df["height"] * (h_rg[1] - h_rg[0]) + h_rg[0] cents_df["width"] = cents_df["width"] * (w_rg[1] - w_rg[0]) + w_rg[0] return cents_df
[docs]def visualize_preprocess( fm: xr.DataArray, fn: Optional[Callable] = None, include_org=True, **kwargs ) -> hv.HoloMap: """ Generalized visualization of preprocessing functions. This function facilitates parameter exploration of preprocessing functions by plotting a single frame before and after the application of the function, along with a contour plot. All keyword arguments not listed below are passed directly to `fn`. Parameters ---------- fm : xr.DataArray The input frame. fn : Callable, optional The function to apply. If `None` then the original frame are visualized unchanged. By default `None`. include_org : bool, optional Whether to include the original frame in the visualization. By default `True`. Returns ------- hvres : hv.HoloMap The resulting visualization containing images and contour plots. See Also -------- minian.preprocessing """ fh, fw = fm.sizes["height"], fm.sizes["width"] asp = fw / fh opts_im = { "plot": { "frame_width": 500, "aspect": asp, "title": "Image {label} {group} {dimensions}", }, "style": {"cmap": "viridis"}, } opts_cnt = { "plot": { "frame_width": 500, "aspect": asp, "title": "Contours {label} {group} {dimensions}", }, "style": {"cmap": "viridis"}, } def _vis(f): im = hv.Image(f, kdims=["width", "height"]).opts(**opts_im) cnt = hv.operation.contours(im).opts(**opts_cnt) return im, cnt if fn is not None: pkey = kwargs.keys() pval = kwargs.values() im_dict = dict() cnt_dict = dict() for params in itt.product(*pval): fm_res = fn(fm, **dict(zip(pkey, params))) cur_im, cur_cnt = _vis(fm_res) cur_im = cur_im.relabel("After") cur_cnt = cur_cnt.relabel("After") p_str = tuple( [str(p) if not isinstance(p, (int, float)) else p for p in params] ) im_dict[p_str] = cur_im cnt_dict[p_str] = cur_cnt hv_im = Dynamic(hv.HoloMap(im_dict, kdims=list(pkey)).opts(**opts_im)) hv_cnt = datashade( hv.HoloMap(cnt_dict, kdims=list(pkey)), precompute=True, cmap=Viridis256 ).opts(**opts_cnt) if include_org: im, cnt = _vis(fm) im = im.relabel("Before").opts(**opts_im) cnt = ( datashade(cnt, precompute=True, cmap=Viridis256) .relabel("Before") .opts(**opts_cnt) ) return (im + cnt + hv_im + hv_cnt).cols(2) else: im, cnt = _vis(fm) im = im.relabel("Before") cnt = cnt.relabel("Before") return im + cnt
[docs]def visualize_seeds( max_proj: xr.DataArray, seeds: pd.DataFrame, mask: Optional[str] = None ) -> hv.Overlay: """ Visualization of seeds. This function plot seeds on top of a max projection. It can also visualize certain refining step of seeds by coloring the filtered-out seeds in red. Parameters ---------- max_proj : xr.DataArray Max projection used as the background of the plot. seeds : pd.DataFrame The seed dataframe. mask : str, optional The name of the mask of seeds to visualize. If specified, then `seeds` must contain a boolean column with the same name. By default `None`. Returns ------- hvres : hv.Overlay The resuling overlay of seeds and max projection. See Also -------- minian.initialization """ h, w = max_proj.sizes["height"], max_proj.sizes["width"] asp = w / h pt_cmap = {True: "white", False: "red"} opts_im = dict(plot=dict(frame_width=600, aspect=asp), style=dict(cmap="Viridis")) opts_pts = dict( plot=dict( frame_width=600, aspect=asp, size_index="seeds", color_index=mask, tools=["hover"], ), style=dict(fill_alpha=0.8, line_alpha=0, cmap=pt_cmap), ) if mask: vdims = ["seeds", mask] else: vdims = ["seeds"] opts_pts["style"]["color"] = "white" im = hv.Image(max_proj, kdims=["width", "height"]) pts = hv.Points(seeds, kdims=["width", "height"], vdims=vdims) return im.opts(**opts_im) * pts.opts(**opts_pts)
[docs]def visualize_gmm_fit( values: np.ndarray, gmm: sklearn.mixture.GaussianMixture, bins: int ) -> hv.Overlay: """ Visualization of the Gaussian mixture model fit. This function visualize GMM fit by plotting the fitted gaussian curves on top of the histograms of values. Parameters ---------- values : np.ndarray The raw values to which GMM is fitted. gmm : sklearn.mixture.GaussianMixture The fitted GMM model object. bins : int Number of bins when plotting the histogram. Returns ------- hvres : hv.Overlay The resulting visualization. See Also -------- minian.initialization.gmm_refine """ def gaussian(x, mu, sig): return np.exp(-np.power(x - mu, 2.0) / (2 * np.power(sig, 2.0))) hist = np.histogram(values, bins=bins, density=True) gss_dict = dict() for igss, (mu, sig) in enumerate(zip(gmm.means_, gmm.covariances_)): gss = gaussian(hist[1], np.asscalar(mu), np.asscalar(np.sqrt(sig))) gss_dict[igss] = hv.Curve((hist[1], gss)) return ( hv.Histogram(((hist[0] - hist[0].min()) / np.ptp(hist[0]), hist[1])).opts( style=dict(alpha=0.6, fill_color="gray") ) * hv.NdOverlay(gss_dict) ).opts(plot=dict(height=350, width=500))
[docs]def visualize_spatial_update( A_dict: dict, C_dict: dict, kdims: Optional[Union[str, List[str]]] = None, norm=True, datashading=True, ) -> hv.HoloMap: """ Visualization of spatial update. This function facilitates parameter exploration for spatial update by plotting the resulting spatial footprints and binarized spatial footprints from different runs of spatial update for a subset of cells, along with their corresponding temporal activities. Parameters ---------- A_dict : dict A dictionary containing resulting spatial footprints from different runs of spatial update. Keys should be tuple containing the values of parameters that uniquely identify each run. Values should be spatial footprints of type `xr.DataArray`. C_dict : dict A dictionary containing temporal activities of each cells in the same format as `A_dict`. The temporal activities of cells are not expected to change across different runs of spatial update, except the number of cells may be different due to dropping of cells in the update process. kdims : Union[str, List[str]], optional Names of key dimensions identifying the parameter space. Should have same length as the keys in `A_dict` and `C_dict`. If `None` then a dimension names "dummy" will be created and the visualization can be used to visualize restults across cells. By default `None`. norm : bool, optional Whether to normalize the temporal activities of each cell to range (0, 1) for visualization. By default `True`. datashading : bool, optional Whether to apply datashading to temporal activities of cells. By default `True`. Returns ------- hvres : hv.HoloMap Resulting visualization. See Also -------- minian.cnmf.update_spatial """ if not kdims: A_dict = dict(dummy=A_dict) C_dict = dict(dummy=C_dict) hv_pts_dict, hv_A_dict, hv_Ab_dict, hv_C_dict = (dict(), dict(), dict(), dict()) for key, A in A_dict.items(): A = A.compute() C = C_dict[key] if norm: C = xr.apply_ufunc( normalize, C.chunk(dict(frame=-1)), input_core_dims=[["frame"]], output_core_dims=[["frame"]], vectorize=True, dask="parallelized", output_dtypes=[C.dtype], ) C = C.compute() h, w = A.sizes["height"], A.sizes["width"] cents_df = centroid(A) hv_pts_dict[key] = hv.Points( cents_df, kdims=["width", "height"], vdims=["unit_id"] ).opts( plot=dict(tools=["hover"]), style=dict(fill_alpha=0.2, line_alpha=0, size=8) ) hv_A_dict[key] = hv.Image( A.sum("unit_id").rename("A"), kdims=["width", "height"] ) hv_Ab_dict[key] = hv.Image( (A > 0).sum("unit_id").rename("A_bin"), kdims=["width", "height"] ) hv_C_dict[key] = hv.Dataset(C.rename("C")).to(hv.Curve, kdims="frame") hv_pts = Dynamic(hv.HoloMap(hv_pts_dict, kdims=kdims)) hv_A = Dynamic(hv.HoloMap(hv_A_dict, kdims=kdims)) hv_Ab = Dynamic(hv.HoloMap(hv_Ab_dict, kdims=kdims)) hv_C = ( hv.HoloMap(hv_C_dict, kdims=kdims) .collate() .grid("unit_id") .add_dimension("time", 0, 0) ) if datashading: hv_C = datashade(hv_C) else: hv_C = Dynamic(hv_C) hv_A = hv_A.opts(frame_width=400, aspect=w / h, colorbar=True, cmap="viridis") hv_Ab = hv_Ab.opts(frame_width=400, aspect=w / h, colorbar=True, cmap="viridis") hv_C = hv_C.map( lambda cr: cr.opts(frame_width=500, frame_height=50), hv.RGB if datashading else hv.Curve, ) return ( hv.NdLayout( {"pseudo-color": (hv_pts * hv_A), "binary": (hv_pts * hv_Ab)}, kdims="Spatial Matrix", ).cols(1) + hv_C.relabel("Temporal Components") )
[docs]def visualize_temporal_update( YA_dict: dict, C_dict: dict, S_dict: dict, g_dict: dict, sig_dict: dict, A_dict: dict, kdims: Optional[Union[str, List[str]]] = None, norm=True, datashading=True, ) -> hv.HoloMap: """ Visualization of temporal update. This function facilitates parameter exploration for temporal update by plotting various temporal traces along with a model calcium response and the spatial footprint for each cell across different runs of temporal update. Four traces are plotted: "Raw Signal" correspond to the `YrA` variable, "Fitted Calcium Trace" correspond to `C` after update, "Fitted Spikes" correspond to `S` after update, and "Fitted Signal" correspond to `C + b0 + c0` after update. See :func:`minian.cnmf.update_temporal` for interpretation of each variable. Parameters ---------- YA_dict : dict A dictionary containing the `YrA` variables in the same format as `C_dict`. The `YrA` variable is not updated and is not expected to be different across different runs of temporal update. C_dict : dict A dictionary containing resulting calcium traces (`C_new`) from different runs of temporal update. Keys should be tuple containing the values of parameters that uniquely identify each run. Values should be temporal traces of type `xr.DataArray`. S_dict : dict A dictionary containing resulting deconvolved spike traces (`S_new`) from different runs of temporal update, in the same format as `C_dict`. g_dict : dict A dictionary containing resulting AR coefficients (`g`) from different runs of temporal update, in the same format as `C_dict`. sig_dict : dict A dictionary containing resulting fitted signals (`C_new + b0_new + c0_new`) from different runs of temporal update, in the same format as `C_dict`. A_dict : dict A dictionary containing spatial footprint of cells in the same format as `C_dict`. The spatial footprints of cells are note expected to change across different runs of temporal update, except the number of cells may be different due to dropping of cells in the update process. kdims : Union[str, List[str]], optional Names of key dimensions identifying the parameter space. Should have same length as the keys in `C_dict` etc. If `None` then a dimension names "dummy" will be created and the visualization can be used to visualize restults across cells. By default `None`. norm : bool, optional Whether to normalize the temporal activities of each cell to range (0, 1) for visualization. By default `True`. datashading : bool, optional Whether to apply datashading to temporal activities of cells. By default `True`. Returns ------- hvres : hv.HoloMap Resulting visualization. See Also -------- minian.cnmf.update_temporal """ inputs = [YA_dict, C_dict, S_dict, sig_dict, g_dict] if not kdims: inputs = [dict(dummy=i) for i in inputs] A_dict = dict(dummy=A_dict) input_dict = {k: [i[k] for i in inputs] for k in inputs[0].keys()} hv_YA, hv_C, hv_S, hv_sig, hv_C_pul, hv_S_pul, hv_A = [dict() for _ in range(7)] for k, ins in input_dict.items(): if norm: ins[:-1] = [ xr.apply_ufunc( normalize, i.chunk(dict(frame=-1)), input_core_dims=[["frame"]], output_core_dims=[["frame"]], vectorize=True, dask="parallelized", output_dtypes=[i.dtype], ) for i in ins[:-1] ] ins[:] = [i.compute() for i in ins] ya, c, s, sig, g = ins f_crd = ya.coords["frame"] pul_crd = f_crd.values[:500] s_pul, c_pul = xr.apply_ufunc( construct_pulse_response, g, input_core_dims=[["lag"]], output_core_dims=[["t"], ["t"]], vectorize=True, kwargs=dict(length=len(pul_crd)), output_sizes=dict(t=len(pul_crd)), ) s_pul, c_pul = (s_pul.assign_coords(t=pul_crd), c_pul.assign_coords(t=pul_crd)) if norm: c_pul = xr.apply_ufunc( normalize, c_pul.chunk(dict(t=-1)), input_core_dims=[["t"]], output_core_dims=[["t"]], dask="parallelized", output_dtypes=[c_pul.dtype], ).compute() pul_range = ( f_crd.min(), int(np.around(f_crd.min() + (f_crd.max() - f_crd.min()) / 2)), ) hv_S_pul[k], hv_C_pul[k] = [ (hv.Dataset(tr.rename("Response (A.U.)")).to(hv.Curve, kdims=["t"])) for tr in [s_pul, c_pul] ] hv_YA[k] = hv.Dataset(ya.rename("Intensity (A.U.)")).to( hv.Curve, kdims=["frame"] ) if c.sizes["unit_id"] > 0: hv_C[k], hv_S[k], hv_sig[k] = [ ( hv.Dataset(tr.rename("Intensity (A.U.)")).to( hv.Curve, kdims=["frame"] ) ) for tr in [c, s, sig] ] hv_A[k] = hv.Dataset(A_dict[k].rename("A")).to( hv.Image, kdims=["width", "height"] ) h, w = A_dict[k].sizes["height"], A_dict[k].sizes["width"] hvobjs = [hv_YA, hv_C, hv_S, hv_sig, hv_C_pul, hv_S_pul, hv_A] hvobjs[:] = [hv.HoloMap(hvobj, kdims=kdims).collate() for hvobj in hvobjs] hv_unit = { "Raw Signal": hvobjs[0], "Fitted Calcium Trace": hvobjs[1], "Fitted Spikes": hvobjs[2], "Fitted Signal": hvobjs[3], } hv_pul = {"Simulated Calcium": hvobjs[4], "Simulated Spike": hvobjs[5]} hv_unit = hv.HoloMap(hv_unit, kdims="traces").collate().overlay("traces") hv_pul = hv.HoloMap(hv_pul, kdims="traces").collate().overlay("traces") hv_A = Dynamic(hvobjs[6]) if datashading: hv_unit = datashade_ndcurve(hv_unit, "traces") else: hv_unit = Dynamic(hv_unit) hv_pul = Dynamic(hv_pul) hv_unit = hv_unit.map( lambda p: p.opts(plot=dict(frame_height=400, frame_width=1000)) ) hv_pul = hv_pul.opts(plot=dict(frame_width=500, aspect=w / h)).redim( t=hv.Dimension("t", soft_range=pul_range) ) hv_A = hv_A.opts( plot=dict(frame_width=500, aspect=w / h), style=dict(cmap="Viridis") ) return ( hv_unit.relabel("Current Unit: Temporal Traces") + hv.NdLayout( {"Simulated Pulse Response": hv_pul, "Spatial Footprint": hv_A}, kdims="Current Unit", ) ).cols(1)
[docs]def NNsort(cents: pd.DataFrame) -> pd.Series: """ Sort centroids of cells into close-by groups. Walk through centroids of cells using a nearest neighbors tree such that the resulting walk order can be used to sort cells into close-by groups. Parameters ---------- cents : pd.DataFrame Input centroids of cells. Should contain column "height" and "width". Returns ------- result : pd.Series A series with same index as input `cents` whose values represent the order of nearest-neighbor walk. """ cents_hw = cents[["height", "width"]] kdtree = cKDTree(cents_hw) idu_start = cents_hw.sum(axis="columns").idxmin() result = pd.Series(0, index=cents.index) remain_list = cents.index.tolist() idu_next = idu_start NNord = 0 while remain_list: result.loc[idu_next] = NNord remain_list.remove(idu_next) for k in range(1, int(np.ceil(np.log2(len(result)))) + 1): qry = kdtree.query(cents_hw.loc[idu_next], 2 ** k) NNs = qry[1][np.isfinite(qry[0])].squeeze() NNs = NNs[np.sort(np.unique(NNs, return_index=True)[1])] NNs = np.array(result.iloc[NNs].index) NN_idxs = np.argwhere(np.isin(NNs, remain_list, assume_unique=True)) if len(NN_idxs) > 0: NN = NNs[NN_idxs[0]][0] idu_next = NN NNord = NNord + 1 break return result
[docs]def visualize_motion(motion: xr.DataArray) -> Union[hv.Layout, hv.NdOverlay]: """ Visualize result of motion estimation. This function plot motions across time. If the input has two dimensions, they are interpreted as rigid shifts along the "height" and "width" dimension of the movie, and plotted as curves across time. If the input has more than two dimensions, it is assumed that non-rigid motion estimation was enabled and each frame is split into several patches that will each have their own shifts. The separate shifts for patches within each frame are flattened into a column, then shifts along "height" and "width" dimensions are separately plotted as 2d images across time, whose columns represent frames and colors represent degree of shift. Parameters ---------- motion : xr.DataArray Estimated motion. Returns ------- Union[hv.Layout, hv.NdOverlay] If `motion` contains rigid shifts, then an overlay of two curves are returned. Otherwise two images representing non-rigid motions are returned. """ if motion.ndim > 2: opts_im = { "frame_width": 500, "aspect": 3, "cmap": "RdBu", "symmetric": True, "colorbar": True, } mheight = motion.sel(shift_dim="height").stack(grid=["grid0", "grid1"]) mwidth = motion.sel(shift_dim="width").stack(grid=["grid0", "grid1"]) mheight = mheight.assign_coords(grid=np.arange(mheight.sizes["grid"])) mwidth = mwidth.assign_coords(grid=np.arange(mwidth.sizes["grid"])) return ( ( hv.Image(mheight.rename("height_motion"), kdims=["frame", "grid"]).opts( title="height_motion", **opts_im ) + hv.Image(mwidth.rename("width_motion"), kdims=["frame", "grid"]).opts( title="width_motion", **opts_im ) ) .cols(1) .opts(show_title=True) ) else: opts_cv = {"frame_width": 500, "tools": ["hover"], "aspect": 2} return hv.NdOverlay( dict( width=hv.Curve(motion.sel(shift_dim="width")).opts(**opts_cv), height=hv.Curve(motion.sel(shift_dim="height")).opts(**opts_cv), ) )