Skip to main content
  1. Posts/

The first time I used Python's pattern matching

·7 mins
Python
Pierre-Antoine Comby
Author
Pierre-Antoine Comby
PhD Student on Deep Learning for fMRI Imaging
Table of Contents

So Python pattern matching has been available for a while now (It landed with python 3.10 in 2020), yet as a mostly scientific programmer, I had not the chance to find a good reason to use it (after all, it could be done with if/else statements). But today, I found a good reason to use it, and I am happy to share it with you.

The problem: Lazy evaluation of an array.
#

For $Project, I was looking for a way to have lazy evaluation of the element of an array, which is computed from a base (non lazy array), think for instance a time series of images that cannot be stored in memory. The idea is to have a lazy array that will compute the element only when needed, and store it in memory for later use. And we want to have a way to access element of the lazy array, such as: lazy_array[5], lazy_array[a_time_frame_index, a_slice_in_base_array]

Of course there exist some library for that (the closest to my use cases would be lazyarray, or PyKeops for heavier stuff), but my use case is simple enough that I can write it myself. Basically we are going to emulate a numpy array, with a __getitem__ method that will evaluate the lazy array (and return a real numpy array), by appling a sequence of operations to the base array. We now a few things about our lazy array: its dimensions (the time serie has a finite length), and the operations to apply to the base array to get the element of the lazy array. So we can write something like:

import numpy as np
from typing import Callable
from collection import Sequence

class LazyArray(Sequence):
    """
    A lazy array, that will evaluate the element only when needed.

    The lazy array is defined by a base array, and a sequence of operations to apply to the base array to get the element of the lazy array.
    Operations are callables, that take the base array as first argument, and the requested frame idx, and return a numpy array.
    """
    def __init__(self, base_array:np.ndarray, n_time_points:int,  operations: list[Callable]):
        self._base_array = base_array
        self._operations = _operations
        self._n_time_points = n_time_points

    # Make it behave (almost) like a numpy array
    def __len__(self):
        return self._n_time_points

    def __iter__(self):
        return (self[i] for i in range(len(self)))

    def shape(self):
        return (self._n_time_points, *self.base_array.shape)

    def dtype(self):
        return self.base_array.dtype



    def _get_frame(self, frame_idx: int) -> np.ndarray:
        cur = self._base_array
        for op in self._operations:
            cur = op(cur, frame_idx)

        return cur

    def __getitem__(self,addr):
        # TODO
        # We apply the operations to the base array
        # and return the result at the specified address.
        # Basically dispatching the call to _get_frame and do some pre/post processing.

Pattern matching to the rescue
#

If we were not using pattern matching
#

We would have to use if/else statements, and it would be much less readable:

def __getitem__(self, addr):
    if isinstance(addr, int):
        return self._get_frame(addr)
    if not isinstance(addr, tuple):
        raise TypeError(f"Invalid type for frame slicer: {type(addr)}")

    frames_slicer, *slicer = addr
    if isinstance(frames_slicer, int):
        return self._get_frame(frame_slicer)[slicer]
    if isinstance(frames_slicer, slice):
        start = frames_slicer.start or 0
        stop = frames_slicer.stop or len(self)
        step = frames_slicer.step or 1
        frames = range(start, stop, step)
    elif isinstance(frames_slicer, Iterable):
        frames=frames_slicer
    return np.concatenate([self._get_frame(i)[np.newaxis, *slicer] for i in frames])

Some critics on the code above:

  • It is not very readable : the chain of though is not necessarily obvious.
  • adding a new type of address (say indexing by a boolean array) is tiresome (it is left as an exercise to the reader…)

With pattern matching
#

Here is the solution I come up with, using pattern matching:

def __getitem__(self, addr):
    match addr:
        case int(): # a single frame
            return self._get_frame(addr)
        case (int() as frame_idx, slicer): # a frame and a slice
            return self._get_frame(frame_idx)[slicer]
        case (slice(start=start, stop=stop, step=step), *slicer): # a certain number of slice
            # by default slice values are None, so we need to set them to something.
            start = start or 0
            stop = stop or len(self)
            step = step or 1
            frames = range(start, stop, step)
            return np.concatenate([self._get_frame(i)[np.newaxis, *slicer] for i in frames])
        case (Iterable as frames, *slicer): # a list of frame and a slice
            # nothing to do here, we have catch everything we need :)
            return np.concatenate([self._get_frame(i)[np.newaxis, *slicer] for i in frames])
        case _:
            raise ValueError(f"No pattern matching for address: {type(addr)}")

And adding the boolean array indexing is as simple as adding a new case:

match addr:
    # [...]
    case (np.ndarray as frames_mask, *slicer) if frame_mask.dtype == "bool": # a boolean mask and a slice
        return np.concatenate([self._get_frame(i)[np.newaxis, *slicer] for i in np.argwhere(mask)])

The guarding condition if frame_mask.dtype = “bool”= show the strong power of the aliasing in pattern matching: we can use the alias frames_mask directly in the case statement (saving a few lines in the process).

Conclusion
#

There are some limits with this approach, as the complexity of the pattern matching growes (lets add a check for the iterable is a int for instance), we start to hit some limits, and the good old if/elif/else statements are more powerfull, at some point trading-off readaility for versatility might be a good idea.

The full Code for the lazy array
#

Here is the full code for my lazy array, with some extra bonuses for managing operations and adding extra parameters to operations. Also, this was a good opportunity to play with type annotation.

"""Lazy Simulation Array Module.

very close (i.e. based on some original copy pasting) to what is done in lazzyarray

https://github.com/NeuralEnsemble/lazyarray/blob/master/lazyarray.py

"""
from __future__ import annotations
import operator
from copy import deepcopy
from typing import Any, Callable, Sequence, TypeVar, Mapping
import numpy as np
from numpy.typing import ArrayLike, NDArray
from functools import wraps

T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")


def reverse(func: Callable[[T, U], V]) -> Callable[[U, T], V]:
    """Flip argument of function f(a, b) ->  f(b, a)."""

    @wraps(func)
    def reversed_func(a: T, b: U) -> V:
        return func(b, a)

    reversed_func.__doc__ = "Reversed argument form of %s" % func.__doc__
    reversed_func.__name__ = "reversed %s" % func.__name__
    return reversed_func


def lazy_inplace_operation(name: str) -> Callable[[NDArray, Any], NDArray]:
    """Create a lazy inplace operation on a LazySimArray."""

    def op(self: LazySimArray, val: ArrayLike) -> LazySimArray:
        self.apply(getattr(operator, name), val)
        return self

    return op


def lazy_operation(
    name: str, reversed: bool = False
) -> Callable[[NDArray, ArrayLike], NDArray]:
    """Create a lazy operation on a LazySimArray."""

    def op(self: LazySimArray, val: ArrayLike) -> LazySimArray:
        new_map = deepcopy(self)
        f = getattr(operator, name)
        if reversed:
            f = reverse(f)
        new_map.apply(f, val)
        return new_map

    return op


def lazy_unary_operation(name: str) -> Callable[[LazySimArray], LazySimArray]:
    """Create a lazy unary operation on a LazySimArray."""

    def op(self: LazySimArray) -> LazySimArray:
        new_map = deepcopy(self)
        new_map.operations.append((getattr(operator, name), None))
        return new_map

    return op


class LazySimArray(Sequence):
    """A lazy array for the simulation of the data.

    The simulation data is acquired frame wise. The idea is thus to register all
    the required operation to produce this frame.

    This is very close to what is done in larray[1]_ libray, but evaluation will
    alwaysbe considered frame wise.

    .. [1] https://github.com/NeuralEnsemble/lazyarray/tree/master

    """

    def __init__(self, base_array: np.ndarray = None, n_frames: int = 1):
        self._base_array = base_array
        self._operations: list[Callable[np.ndarray, int, ...], np.ndarray] = []
        self._n_frames = n_frames

    @property
    def shape(self) -> tuple[int, ...]:
        """Get shape."""
        return tuple(len(self), *self._base_array.shape)

    @property
    def dtype(self) -> np.dtype:
        """Get dtype."""
        return self._base_array.dtype

    def ndim(self) -> int:
        """Get number of dimensions."""
        return len(self.shape) + 1

    def __len__(self) -> int:
        """Get length."""
        return self._n_frames

    def __getitem__(self, addr: int | tuple[slice | int]) -> np.ndarray:
        """Get frame idx by applying all the operations in order.

        If an operation requires the frame index, (ie has `frame_idx=None` in signature)
        it will be provided here.

        TODO add support for slicing:
        - transform the slice of frame in range of frame idx
        - extract the rest of the slice from the modified base array and return it.
        This would allow stuff like larray[:, mask] to work.
        """
        match addr:
            case int():
                return self._get_frame(addr)
            case (int() as frame_idx, slicer):
                return self._get_frame(frame_idx)[slicer]
            case (slice(start=start, stop=stop, step=step), slicer):
                start = start or 0
                stop = stop or len(self)
                step = step or 1
                return np.concatenate(
                    [
                        self._get_frame(i)[np.newaxis, slicer]
                        for i in range(start, stop, step)
                    ]
                )

    def _get_frame(self, frame_idx: int) -> np.ndarray:
        if isinstance(self._base_array, LazySimArray):
            cur = self._base_array[frame_idx]
        else:
            cur = self._base_array
        for op, args, kwargs in self._operations:
            if "frame_idx" in op.__code__.co_varnames[: op.__code__.co_argcount]:
                kwargs["frame_idx"] = frame_idx
            cur = op(cur, *args, **kwargs)

        return cur

    def __iter__(self):
        return (self[i] for i in range(len(self)))

    def apply(
        self,
        op: Callable[[np.ndarray, int, ...], np.ndarray],
        *args: Sequence[Any],
        **op_kwargs: Mapping[str, Any],
    ) -> None:
        """Register an operation to apply."""
        self._operations.append((op, args, op_kwargs))

    def pop_op(self, idx: int) -> None:
        """Pop an operation."""
        op = self._operations.pop(idx)
        return op

    # define standard operations
    __iadd__ = lazy_inplace_operation("add")
    __isub__ = lazy_inplace_operation("sub")
    __imul__ = lazy_inplace_operation("mul")
    __idiv__ = lazy_inplace_operation("div")
    __ipow__ = lazy_inplace_operation("pow")

    __add__ = lazy_operation("add")
    __radd__ = __add__
    __sub__ = lazy_operation("sub")
    __rsub__ = lazy_operation("sub", reversed=True)
    __mul__ = lazy_operation("mul")
    __rmul__ = __mul__
    __div__ = lazy_operation("div")
    __rdiv__ = lazy_operation("div", reversed=True)
    __truediv__ = lazy_operation("truediv")
    __rtruediv__ = lazy_operation("truediv", reversed=True)
    __pow__ = lazy_operation("pow")

    __neg__ = lazy_unary_operation("neg")
    __pos__ = lazy_unary_operation("pos")
    __abs__ = lazy_unary_operation("abs")