Source code for jnormcorre.utils.registrationarrays

import jnormcorre.motion_correction
from jnormcorre.utils.lazy_array import lazy_data_loader
import tifffile
import numpy as np
import h5py
from typing import *


[docs] class FilteredArray(lazy_data_loader):
[docs] def __init__( self, raw_data_loader: lazy_data_loader, filter_function: Callable, batching: int = 100, ): """ Class for loading and filtering data; this is broadly useful because we often want to spatially filter data to expose salient signals. We use this filtered version of the data to estimate shifts Args: raw_data_loader (lazy_data_loader): An object that supports the lazy_data_loader interface. This can be for e.g. a custom object that reads data from disk, an array in RAM (like a numpy ndarray) or anything else filter_function (Callable): A function that applies a spatial filter to every frame of a data array. It takes an input movie of shape (frames, fov dim 1, fov dim 2) and returns a filtered movie of the same shape. The type of the output is cast to numpy array in this function. batching (int): Max number of frames we process on GPU at a time, used to avoid OOM errors. """ self._raw_data_loader = raw_data_loader self._filter = filter_function self._batching = batching
@property def raw_data_loader(self) -> lazy_data_loader: return self._raw_data_loader @property def filter_function(self) -> Callable: return self._filter @property def batching(self): return self._batching @batching.setter def batching(self, new_batch: int): self._batching = new_batch @property def dtype(self) -> str: """ data type """ return self.raw_data_loader.dtype @property def shape(self) -> Tuple[int, int, int]: """ Array shape (n_frames, dims_x, dims_y) """ return self.raw_data_loader.shape @property def ndim(self) -> int: """ Number of dimensions """ return len(self.shape) def _compute_at_indices(self, indices: Union[list, int, slice]) -> np.ndarray: """ Lazy computation logic goes here to return frames. Slices the array over time (dimension 0) at the desired indices. Parameters ---------- indices: Union[list, int, slice] the user's desired way of picking frames, either an int, list of ints, or slice i.e. slice object or int passed from `__getitem__()` Returns ------- np.ndarray array at the indexed slice """ frames = self.raw_data_loader[indices] if frames.ndim == 2: frames = frames[None, :, :] if frames.shape[0] <= self.batching: return np.array(self.filter_function(frames)) else: batches = list(range(0, frames.shape[0], self.batching)) output = np.zeros_like(frames) for k in range(len(batches)): start = batches[k] end = min(frames.shape[0], start + self.batching) output[start:end] = np.array(self.filter_function(frames[start:end])) return output
[docs] class TiffArray(lazy_data_loader):
[docs] def __init__(self, filename): """ TiffArray data loading object. Supports loading data from multipage tiff files. Args: filename (str): Path to file """ self.filename = filename
@property def dtype(self) -> str: """ str data type """ return np.float32 @property def shape(self) -> Tuple[int, int, int]: """ Tuple[int] (n_frames, dims_x, dims_y) """ with tifffile.TiffFile(self.filename) as tffl: num_frames = len(tffl.pages) for page in tffl.pages[0:1]: image = page.asarray() x, y = page.shape return num_frames, x, y @property def ndim(self) -> int: """ int Number of dimensions """ return len(self.shape) def _compute_at_indices(self, indices: Union[list, int, slice]) -> np.ndarray: if isinstance(indices, int): data = tifffile.imread(self.filename, key=[indices]).squeeze() elif isinstance(indices, list): data = tifffile.imread(self.filename, key=indices).squeeze() else: indices_list = list( range( indices.start or 0, indices.stop or self.shape[0], indices.step or 1 ) ) data = tifffile.imread(self.filename, key=indices_list).squeeze() return data.astype(self.dtype)
[docs] class Hdf5Array(lazy_data_loader):
[docs] def __init__(self, filename: str, field: str) -> None: """ Generic lazy loader for Hdf5 files video files, where data is stored as (T, x, y). T is number of frames, x and y are the field of view dimensions (height and width). Args: filename (str): Path to filename field (str): Field of hdf5 file containing data """ if not isinstance(field, str): raise ValueError("Field must be a string") self.filename = filename self.field = field with h5py.File(self.filename, "r") as file: # Access the 'field' dataset field_dataset = file[self.field] # Get the shape of the array self._shape = field_dataset.shape
@property def dtype(self) -> str: """ str data type """ return np.float32 @property def shape(self) -> Tuple[int, int, int]: """ Tuple[int] (n_frames, dims_x, dims_y) """ return self._shape @property def ndim(self) -> int: """ int Number of dimensions """ return len(self.shape) def _compute_at_indices(self, indices: Union[list, int, slice]) -> np.ndarray: with h5py.File(self.filename, "r") as file: # Access the 'field' dataset field_dataset = file[self.field] if isinstance(indices, int): data = field_dataset[indices, :, :].squeeze() elif isinstance(indices, list): data = field_dataset[indices, :, :].squeeze() else: indices_list = list( range( indices.start or 0, indices.stop or self.shape[0], indices.step or 1, ) ) data = field_dataset[indices_list, :, :].squeeze() return data.astype(self.dtype)
[docs] class RegistrationArray(lazy_data_loader):
[docs] def __init__( self, registration_obj: jnormcorre.motion_correction.FrameCorrector, data_to_register: jnormcorre.utils.lazy_array.lazy_data_loader, pw_rigid=False, reference_data: Optional[jnormcorre.utils.lazy_array.lazy_data_loader] = None, ): """ Class for registering 2D functional imaging data on the fly. Useful for visualization libraries etc. Args: registration_obj (jnormcorre.motion_correction.FrameCorrector): Object which can perform registration data_to_register (jnormcorre.utils.lazy_array.lazy_data_loader): Data loading object pw_rigid (bool): Indicates whether we apply rigid or piecewise rigid registration to frames reference_data [Optional(jnormcorre.utils.lazy_array.lazy_data_loader)]: A reference stack. If provided, the algorithm will find optimal alignment between template and each frame of this stack. It will then apply these shifts to "data_to_register" """ self.reference_data = reference_data self.data_loader = data_to_register if self.reference_data is not None: if not (self.reference_data.shape == self.data_loader.shape): raise ValueError( f"The data to register and the reference data stack do not have the same shape." ) self.registration_obj = registration_obj self._pw_rigid = pw_rigid # Verify that the data and registration info align properly dim1_match = data_to_register.shape[1] == registration_obj.template.shape[0] dim2_match = data_to_register.shape[2] == registration_obj.template.shape[1] error_msg = "Dimension mismatch: FOV dims of dataset {} FOV dims\ of template {}".format( data_to_register.shape[1:], registration_obj.template.shape ) if not (dim1_match and dim2_match): raise ValueError(error_msg)
@property def dtype(self): return self.data_loader.dtype @property def shape(self): return self.data_loader.shape @property def ndim(self): return self.data_loader.ndim @property def batching(self): return self.registration_obj.batching @batching.setter def batching(self, new_batch: int): self.registration_obj.batching = new_batch @property def template(self) -> np.ndarray: """ The template used for registration """ return self.registration_obj.template def _compute_at_indices(self, indices: Union[list, int, slice]) -> np.ndarray: # Use data loader to load the frames frames = self.data_loader[indices, :, :] if len(frames.shape) == 2: # This means we loaded 1 frame only frames = frames[None, :, :] # Register the data if self.reference_data is None: return self.registration_obj.register_frames( frames, pw_rigid=self._pw_rigid ).squeeze() else: reference_frames = self.reference_data[indices, :, :] if len(reference_frames.shape) == 2: reference_frames = reference_frames[None, :, :] return self.registration_obj.register_frames_and_transfer( frames, reference_frames )