import functools as fct
import os
import re
import shutil
import warnings
from copy import deepcopy
from os import listdir
from os.path import isdir, isfile
from os.path import join as pjoin
from pathlib import Path
from typing import Callable, List, Optional, Union
from uuid import uuid4
import _operator
import cv2
import dask as da
import dask.array as darr
import ffmpeg
import numpy as np
import pandas as pd
import rechunker
import xarray as xr
import zarr as zr
from dask.core import flatten
from dask.delayed import optimize as default_delay_optimize
from dask.optimization import cull, fuse, inline, inline_functions
from dask.utils import ensure_dict
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.scheduler import SchedulerState, cast
from natsort import natsorted
from scipy.ndimage.filters import median_filter
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import lsqr
from tifffile import TiffFile, imread
[docs]def load_videos(
vpath: str,
pattern=r"msCam[0-9]+\.avi$",
dtype: Union[str, type] = np.float64,
downsample: Optional[dict] = None,
downsample_strategy="subset",
post_process: Optional[Callable] = None,
) -> xr.DataArray:
"""
Load multiple videos in a folder and return a `xr.DataArray`.
Load videos from the folder specified in `vpath` and according to the regex
`pattern`, then concatenate them together and return a `xr.DataArray`
representation of the concatenated videos. The videos are sorted by
filenames with :func:`natsort.natsorted` before concatenation. Optionally
the data can be downsampled, and the user can pass in a custom callable to
post-process the result.
Parameters
----------
vpath : str
The path containing the videos to load.
pattern : regexp, optional
The regexp matching the filenames of the videso. By default
`r"msCam[0-9]+\.avi$"`, which can be interpreted as filenames starting
with "msCam" followed by at least a number, and then followed by ".avi".
dtype : Union[str, type], optional
Datatype of the resulting DataArray, by default `np.float64`.
downsample : dict, optional
A dictionary mapping dimension names to an integer downsampling factor.
The dimension names should be one of "height", "width" or "frame". By
default `None`.
downsample_strategy : str, optional
How the downsampling should be done. Only used if `downsample` is not
`None`. Either `"subset"` where data points are taken at an interval
specified in `downsample`, or `"mean"` where mean will be taken over
data within each interval. By default `"subset"`.
post_process : Callable, optional
An user-supplied custom function to post-process the resulting array.
Four arguments will be passed to the function: the resulting DataArray
`varr`, the input path `vpath`, the list of matched video filenames
`vlist`, and the list of DataArray before concatenation `varr_list`. The
function should output another valide DataArray. In other words, the
function should have signature `f(varr: xr.DataArray, vpath: str, vlist:
List[str], varr_list: List[xr.DataArray]) -> xr.DataArray`. By default
`None`
Returns
-------
varr : xr.DataArray
The resulting array representation of the input movie. Should have
dimensions ("frame", "height", "width").
Raises
------
FileNotFoundError
if no files under `vpath` match the pattern `pattern`
ValueError
if the matched files does not have extension ".avi", ".mkv" or ".tif"
NotImplementedError
if `downsample_strategy` is not "subset" or "mean"
"""
vpath = os.path.normpath(vpath)
vlist = natsorted(
[vpath + os.sep + v for v in os.listdir(vpath) if re.search(pattern, v)]
)
if not vlist:
raise FileNotFoundError(
"No data with pattern {}"
" found in the specified folder {}".format(pattern, vpath)
)
print("loading {} videos in folder {}".format(len(vlist), vpath))
file_extension = os.path.splitext(vlist[0])[1]
if file_extension in (".avi", ".mkv"):
movie_load_func = load_avi_lazy
elif file_extension == ".tif":
movie_load_func = load_tif_lazy
else:
raise ValueError("Extension not supported.")
varr_list = [movie_load_func(v) for v in vlist]
varr = darr.concatenate(varr_list, axis=0)
varr = xr.DataArray(
varr,
dims=["frame", "height", "width"],
coords=dict(
frame=np.arange(varr.shape[0]),
height=np.arange(varr.shape[1]),
width=np.arange(varr.shape[2]),
),
)
if dtype:
varr = varr.astype(dtype)
if downsample:
if downsample_strategy == "mean":
varr = varr.coarsen(**downsample, boundary="trim", coord_func="min").mean()
elif downsample_strategy == "subset":
varr = varr.isel(**{d: slice(None, None, w) for d, w in downsample.items()})
else:
raise NotImplementedError("unrecognized downsampling strategy")
varr = varr.rename("fluorescence")
if post_process:
varr = post_process(varr, vpath, vlist, varr_list)
arr_opt = fct.partial(custom_arr_optimize, keep_patterns=["^load_avi_ffmpeg"])
with da.config.set(array_optimize=arr_opt):
varr = da.optimize(varr)[0]
return varr
[docs]def load_tif_lazy(fname: str) -> darr.array:
"""
Lazy load a tif stack of images.
Parameters
----------
fname : str
The filename of the tif stack to load.
Returns
-------
arr : darr.array
Resulting dask array representation of the tif stack.
"""
data = TiffFile(fname)
f = len(data.pages)
fmread = da.delayed(load_tif_perframe)
flist = [fmread(fname, i) for i in range(f)]
sample = flist[0].compute()
arr = [
da.array.from_delayed(fm, dtype=sample.dtype, shape=sample.shape)
for fm in flist
]
return da.array.stack(arr, axis=0)
[docs]def load_tif_perframe(fname: str, fid: int) -> np.ndarray:
"""
Load a single image from a tif stack.
Parameters
----------
fname : str
The filename of the tif stack.
fid : int
The index of the image to load.
Returns
-------
arr : np.ndarray
Array representation of the image.
"""
return imread(fname, key=fid)
def load_avi_lazy_framewise(fname: str) -> darr.array:
cap = cv2.VideoCapture(fname)
f = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fmread = da.delayed(load_avi_perframe)
flist = [fmread(fname, i) for i in range(f)]
sample = flist[0].compute()
arr = [
da.array.from_delayed(fm, dtype=sample.dtype, shape=sample.shape)
for fm in flist
]
return da.array.stack(arr, axis=0)
[docs]def load_avi_lazy(fname: str) -> darr.array:
"""
Lazy load an avi video.
This function construct a single delayed task for loading the video as a
whole.
Parameters
----------
fname : str
The filename of the video to load.
Returns
-------
arr : darr.array
The array representation of the video.
"""
probe = ffmpeg.probe(fname)
video_info = next(s for s in probe["streams"] if s["codec_type"] == "video")
w = int(video_info["width"])
h = int(video_info["height"])
f = int(video_info["nb_frames"])
return da.array.from_delayed(
da.delayed(load_avi_ffmpeg)(fname, h, w, f), dtype=np.uint8, shape=(f, h, w)
)
[docs]def load_avi_ffmpeg(fname: str, h: int, w: int, f: int) -> np.ndarray:
"""
Load an avi video using `ffmpeg`.
This function directly invoke `ffmpeg` using the `python-ffmpeg` wrapper and
retrieve the data from buffer.
Parameters
----------
fname : str
The filename of the video to load.
h : int
The height of the video.
w : int
The width of the video.
f : int
The number of frames in the video.
Returns
-------
arr : np.ndarray
The resulting array. Has shape (`f`, `h`, `w`).
"""
out_bytes, err = (
ffmpeg.input(fname)
.video.output("pipe:", format="rawvideo", pix_fmt="gray")
.run(capture_stdout=True)
)
return np.frombuffer(out_bytes, np.uint8).reshape(f, h, w)
def load_avi_perframe(fname: str, fid: int) -> np.ndarray:
cap = cv2.VideoCapture(fname)
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
cap.set(cv2.CAP_PROP_POS_FRAMES, fid)
ret, fm = cap.read()
if ret:
return np.flip(cv2.cvtColor(fm, cv2.COLOR_RGB2GRAY), axis=0)
else:
print("frame read failed for frame {}".format(fid))
return np.zeros((h, w))
[docs]def open_minian(
dpath: str, post_process: Optional[Callable] = None, return_dict=False
) -> Union[dict, xr.Dataset]:
"""
Load an existing minian dataset.
If `dpath` is a file, then it is assumed that the full dataset is saved as a
single file, and this function will directly call
:func:`xarray.open_dataset` on `dpath`. Otherwise if `dpath` is a directory,
then it is assumed that the dataset is saved as a directory of `zarr`
arrays, as produced by :func:`save_minian`. This function will then iterate
through all the directories under input `dpath` and load them as
`xr.DataArray` with `zarr` backend, so it is important that the user make
sure every directory under `dpath` can be load this way. The loaded arrays
will be combined as either a `xr.Dataset` or a `dict`. Optionally a
user-supplied custom function can be used to post process the resulting
`xr.Dataset`.
Parameters
----------
dpath : str
The path to the minian dataset that should be loaded.
post_process : Callable, optional
User-supplied function to post process the dataset. Only used if
`return_dict` is `False`. Two arguments will be passed to the function:
the resulting dataset `ds` and the data path `dpath`. In other words the
function should have signature `f(ds: xr.Dataset, dpath: str) ->
xr.Dataset`. By default `None`.
return_dict : bool, optional
Whether to combine the DataArray as dictionary, where the `.name`
attribute will be used as key. Otherwise the DataArray will be combined
using `xr.merge(..., compat="no_conflicts")`, which will implicitly
align the DataArray over all dimensions, so it is important to make sure
the coordinates are compatible and will not result in creation of large
NaN-padded results. Only used if `dpath` is a directory, otherwise a
`xr.Dataset` is always returned. By default `False`.
Returns
-------
ds : Union[dict, xr.Dataset]
The resulting dataset. If `return_dict` is `True` it will be a `dict`,
otherwise a `xr.Dataset`.
See Also
-------
xarray.open_zarr : for how each directory will be loaded as `xr.DataArray`
xarray.merge : for how the `xr.DataArray` will be merged as `xr.Dataset`
"""
if isfile(dpath):
ds = xr.open_dataset(dpath).chunk()
elif isdir(dpath):
dslist = []
for d in listdir(dpath):
arr_path = pjoin(dpath, d)
if isdir(arr_path):
arr = list(xr.open_zarr(arr_path).values())[0]
arr.data = darr.from_zarr(
os.path.join(arr_path, arr.name), inline_array=True
)
dslist.append(arr)
if return_dict:
ds = {d.name: d for d in dslist}
else:
ds = xr.merge(dslist, compat="no_conflicts")
if (not return_dict) and post_process:
ds = post_process(ds, dpath)
return ds
[docs]def open_minian_mf(
dpath: str,
index_dims: List[str],
result_format="xarray",
pattern=r"minian$",
sub_dirs: List[str] = [],
exclude=True,
**kwargs,
) -> Union[xr.Dataset, pd.DataFrame]:
"""
Open multiple minian datasets across multiple directories.
This function recursively walks through directories under `dpath` and try to
load minian datasets from all directories matching `pattern`. It will then
combine them based on `index_dims` into either a `xr.Dataset` object or a
`pd.DataFrame`. Optionally a subset of paths can be specified, so that they
can either be excluded or white-listed. Additional keyword arguments will be
passed directly to :func:`open_minian`.
Parameters
----------
dpath : str
The root folder containing all datasets to be loaded.
index_dims : List[str]
List of dimensions that can be used to index and merge multiple
datasets. All loaded datasets should have unique coordinates in the
listed dimensions.
result_format : str, optional
If `"xarray"`, the result will be merged together recursively along each
dimensions listed in `index_dims`. Users should make sure the
coordinates are compatible and the merging will not cause generation of
large NaN-padded results. If `"pandas"`, then a `pd.DataFrame` is
returned, with columns corresponding to `index_dims` uniquely identify
each dataset, and an additional column named "minian" of object dtype
pointing to the loaded minian dataset objects. By default `"xarray"`.
pattern : regexp, optional
Pattern of minian dataset directory names. By default `r"minian$"`.
sub_dirs : List[str], optional
A list of sub-directories under `dpath`. Useful if only a subset of
datasets under `dpath` should be recursively loaded. By default `[]`.
exclude : bool, optional
Whether to exclude directories listed under `sub_dirs`. If `True`, then
any minian datasets under those specified in `sub_dirs` will be ignored.
If `False`, then **only** the datasets under those specified in
`sub_dirs` will be loaded (they still have to be under `dpath` though).
by default `True`.
Returns
-------
ds : Union[xr.Dataset, pd.DataFrame]
The resulting combined datasets. If `result_format` is `"xarray"`, then
a `xr.Dataset` will be returned, otherwise a `pd.DataFrame` will be
returned.
Raises
------
NotImplementedError
if `result_format` is not "xarray" or "pandas"
"""
minian_dict = dict()
for nextdir, dirlist, filelist in os.walk(dpath, topdown=False):
nextdir = os.path.abspath(nextdir)
cur_path = Path(nextdir)
dir_tag = bool(
(
(any([Path(epath) in cur_path.parents for epath in sub_dirs]))
or nextdir in sub_dirs
)
)
if exclude == dir_tag:
continue
flist = list(filter(lambda f: re.search(pattern, f), filelist + dirlist))
if flist:
print("opening dataset under {}".format(nextdir))
if len(flist) > 1:
warnings.warn("multiple dataset found: {}".format(flist))
fname = flist[-1]
print("opening {}".format(fname))
minian = open_minian(dpath=os.path.join(nextdir, fname), **kwargs)
key = tuple([np.array_str(minian[d].values) for d in index_dims])
minian_dict[key] = minian
print(["{}: {}".format(d, v) for d, v in zip(index_dims, key)])
if result_format == "xarray":
return xrconcat_recursive(minian_dict, index_dims)
elif result_format == "pandas":
minian_df = pd.Series(minian_dict).rename("minian")
minian_df.index.set_names(index_dims, inplace=True)
return minian_df.to_frame()
else:
raise NotImplementedError("format {} not understood".format(result_format))
[docs]def save_minian(
var: xr.DataArray,
dpath: str,
meta_dict: Optional[dict] = None,
overwrite=False,
chunks: Optional[dict] = None,
compute=True,
mem_limit="500MB",
) -> xr.DataArray:
"""
Save a `xr.DataArray` with `zarr` storage backend following minian
conventions.
This function will store arbitrary `xr.DataArray` into `dpath` with `zarr`
backend. A separate folder will be created under `dpath`, with folder name
`var.name + ".zarr"`. Optionally metadata can be retrieved from directory
hierarchy and added as coordinates of the `xr.DataArray`. In addition, an
on-disk rechunking of the result can be performed using
:func:`rechunker.rechunk` if `chunks` are given.
Parameters
----------
var : xr.DataArray
The array to be saved.
dpath : str
The path to the minian dataset directory.
meta_dict : dict, optional
How metadata should be retrieved from directory hierarchy. The keys
should be negative integers representing directory level relative to
`dpath` (so `-1` means the immediate parent directory of `dpath`), and
values should be the name of dimensions represented by the corresponding
level of directory. The actual coordinate value of the dimensions will
be the directory name of corresponding level. By default `None`.
overwrite : bool, optional
Whether to overwrite the result on disk. By default `False`.
chunks : dict, optional
A dictionary specifying the desired chunk size. The chunk size should be
specified using :doc:`dask:array-chunks` convention, except the "auto"
specifiication is not supported. The rechunking operation will be
carried out with on-disk algorithms using :func:`rechunker.rechunk`. By
default `None`.
compute : bool, optional
Whether to compute `var` and save it immediately. By default `True`.
mem_limit : str, optional
The memory limit for the on-disk rechunking algorithm, passed to
:func:`rechunker.rechunk`. Only used if `chunks` is not `None`. By
default `"500MB"`.
Returns
-------
var : xr.DataArray
The array representation of saving result. If `compute` is `True`, then
the returned array will only contain delayed task of loading the on-disk
`zarr` arrays. Otherwise all computation leading to the input `var` will
be preserved in the result.
Examples
-------
The following will save the variable `var` to directory
`/spatial_memory/alpha/learning1/minian/important_array.zarr`, with the
additional coordinates: `{"session": "learning1", "animal": "alpha",
"experiment": "spatial_memory"}`.
>>> save_minian(
... var.rename("important_array"),
... "/spatial_memory/alpha/learning1/minian",
... {-1: "session", -2: "animal", -3: "experiment"},
... ) # doctest: +SKIP
"""
dpath = os.path.normpath(dpath)
Path(dpath).mkdir(parents=True, exist_ok=True)
ds = var.to_dataset()
if meta_dict is not None:
pathlist = os.path.split(os.path.abspath(dpath))[0].split(os.sep)
ds = ds.assign_coords(
**dict([(dn, pathlist[di]) for dn, di in meta_dict.items()])
)
md = {True: "a", False: "w-"}[overwrite]
fp = os.path.join(dpath, var.name + ".zarr")
if overwrite:
try:
shutil.rmtree(fp)
except FileNotFoundError:
pass
arr = ds.to_zarr(fp, compute=compute, mode=md)
if (chunks is not None) and compute:
chunks = {d: var.sizes[d] if v <= 0 else v for d, v in chunks.items()}
dst_path = os.path.join(dpath, str(uuid4()))
temp_path = os.path.join(dpath, str(uuid4()))
with da.config.set(
array_optimize=darr.optimization.optimize,
delayed_optimize=default_delay_optimize,
):
zstore = zr.open(fp)
rechk = rechunker.rechunk(
zstore[var.name], chunks, mem_limit, dst_path, temp_store=temp_path
)
rechk.execute()
try:
shutil.rmtree(temp_path)
except FileNotFoundError:
pass
arr_path = os.path.join(fp, var.name)
for f in os.listdir(arr_path):
os.remove(os.path.join(arr_path, f))
for f in os.listdir(dst_path):
os.rename(os.path.join(dst_path, f), os.path.join(arr_path, f))
os.rmdir(dst_path)
if compute:
arr = xr.open_zarr(fp)[var.name]
arr.data = darr.from_zarr(os.path.join(fp, var.name), inline_array=True)
return arr
[docs]def xrconcat_recursive(var: Union[dict, list], dims: List[str]) -> xr.Dataset:
"""
Recursively concatenate `xr.DataArray` over multiple dimensions.
Parameters
----------
var : Union[dict, list]
Either a `dict` or a `list` of `xr.DataArray` to be concatenated. If a
`dict` then keys should be `tuple`, with length same as the length of
`dims` and values corresponding to the coordinates that uniquely
identify each `xr.DataArray`. If a `list` then each `xr.DataArray`
should contain valid coordinates for each dimensions specified in
`dims`.
dims : List[str]
Dimensions to be concatenated over.
Returns
-------
ds : xr.Dataset
The concatenated dataset.
Raises
------
NotImplementedError
if input `var` is neither a `dict` nor a `list`
"""
if len(dims) > 1:
if type(var) is dict:
var_dict = var
elif type(var) is list:
var_dict = {tuple([np.asscalar(v[d]) for d in dims]): v for v in var}
else:
raise NotImplementedError("type {} not supported".format(type(var)))
try:
var_dict = {k: v.to_dataset() for k, v in var_dict.items()}
except AttributeError:
pass
data = np.empty(len(var_dict), dtype=object)
for iv, ds in enumerate(var_dict.values()):
data[iv] = ds
index = pd.MultiIndex.from_tuples(list(var_dict.keys()), names=dims)
var_ps = pd.Series(data=data, index=index)
xr_ls = []
for idx, v in var_ps.groupby(level=dims[0]):
v.index = v.index.droplevel(dims[0])
xarr = xrconcat_recursive(v.to_dict(), dims[1:])
xr_ls.append(xarr)
return xr.concat(xr_ls, dim=dims[0])
else:
if type(var) is dict:
var = list(var.values())
return xr.concat(var, dim=dims[0])
def update_meta(dpath, pattern=r"^minian\.nc$", meta_dict=None, backend="netcdf"):
for dirpath, dirnames, fnames in os.walk(dpath):
if backend == "netcdf":
fnames = filter(lambda fn: re.search(pattern, fn), fnames)
elif backend == "zarr":
fnames = filter(lambda fn: re.search(pattern, fn), dirnames)
else:
raise NotImplementedError("backend {} not supported".format(backend))
for fname in fnames:
f_path = os.path.join(dirpath, fname)
pathlist = os.path.normpath(dirpath).split(os.sep)
new_ds = xr.Dataset()
old_ds = open_minian(f_path, f_path, backend)
new_ds.attrs = deepcopy(old_ds.attrs)
old_ds.close()
new_ds = new_ds.assign_coords(
**dict(
[(cdname, pathlist[cdval]) for cdname, cdval in meta_dict.items()]
)
)
if backend == "netcdf":
new_ds.to_netcdf(f_path, mode="a")
elif backend == "zarr":
new_ds.to_zarr(f_path, mode="w")
print("updated: {}".format(f_path))
[docs]def get_chk(arr: xr.DataArray) -> dict:
"""
Get chunks of a `xr.DataArray`.
Parameters
----------
arr : xr.DataArray
The input `xr.DataArray`
Returns
-------
chk : dict
Dictionary mapping dimension names to chunks.
"""
return {d: c for d, c in zip(arr.dims, arr.chunks)}
[docs]def rechunk_like(x: xr.DataArray, y: xr.DataArray) -> xr.DataArray:
"""
Rechunk the input `x` such that its chunks are compatible with `y`.
Parameters
----------
x : xr.DataArray
The array to be rechunked.
y : xr.DataArray
The array where chunk information are extracted.
Returns
-------
x_chk : xr.DataArray
The rechunked `x`.
"""
try:
dst_chk = get_chk(y)
comm_dim = set(x.dims).intersection(set(dst_chk.keys()))
dst_chk = {d: max(dst_chk[d]) for d in comm_dim}
return x.chunk(dst_chk)
except TypeError:
return x.compute()
[docs]def get_optimal_chk(
arr: xr.DataArray,
dim_grp=[("frame",), ("height", "width")],
csize=256,
dtype: Optional[type] = None,
) -> dict:
"""
Compute the optimal chunk size across all dimensions of the input array.
This function use `dask` autochunking mechanism to determine the optimal
chunk size of an array. The difference between this and directly using
"auto" as chunksize is that it understands which dimensions are usually
chunked together with the help of `dim_grp`. It also support computing
chunks for custom `dtype` and explicit requirement of chunk size.
Parameters
----------
arr : xr.DataArray
The input array to estimate for chunk size.
dim_grp : list, optional
List of tuples specifying which dimensions are usually chunked together
during computation. For each tuple in the list, it is assumed that only
dimensions in the tuple will be chunked while all other dimensions in
the input `arr` will not be chunked. Each dimensions in the input `arr`
should appear once and only once across the list. By default
`[("frame",), ("height", "width")]`.
csize : int, optional
The desired space each chunk should occupy, specified in MB. By default
`256`.
dtype : type, optional
The datatype of `arr` during actual computation in case that will be
different from the current `arr.dtype`. By default `None`.
Returns
-------
chk : dict
Dictionary mapping dimension names to chunk sizes.
"""
if dtype is not None:
arr = arr.astype(dtype)
dims = arr.dims
if not dim_grp:
dim_grp = [(d,) for d in dims]
chk_compute = dict()
for dg in dim_grp:
d_rest = set(dims) - set(dg)
dg_dict = {d: "auto" for d in dg}
dr_dict = {d: -1 for d in d_rest}
dg_dict.update(dr_dict)
with da.config.set({"array.chunk-size": "{}MiB".format(csize)}):
arr_chk = arr.chunk(dg_dict)
chk = get_chunksize(arr_chk)
chk_compute.update({d: chk[d] for d in dg})
with da.config.set({"array.chunk-size": "{}MiB".format(csize)}):
arr_chk = arr.chunk({d: "auto" for d in dims})
chk_store_da = get_chunksize(arr_chk)
chk_store = dict()
for d in dims:
ncomp = int(arr.sizes[d] / chk_compute[d])
sz = np.array(factors(ncomp)) * chk_compute[d]
chk_store[d] = sz[np.argmin(np.abs(sz - chk_store_da[d]))]
return chk_compute, chk_store_da
[docs]def get_chunksize(arr: xr.DataArray) -> dict:
"""
Get chunk size of a `xr.DataArray`.
Parameters
----------
arr : xr.DataArray
The input `xr.DataArray`.
Returns
-------
chk : dict
Dictionary mapping dimension names to chunk sizes.
"""
dims = arr.dims
sz = arr.data.chunksize
return {d: s for d, s in zip(dims, sz)}
[docs]def factors(x: int) -> List[int]:
"""
Compute all factors of an interger.
Parameters
----------
x : int
Input
Returns
-------
factors : List[int]
List of factors of `x`.
"""
return [i for i in range(1, x + 1) if x % i == 0]
ANNOTATIONS = {
"from-zarr-store": {"resources": {"MEM": 1}},
"load_avi_ffmpeg": {"resources": {"MEM": 1}},
"est_motion_chunk": {"resources": {"MEM": 1}},
"transform_perframe": {"resources": {"MEM": 0.5}},
"pnr_perseed": {"resources": {"MEM": 0.5}},
"ks_perseed": {"resources": {"MEM": 0.5}},
"smooth_corr": {"resources": {"MEM": 1}},
"vectorize_noise_fft": {"resources": {"MEM": 1}},
"vectorize_noise_welch": {"resources": {"MEM": 1}},
"update_spatial_block": {"resources": {"MEM": 1}},
"tensordot_restricted": {"resources": {"MEM": 1}},
"update_temporal_block": {"resources": {"MEM": 1}},
"merge_restricted": {"resources": {"MEM": 1}},
}
"""
Dask annotations that should be applied to each task.
This is a `dict` mapping task names (actually patterns) to a `dict` of dask
annotations that should be applied to the tasks. It is mainly used to constrain
number of tasks that can be concurrently in memory for each worker.
See Also
-------
:doc:`distributed:resources`
"""
FAST_FUNCTIONS = [
darr.core.getter_inline,
darr.core.getter,
_operator.getitem,
zr.core.Array,
darr.chunk.astype,
darr.core.concatenate_axes,
darr.core._vindex_slice,
darr.core._vindex_merge,
darr.core._vindex_transpose,
]
"""
List of fast functions that should be inlined during optimization.
See Also
-------
:doc:`dask:optimize`
"""
[docs]class TaskAnnotation(SchedulerPlugin):
"""
Custom `SchedulerPlugin` that implemented per-task level annotation. The
annotations are applied according to the module constant
:const:`ANNOTATIONS`.
"""
def __init__(self) -> None:
super().__init__()
self.annt_dict = ANNOTATIONS
[docs] def update_graph(self, scheduler, client, tasks, **kwargs):
parent = cast(SchedulerState, scheduler)
for tk in tasks.keys():
for pattern, annt in self.annt_dict.items():
if re.search(pattern, tk):
ts = parent._tasks.get(tk)
res = annt.get("resources", None)
if res:
ts._resource_restrictions = res
pri = annt.get("priority", None)
if pri:
pri_org = list(ts._priority)
pri_org[0] = -pri
ts._priority = tuple(pri_org)
[docs]def custom_arr_optimize(
dsk: dict,
keys: list,
fast_funcs: list = FAST_FUNCTIONS,
inline_patterns=[],
rename_dict: Optional[dict] = None,
rewrite_dict: Optional[dict] = None,
keep_patterns=[],
) -> dict:
"""
Customized implementation of array optimization function.
Parameters
----------
dsk : dict
Input dask task graph.
keys : list
Output task keys.
fast_funcs : list, optional
List of fast functions to be inlined. By default :const:`FAST_FUNCTIONS`.
inline_patterns : list, optional
List of patterns of task keys to be inlined. By default `[]`.
rename_dict : dict, optional
Dictionary mapping old task keys to new ones. Only used during fusing of
tasks. By default `None`.
rewrite_dict : dict, optional
Dictionary mapping old task key substrings to new ones. Applied at the
end of optimization to all task keys. By default `None`.
keep_patterns : list, optional
List of patterns of task keys that should be preserved during
optimization. By default `[]`.
Returns
-------
dsk : dict
Optimized dask graph.
See Also
-------
:doc:`dask:optimize`
`dask.array.optimization.optimize`
"""
# inlining lots of array operations ref:
# https://github.com/dask/dask/issues/6668
if rename_dict:
key_renamer = fct.partial(custom_fused_keys_renamer, rename_dict=rename_dict)
else:
key_renamer = custom_fused_keys_renamer
keep_keys = []
if keep_patterns:
key_ls = list(dsk.keys())
for pat in keep_patterns:
keep_keys.extend(list(filter(lambda k: check_key(k, pat), key_ls)))
dsk = darr.optimization.optimize(
dsk,
keys,
fuse_keys=keep_keys,
fast_functions=fast_funcs,
rename_fused_keys=key_renamer,
)
if inline_patterns:
dsk = inline_pattern(dsk, inline_patterns, inline_constants=False)
if rewrite_dict:
dsk_old = dsk.copy()
for key, val in dsk_old.items():
key_new = rewrite_key(key, rewrite_dict)
if key_new != key:
dsk[key_new] = val
dsk[key] = key_new
return dsk
[docs]def rewrite_key(key: Union[str, tuple], rwdict: dict) -> str:
"""
Rewrite a task key according to `rwdict`.
Parameters
----------
key : Union[str, tuple]
Input task key.
rwdict : dict
Dictionary mapping old task key substring to new ones. All keys in this
dictionary that exists in input `key` will be substituted.
Returns
-------
key : str
The new key.
Raises
------
ValueError
if input `key` is neither `str` or `tuple`
"""
typ = type(key)
if typ is tuple:
k = key[0]
elif typ is str:
k = key
else:
raise ValueError("key must be either str or tuple: {}".format(key))
for pat, repl in rwdict.items():
k = re.sub(pat, repl, k)
if typ is tuple:
ret_key = list(key)
ret_key[0] = k
return tuple(ret_key)
else:
return k
[docs]def custom_fused_keys_renamer(
keys: list, max_fused_key_length=120, rename_dict: Optional[dict] = None
) -> str:
"""
Custom implmentation to create new keys for `fuse` tasks.
Uses custom `split_key` implementation.
Parameters
----------
keys : list
List of task keys that should be fused together.
max_fused_key_length : int, optional
Used to limit the maximum string length for each renamed key. If `None`,
there is no limit. By default `120`.
rename_dict : dict, optional
Dictionary used to rename keys during fuse. By default `None`.
Returns
-------
fused_key : str
The fused task key.
See Also
-------
split_key
dask.optimization.fuse
"""
it = reversed(keys)
first_key = next(it)
typ = type(first_key)
if max_fused_key_length: # Take into account size of hash suffix
max_fused_key_length -= 5
def _enforce_max_key_limit(key_name):
if max_fused_key_length and len(key_name) > max_fused_key_length:
name_hash = f"{hash(key_name):x}"[:4]
key_name = f"{key_name[:max_fused_key_length]}-{name_hash}"
return key_name
if typ is str:
first_name = split_key(first_key, rename_dict=rename_dict)
names = {split_key(k, rename_dict=rename_dict) for k in it}
names.discard(first_name)
names = sorted(names)
names.append(first_key)
concatenated_name = "-".join(names)
return _enforce_max_key_limit(concatenated_name)
elif typ is tuple and len(first_key) > 0 and isinstance(first_key[0], str):
first_name = split_key(first_key, rename_dict=rename_dict)
names = {split_key(k, rename_dict=rename_dict) for k in it}
names.discard(first_name)
names = sorted(names)
names.append(first_key[0])
concatenated_name = "-".join(names)
return (_enforce_max_key_limit(concatenated_name),) + first_key[1:]
[docs]def split_key(key: Union[tuple, str], rename_dict: Optional[dict] = None) -> str:
"""
Split, rename and filter task keys.
This is custom implementation that only keeps keys found in :const:`ANNOTATIONS`.
Parameters
----------
key : Union[tuple, str]
The input task key.
rename_dict : dict, optional
Dictionary used to rename keys. By default `None`.
Returns
-------
new_key : str
New key.
"""
if type(key) is tuple:
key = key[0]
kls = key.split("-")
if rename_dict:
kls = list(map(lambda k: rename_dict.get(k, k), kls))
kls_ft = list(filter(lambda k: k in ANNOTATIONS.keys(), kls))
if kls_ft:
return "-".join(kls_ft)
else:
return kls[0]
[docs]def check_key(key: Union[str, tuple], pat: str) -> bool:
"""
Check whether `key` contains pattern.
Parameters
----------
key : Union[str, tuple]
Input key. If a `tuple` then the first element will be used to check.
pat : str
Pattern to check.
Returns
-------
bool
Whether `key` contains pattern.
"""
try:
return bool(re.search(pat, key))
except TypeError:
return bool(re.search(pat, key[0]))
[docs]def check_pat(key: Union[str, tuple], pat_ls: List[str]) -> bool:
"""
Check whether `key` contains any pattern in a list.
Parameters
----------
key : Union[str, tuple]
Input key. If a `tuple` then the first element will be used to check.
pat_ls : List[str]
List of pattern to check.
Returns
-------
bool
Whether `key` contains any pattern in the list.
"""
for pat in pat_ls:
if check_key(key, pat):
return True
return False
[docs]def inline_pattern(dsk: dict, pat_ls: List[str], inline_constants: bool) -> dict:
"""
Inline tasks whose keys match certain patterns.
Parameters
----------
dsk : dict
Input dask graph.
pat_ls : List[str]
List of patterns to check.
inline_constants : bool
Whether to inline constants.
Returns
-------
dsk : dict
Dask graph with keys inlined.
See Also
-------
dask.optimization.inline
"""
keys = [k for k in dsk.keys() if check_pat(k, pat_ls)]
if keys:
dsk = inline(dsk, keys, inline_constants=inline_constants)
for k in keys:
del dsk[k]
if inline_constants:
dsk, dep = cull(dsk, set(list(flatten(keys))))
return dsk
[docs]def custom_delay_optimize(
dsk: dict, keys: list, fast_functions=[], inline_patterns=[], **kwargs
) -> dict:
"""
Custom optimization functions for delayed tasks.
By default only fusing of tasks will be carried out.
Parameters
----------
dsk : dict
Input dask task graph.
keys : list
Output task keys.
fast_functions : list, optional
List of fast functions to be inlined. By default `[]`.
inline_patterns : list, optional
List of patterns of task keys to be inlined. By default `[]`.
Returns
-------
dsk : dict
Optimized dask graph.
"""
dsk, _ = fuse(ensure_dict(dsk), rename_keys=custom_fused_keys_renamer)
if inline_patterns:
dsk = inline_pattern(dsk, inline_patterns, inline_constants=False)
if fast_functions:
dsk = inline_functions(
dsk,
[],
fast_functions=fast_functions,
)
return dsk
[docs]def unique_keys(keys: list) -> np.ndarray:
"""
Returns only unique keys in a list of task keys.
Dask task keys regarding arrays are usually tuples representing chunked
operations. This function ignore different chunks and only return unique keys.
Parameters
----------
keys : list
List of dask keys.
Returns
-------
unique : np.ndarray
Unique keys.
"""
new_keys = []
for k in keys:
if isinstance(k, tuple):
new_keys.append("chunked-" + k[0])
elif isinstance(k, str):
new_keys.append(k)
return np.unique(new_keys)
[docs]def get_keys_pat(pat: str, keys: list, return_all=False) -> Union[list, str]:
"""
Filter a list of task keys by pattern.
Parameters
----------
pat : str
Pattern to check.
keys : list
List of keys to be filtered.
return_all : bool, optional
Whether to return all keys matching `pat`. If `False` then only the
first match will be returned. By default `False`.
Returns
-------
keys : Union[list, str]
If `return_all` is `True` then a list of keys will be returned.
Otherwise only one key will be returned.
"""
keys_filt = list(filter(lambda k: check_key(k, pat), list(keys)))
if return_all:
return keys_filt
else:
return keys_filt[0]
[docs]def optimize_chunk(arr: xr.DataArray, chk: dict) -> xr.DataArray:
"""
Rechunk a `xr.DataArray` with constrained "rechunk-merge" tasks.
Parameters
----------
arr : xr.DataArray
The array to be rechunked.
chk : dict
The desired chunk size.
Returns
-------
arr_chk : xr.DataArray
The rechunked array.
"""
fast_funcs = FAST_FUNCTIONS + [darr.core.concatenate3]
arr_chk = arr.chunk(chk)
arr_opt = fct.partial(
custom_arr_optimize,
fast_funcs=fast_funcs,
rewrite_dict={"rechunk-merge": "merge_restricted"},
)
with da.config.set(array_optimize=arr_opt):
arr_chk.data = da.optimize(arr_chk.data)[0]
return arr_chk
[docs]def local_extreme(fm: np.ndarray, k: np.ndarray, etype="max", diff=0) -> np.ndarray:
"""
Find local extreme of a 2d array.
Parameters
----------
fm : np.ndarray
The input 2d array.
k : np.ndarray
Structuring element defining the locality of the result, passed as
`kernel` to :func:`cv2.erode` and :func:`cv2.dilate`.
etype : str, optional
Type of local extreme. Either `"min"` or `"max"`. By default `"max"`.
diff : int, optional
Threshold of difference between local extreme and its neighbours. By
default `0`.
Returns
-------
fm_ext : np.ndarray
The returned 2d array whose non-zero elements represent the location of
local extremes.
Raises
------
ValueError
if `etype` is not "min" or "max"
"""
fm_max = cv2.dilate(fm, k)
fm_min = cv2.erode(fm, k)
fm_diff = ((fm_max - fm_min) > diff).astype(np.uint8)
if etype == "max":
fm_ext = (fm == fm_max).astype(np.uint8)
elif etype == "min":
fm_ext = (fm == fm_min).astype(np.uint8)
else:
raise ValueError("Don't understand {}".format(etype))
return cv2.bitwise_and(fm_ext, fm_diff).astype(np.uint8)
[docs]def med_baseline(a: np.ndarray, wnd: int) -> np.ndarray:
"""
Subtract baseline from a timeseries as estimated by median-filtering the
timeseries.
Parameters
----------
a : np.ndarray
Input timeseries.
wnd : int
Window size of the median filter. This parameter is passed as `size` to
:func:`scipy.ndimage.filters.median_filter`.
Returns
-------
a : np.ndarray
Timeseries with baseline subtracted.
"""
base = median_filter(a, size=wnd)
a -= base
return a.clip(0, None)
@darr.as_gufunc(signature="(m,n),(m)->(n)", output_dtypes=float)
def sps_lstsq(a: csc_matrix, b: np.ndarray, **kwargs):
out = np.zeros((b.shape[0], a.shape[1]))
for i in range(b.shape[0]):
out[i, :] = lsqr(a, b[i, :].squeeze(), **kwargs)[0]
return out