Source code for mindpype.kernels.baseline_correction

from ..core import MPEnums
from ..kernel import Kernel
from ..graph import Node, Parameter

import numpy as np


[docs] class BaselineCorrectionKernel(Kernel): """ Kernel to perform baseline correction. Input data is baseline corrected by subtracting the mean of the baseline period from the input data along a specified axis. .. note:: This kernel utilizes the numpy function :func:`mean <numpy:numpy.mean>`. Parameters ---------- graph : Graph Graph that the kernel should be added to inA : Tensor object Input data container outA : Tensor object Output data container axis : int, default = -1 Axis along which to perform baseline correction baseline_period : array-like (start, end) Baseline period where start and end are the start and end indices of the baseline period within the target axis. """ def __init__(self, graph, inA, outA, axis=-1, baseline_period=(0,-1)): """ Init """ super().__init__('BaselineCorrection', MPEnums.INIT_FROM_NONE, graph) self.inputs = [inA] self.outputs = [outA] self._baseline_period = np.asarray(baseline_period) self._axis = axis def _verify(self): """ Verify the kernel parameters """ d_in = self.inputs[0] d_out = self.outputs[0] # input and output must be a tensor if (d_in.mp_type != MPEnums.TENSOR or d_out.mp_type != MPEnums.TENSOR): raise TypeError("Input and output must be a tensor") # check the baseline period is valid if self._baseline_period.shape[0] != 2: raise ValueError("Baseline period must include start and end points") # if the axis or baseline indices are negative, convert them to positive if self._axis < 0: self._axis = self._axis + len(d_in.shape) for i, idx in enumerate(self._baseline_period): if idx < 0: self._baseline_period[i] = d_in.shape[self._axis] + idx # force the baseline period to be an integer self._baseline_period = np.array(self._baseline_period, dtype=int) # start index must be less than end index and both must be # within the range of the data if ((self._baseline_period[0] > self._baseline_period[1]) or (self._baseline_period[0] < 0) or (self._baseline_period[1] > d_in.shape[-1])): raise ValueError("Baseline period must be within the range of the data") # check the output dimensions are valid if d_out.virtual and len(d_out.shape) == 0: d_out.shape = d_in.shape def _initialize(self, init_inputs, init_outputs, labels=None): """ Initialize the kernel and compute initialization data output Parameters ---------- init_inputs : list of Tensors or Arrays Initialization input data container, list of length 1 init_outputs : list of Tensors or Arrays Initialization output data container, list of length 1 labels : Tensor Class labels for initialization data """ init_in = init_inputs[0] init_out = init_outputs[0] if init_out is not None and (init_in is not None and init_in.shape != ()): if init_in.mp_type != MPEnums.TENSOR: init_in = init_in.to_tensor() # Set the initialization output shape if init_out.virtual: output_shape = list(init_in.shape) init_out.shape = tuple(output_shape) # update the output shape, as needed axis_adjusted = False if (len(self.inputs[0].shape) != len(init_in.shape) and self._axis >= 0): self._axis += 1 # adjust axis assuming stacked data axis_adjusted = True self._process_data([init_in], init_outputs) # adjust the axis back to the original value if axis_adjusted: self._axis -= 1 def _process_data(self, inputs, outputs): """ Perform baseline correction on the input data and store the result in the output data. Compute the baseline along the specified axis between the start and end indices. Then subtract the mean of the baseline from the input data. Parameters ---------- inputs : list of Tensors or Arrays Input data container, list of length 1 outputs : list of Tensors or Arrays Output data container, list of length 1 """ inA = inputs[0] outA = outputs[0] # compute baseline baseline = np.mean(inA.data[..., self._baseline_period[0]:self._baseline_period[1]], axis=self._axis, keepdims=True) # remove baseline and assign to output outA.data = inA.data - baseline
[docs] @classmethod def add_to_graph(cls, graph, inputA, outputA, baseline_period=(0,-1), axis=-1, init_input=None, init_labels=None): """ Factory method to create a baseline correction kernel and add it to a graph Parameters ---------- graph : Graph Graph that the kernel should be added to inputA : Tensor Input data container outputA : Tensor Output data container baseline_period : array-like (start, end) Baseline period where start and end are the start and end indices of the baseline period within the target axis. axis : int, default = -1 Axis along which to perform baseline correction init_input : Tensor or Scalar data container, default=None MindPype data container with initialization data to be transformed and passed to downstream nodes during graph initialization init_labels : Tensor or Array data container, default=None MindPype data container with initialization labels to be passed to downstream nodes during graph initialization Returns ------- node : Node Node object containing the baseline correction kernel and parameters """ # create the kernel object k = cls(graph, inputA, outputA, axis, baseline_period) # create parameter objects for the input and output params = (Parameter(inputA, MPEnums.INPUT), Parameter(outputA, MPEnums.OUTPUT)) # add the kernel to a generic node object node = Node(graph, k, params) # add the node to the graph graph.add_node(node) # if initialization data is provided, then add it to the node if init_input is not None: node.add_initialization_data([init_input], init_labels) return node