resampler

resampler: fast differentiable resizing and warping of arbitrary grids.

Hugues Hoppe    2026

[Open in Colab]   [Kaggle]   [MyBinder]   [DeepNote]   [GitHub source]   [API docs]   [PyPI package]

The notebook resampler_notebook.ipynb demonstrates the resampler library and contains documentation, usage examples, unit tests, and experiments.

Overview

The resampler library enables fast differentiable resizing and warping of arbitrary grids. It supports:

  • grids of any dimension (e.g., 1D, 2D images, 3D video, 4D batches of videos), containing

  • samples of any shape (e.g., scalars, colors, motion vectors, Jacobian matrices) and

  • any numeric type (e.g., uint8, float64, complex128)

  • within several array libraries (numpy, tensorflow, torch, and jax);

  • either 'dual' ("half-integer") or 'primal' grid-type for each dimension;

  • many boundary rules, specified per dimension, extensible via subclassing;

  • an extensible set of filter kernels, selectable per dimension;

  • optional gamma transfer functions for correct linear-space filtering;

  • prefiltering for accurate antialiasing when resize downsampling;

  • efficient backpropagation of gradients for tensorflow, torch, and jax;

  • few dependencies (only numpy and scipy) and no C extension code, yet

  • faster resizing than C++ implementations in tf.image and torch.nn.

A key strategy is to leverage existing sparse matrix representations and operations.

Example usage

!pip install -q mediapy numpy resampler

import mediapy as media
import numpy as np
import resampler
array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
upsampled = resampler.resize(array, (128, 192))  # To 128x192 resolution.
media.show_images({'4x6': array, '128x192': upsampled}, height=128)

image = media.read_image('https://github.com/hhoppe/data/raw/main/image.png')
downsampled = resampler.resize(image, (32, 32))
media.show_images({'128x128': image, '32x32': downsampled}, height=128)

import matplotlib.pyplot as plt
array = [3.0, 5.0, 8.0, 7.0]  # 4 source samples in 1D.
new_dual = resampler.resize(array, (32,))  # (default gridtype='dual') 8x resolution.
new_primal = resampler.resize(array, (25,), gridtype='primal')  # 8x resolution.

_, axs = plt.subplots(1, 2, figsize=(7, 1.5))
axs[0].set_title("gridtype='dual'")
axs[0].plot((np.arange(len(array)) + 0.5) / len(array), array, 'o')
axs[0].plot((np.arange(len(new_dual)) + 0.5) / len(new_dual), new_dual, '.')
axs[1].set_title("gridtype='primal'")
axs[1].plot(np.arange(len(array)) / (len(array) - 1), array, 'o')
axs[1].plot(np.arange(len(new_primal)) / (len(new_primal) - 1), new_primal, '.')
plt.show()

batch_size = 4
batch_of_images = media.moving_circle((16, 16), batch_size)
upsampled = resampler.resize(batch_of_images, (batch_size, 64, 64))
media.show_videos({'original': batch_of_images, 'upsampled': upsampled}, fps=1)

original upsampled

Most examples above use the default resize() settings:

  • gridtype='dual' for both source and destination arrays,
  • boundary='auto' which uses 'reflect' for upsampling and 'clamp' for downsampling,
  • filter='lanczos3' (a Lanczos kernel with radius 3),
  • gamma=None which by default uses the 'power2' transfer function for the uint8 image in the second example,
  • scale=1.0, translate=0.0 (no domain transformation),
  • default precision and output dtype.

Advanced usage:

Map an image to a wider grid using custom scale and translate vectors, with horizontal 'reflect' and vertical 'natural' boundary rules, providing a constant value for the exterior, using different filters (Lanczos and O-MOMS) in the two dimensions, disabling gamma correction, performing computations in double-precision, and returning an output array in single-precision:

new = resampler.resize(
    image, (128, 512), boundary=('natural', 'reflect'), cval=(0.2, 0.7, 0.3),
    filter=('lanczos3', 'omoms5'), gamma='identity', scale=(0.8, 0.25),
    translate=(0.1, 0.35), precision='float64', dtype='float32')
media.show_images({'image': image, 'new': new})

Warp an image by transforming it using polar coordinates:

shape = image.shape[:2]
yx = ((np.indices(shape).T + 0.5) / shape - 0.5).T  # [-0.5, 0.5]^2
radius, angle = np.linalg.norm(yx, axis=0), np.arctan2(*yx)
angle += (0.8 - radius).clip(0, 1) * 2.0 - 0.6
coords = np.dstack((np.sin(angle) * radius, np.cos(angle) * radius)) + 0.5
resampled = resampler.resample(image, coords, boundary='constant')
media.show_images({'image': image, 'resampled': resampled})

Limitations:

  • Filters are assumed to be separable.
  • Although resize implements prefiltering, resample does not yet have it (and therefore may have aliased results if downsampling).
  • Differentiability is only with respect to the grid values, not wrt the resize shape, scale, translation, or the resampling coordinates.

   1"""resampler: fast differentiable resizing and warping of arbitrary grids.
   2
   3.. include:: ../README.md
   4"""
   5
   6# Note that pydoc uses the module docstring in both __init__.py (for section headings) and
   7# __init__.pyi (for the actual content)!
   8
   9__docformat__ = 'google'
  10__version__ = '1.0.3'
  11__version_info__ = tuple(int(num) for num in __version__.split('.'))
  12
  13from collections.abc import Callable, Iterable, Sequence
  14import abc
  15import dataclasses
  16import functools
  17import importlib
  18import itertools
  19import math
  20import os
  21import sys
  22import types
  23import typing
  24from typing import Any, Generic, Literal, TypeAlias, TypeVar, Union
  25
  26import numpy as np
  27import numpy.typing
  28import scipy.interpolate
  29import scipy.linalg
  30import scipy.ndimage
  31import scipy.sparse
  32import scipy.sparse.linalg
  33
  34
  35def _noop_decorator(*args: Any, **kwargs: Any) -> Any:
  36  """Return function decorated with no-operation; invocable with or without args."""
  37  if len(args) != 1 or not callable(args[0]) or kwargs:
  38    return _noop_decorator  # Decorator is invoked with arguments; ignore them.
  39  func: Callable[..., Any] = args[0]
  40  return func
  41
  42
  43try:
  44  import numba
  45except ModuleNotFoundError:
  46  numba = sys.modules['numba'] = types.ModuleType('numba')
  47  numba.njit = _noop_decorator  # type: ignore[attr-defined]
  48_USING_NUMBA = hasattr(numba, 'jit')
  49
  50if TYPE_CHECKING:
  51  import jax.numpy
  52  import tensorflow as tf
  53  import torch
  54
  55  _DType: TypeAlias = np.dtype[Any]
  56  _NDArray: TypeAlias = numpy.NDArray[Any]
  57  _DTypeLike: TypeAlias = numpy.DTypeLike
  58  _ArrayLike: TypeAlias = numpy.ArrayLike
  59  _TensorflowTensor: TypeAlias = tf.Tensor
  60  _TorchTensor: TypeAlias = torch.Tensor
  61  _JaxArray: TypeAlias = jax.numpy.ndarray
  62
  63else:
  64  # Typically, create named types for use in the `pdoc` documentation.
  65  # But here, these are superseded by the declarations in __init__.pyi!
  66  _DType: TypeAlias = Any
  67  _NDArray: TypeAlias = Any
  68  _DTypeLike: TypeAlias = Any
  69  _ArrayLike: TypeAlias = Any
  70  _TensorflowTensor: TypeAlias = Any
  71  _TorchTensor: TypeAlias = Any
  72  _JaxArray: TypeAlias = Any
  73
  74_Array = TypeVar('_Array', _NDArray, _TensorflowTensor, _TorchTensor, _JaxArray)
  75_AnyArray = Union[_NDArray, _TensorflowTensor, _TorchTensor, _JaxArray]
  76
  77
  78def _check_eq(a: Any, b: Any, /) -> None:
  79  """If the two values or arrays are not equal, raise an exception with a useful message."""
  80  are_equal = np.all(a == b) if isinstance(a, np.ndarray) else a == b
  81  if not are_equal:
  82    raise AssertionError(f'{a!r} == {b!r}')
  83
  84
  85def _real_precision(dtype: _DTypeLike, /) -> _DType:
  86  """Return the type of the real part of a complex number."""
  87  return np.array([], dtype).real.dtype
  88
  89
  90def _complex_precision(dtype: _DTypeLike, /) -> _DType:
  91  """Return a complex type to represent a non-complex type."""
  92  return np.result_type(dtype, np.complex64)
  93
  94
  95def _get_precision(
  96    precision: _DTypeLike | None, dtypes: list[_DType], weight_dtypes: list[_DType], /
  97) -> _DType:
  98  """Return dtype based on desired precision or on data and weight types."""
  99  precision2 = np.dtype(
 100      precision if precision is not None else np.result_type(np.float32, *dtypes, *weight_dtypes)
 101  )
 102  if not np.issubdtype(precision2, np.inexact):
 103    raise ValueError(f'Precision {precision2} is not floating or complex.')
 104  check_complex = [precision2, *dtypes]
 105  is_complex = [np.issubdtype(dtype, np.complexfloating) for dtype in check_complex]
 106  if len(set(is_complex)) != 1:
 107    s_types = ','.join(str(dtype) for dtype in check_complex)
 108    raise ValueError(f'Types {s_types} must be all real or all complex.')
 109  return precision2
 110
 111
 112def _sinc(x: _ArrayLike, /) -> _NDArray:
 113  """Return the value `np.sinc(x)` but improved to:
 114  (1) ignore underflow that occurs at 0.0 for np.float32, and
 115  (2) output exact zero for integer input values.
 116
 117  >>> _sinc(np.array([-3, -2, -1, 0], np.float32))
 118  array([0., 0., 0., 1.], dtype=float32)
 119
 120  >>> _sinc(np.array([-3, -2, -1, 0]))
 121  array([0., 0., 0., 1.])
 122
 123  >>> _sinc(0)
 124  1.0
 125  """
 126  x = np.asarray(x)
 127  x_is_scalar = x.ndim == 0
 128  with np.errstate(under='ignore'):
 129    result = np.sinc(np.atleast_1d(x))
 130    result[x == np.floor(x)] = 0.0
 131    result[x == 0] = 1.0
 132    return result.item() if x_is_scalar else result
 133
 134
 135def _is_symmetric(matrix: scipy.sparse.spmatrix, /, tol: float = 1e-10) -> bool:
 136  """Return True if the sparse matrix is symmetric."""
 137  norm: float = scipy.sparse.linalg.norm(matrix - matrix.T, np.inf)
 138  return norm <= tol
 139
 140
 141def _cache_sampled_1d_function(
 142    xmin: float,
 143    xmax: float,
 144    *,
 145    num_samples: int = 3_600,
 146    enable: bool = True,
 147) -> Callable[[Callable[[_ArrayLike], _NDArray]], Callable[[_ArrayLike], _NDArray]]:
 148  """Function decorator to linearly interpolate cached function values."""
 149  # Speed unchanged up to num_samples=12_000, then slow decrease until 100_000.
 150
 151  def wrap_it(func: Callable[[_ArrayLike], _NDArray]) -> Callable[[_ArrayLike], _NDArray]:
 152    if not enable:
 153      return func
 154
 155    dx = (xmax - xmin) / num_samples
 156    x = np.linspace(xmin, xmax + dx, num_samples + 2, dtype=np.float32)
 157    samples_func = func(x)
 158    assert np.all(samples_func[[0, -1, -2]] == 0.0)
 159
 160    @functools.wraps(func)
 161    def interpolate_using_cached_samples(x: _ArrayLike) -> _NDArray:
 162      x = np.asarray(x)
 163      index_float = np.clip((x - xmin) / dx, 0.0, num_samples)
 164      index = index_float.astype(np.int64)
 165      frac = np.subtract(index_float, index, dtype=np.float32)
 166      return (1 - frac) * samples_func[index] + frac * samples_func[index + 1]
 167
 168    return interpolate_using_cached_samples
 169
 170  return wrap_it
 171
 172
 173class _DownsampleIn2dUsingBoxFilter:
 174  """Fast 2D box-filter downsampling using cached numba-jitted functions."""
 175
 176  def __init__(self) -> None:
 177    # Downsampling function for params (dtype, block_height, block_width, ch).
 178    self._jitted_function: dict[tuple[_DType, int, int, int], Callable[[_NDArray], _NDArray]] = {}
 179
 180  def __call__(self, array: _NDArray, shape: tuple[int, int]) -> _NDArray:
 181    assert _USING_NUMBA
 182    assert array.ndim in (2, 3), array.ndim
 183    _check_eq(len(shape), 2)
 184    dtype = array.dtype
 185    a = array[..., None] if array.ndim == 2 else array
 186    height, width, ch = a.shape
 187    new_height, new_width = shape
 188    if height % new_height != 0 or width % new_width != 0:
 189      raise ValueError(f'Shape {array.shape} not a multiple of {shape}.')
 190    block_height, block_width = height // new_height, width // new_width
 191
 192    def func(array: _NDArray) -> _NDArray:
 193      new_height = array.shape[0] // block_height
 194      new_width = array.shape[1] // block_width
 195      result = np.empty((new_height, new_width, ch), dtype)
 196      totals = np.empty(ch, dtype)
 197      factor = dtype.type(1.0 / (block_height * block_width))
 198      for y in numba.prange(new_height):  # pylint: disable=not-an-iterable
 199        for x in range(new_width):
 200          # Introducing "y2, x2 = y * block_height, x * block_width" is actually slower.
 201          if ch == 1:  # All the branches involve compile-time constants.
 202            total = dtype.type(0.0)
 203            for yy in range(block_height):
 204              for xx in range(block_width):
 205                total += array[y * block_height + yy, x * block_width + xx, 0]
 206            result[y, x, 0] = total * factor
 207          elif ch == 3:
 208            total0 = total1 = total2 = dtype.type(0.0)
 209            for yy in range(block_height):
 210              for xx in range(block_width):
 211                total0 += array[y * block_height + yy, x * block_width + xx, 0]
 212                total1 += array[y * block_height + yy, x * block_width + xx, 1]
 213                total2 += array[y * block_height + yy, x * block_width + xx, 2]
 214            result[y, x, 0] = total0 * factor
 215            result[y, x, 1] = total1 * factor
 216            result[y, x, 2] = total2 * factor
 217          elif block_height * block_width >= 9:
 218            for c in range(ch):
 219              totals[c] = 0.0
 220            for yy in range(block_height):
 221              for xx in range(block_width):
 222                for c in range(ch):
 223                  totals[c] += array[y * block_height + yy, x * block_width + xx, c]
 224            for c in range(ch):
 225              result[y, x, c] = totals[c] * factor
 226          else:
 227            for c in range(ch):
 228              total = dtype.type(0.0)
 229              for yy in range(block_height):
 230                for xx in range(block_width):
 231                  total += array[y * block_height + yy, x * block_width + xx, c]
 232              result[y, x, c] = total * factor
 233      return result
 234
 235    signature = dtype, block_height, block_width, ch
 236    jitted_function = self._jitted_function.get(signature)
 237    if not jitted_function:
 238      if 0:
 239        print(f'Creating numba jit-wrapper for {signature}.')
 240      jitted_function = numba.njit(func, parallel=True, fastmath=True, cache=True)
 241      self._jitted_function[signature] = jitted_function
 242
 243    try:
 244      result = jitted_function(a)
 245    except RuntimeError:
 246      message = (
 247          'resampler: This runtime error may be due to a corrupt resampler/__pycache__;'
 248          ' try deleting that directory.'
 249      )
 250      print(message, file=sys.stdout, flush=True)
 251      print(message, file=sys.stderr, flush=True)
 252      raise
 253
 254    return result[..., 0] if array.ndim == 2 else result
 255
 256
 257_downsample_in_2d_using_box_filter = _DownsampleIn2dUsingBoxFilter()
 258
 259
 260@numba.njit(nogil=True, fastmath=True, cache=True)  # type: ignore[untyped-decorator]
 261def _numba_serial_csr_dense_mult(
 262    indptr: _NDArray,
 263    indices: _NDArray,
 264    data: _NDArray,
 265    src: _NDArray,
 266    dst: _NDArray,
 267) -> None:
 268  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 269
 270  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 271  """
 272  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 273  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 274  acc = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 275  for i in range(dst.shape[0]):
 276    acc[:] = 0
 277    for jj in range(indptr[i], indptr[i + 1]):
 278      j = indices[jj]
 279      acc += data[jj] * src[j]
 280    dst[i] = acc
 281
 282
 283# I tried using the "minimal" "parallel=" config but this did not result in any jit speedup:
 284#  parallel=dict(comprehension=False, prange=True, numpy=True, reduction=False,
 285#                setitem=False, stencil=False, fusion=False)
 286@numba.njit(parallel=True, fastmath=True, cache=True)  # type: ignore[untyped-decorator]
 287def _numba_parallel_csr_dense_mult(
 288    indptr: _NDArray,
 289    indices: _NDArray,
 290    data: _NDArray,
 291    src: _NDArray,
 292    dst: _NDArray,
 293) -> None:
 294  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 295
 296  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 297  The introduction of parallel omp threads provides another 2-4x speedup.
 298  However, "parallel=True" leads to slow jitting (~3 s), so we cache the jitted code on disk.
 299  """
 300  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 301  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 302  acc0 = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 303  # Default is static scheduling, which is fine.
 304  for i in numba.prange(dst.shape[0]):  # pylint: disable=not-an-iterable
 305    acc = np.zeros_like(acc0)  # Numba automatically hoists the allocation outside the loop.
 306    for jj in range(indptr[i], indptr[i + 1]):
 307      j = indices[jj]
 308      acc += data[jj] * src[j]
 309    dst[i] = acc
 310
 311
 312@dataclasses.dataclass
 313class _Arraylib(abc.ABC, Generic[_Array]):
 314  """Abstract base class for abstraction of array libraries."""
 315
 316  arraylib: str
 317  """Name of array library (e.g., `'numpy'`, `'tensorflow'`, `'torch'`, `'jax'`)."""
 318
 319  array: _Array
 320
 321  @staticmethod
 322  @abc.abstractmethod
 323  def recognize(array: Any) -> bool:
 324    """Return True if `array` is recognized by this _Arraylib."""
 325
 326  @abc.abstractmethod
 327  def numpy(self) -> _NDArray:
 328    """Return a `numpy` version of `self.array`."""
 329
 330  @abc.abstractmethod
 331  def dtype(self) -> _DType:
 332    """Return the equivalent of `self.array.dtype` as a `numpy` `dtype`."""
 333
 334  @abc.abstractmethod
 335  def astype(self, dtype: _DTypeLike) -> _Array:
 336    """Return the equivalent of `self.array.astype(dtype, copy=False)` with `numpy` `dtype`."""
 337
 338  def reshape(self, shape: tuple[int, ...]) -> _Array:
 339    """Return the equivalent of `self.array.reshape(shape)`."""
 340    return self.array.reshape(shape)
 341
 342  def possibly_make_contiguous(self) -> _Array:
 343    """Return a contiguous copy of `self.array` or just `self.array` if already contiguous."""
 344    return self.array
 345
 346  @abc.abstractmethod
 347  def clip(self, low: Any, high: Any, dtype: _DTypeLike | None = None) -> _Array:
 348    """Return the equivalent of `self.array.clip(low, high, dtype=dtype)` with `numpy` `dtype`."""
 349
 350  @abc.abstractmethod
 351  def square(self) -> _Array:
 352    """Return the equivalent of `np.square(self.array)`."""
 353
 354  @abc.abstractmethod
 355  def sqrt(self) -> _Array:
 356    """Return the equivalent of `np.sqrt(self.array)`."""
 357
 358  def getitem(self, indices: Any) -> _Array:
 359    """Return the equivalent of `self.array[indices]` (a "gather" operation)."""
 360    return self.array[indices]
 361
 362  @abc.abstractmethod
 363  def where(self, if_true: Any, if_false: Any) -> _Array:
 364    """Return the equivalent of `np.where(self.array, if_true, if_false)`."""
 365
 366  @abc.abstractmethod
 367  def transpose(self, axes: Sequence[int]) -> _Array:
 368    """Return the equivalent of `np.transpose(self.array, axes)`."""
 369
 370  @abc.abstractmethod
 371  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 372    """Return the best order in which to process dims for resizing `self.array` to `dst_shape`."""
 373
 374  @abc.abstractmethod
 375  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _Array:
 376    """Return the multiplication of the `sparse` matrix and `self.array`."""
 377
 378  @staticmethod
 379  @abc.abstractmethod
 380  def concatenate(arrays: Sequence[_Array], axis: int) -> _Array:
 381    """Return the equivalent of `np.concatenate(arrays, axis)`."""
 382
 383  @staticmethod
 384  @abc.abstractmethod
 385  def einsum(subscripts: str, *operands: _Array) -> _Array:
 386    """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 387
 388  @staticmethod
 389  @abc.abstractmethod
 390  def make_sparse_matrix(
 391      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 392  ) -> _Array:
 393    """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 394    However, the indices must be ordered and unique."""
 395
 396
 397class _NumpyArraylib(_Arraylib[_NDArray]):
 398  """Numpy implementation of the array abstraction."""
 399
 400  # pylint: disable=missing-function-docstring
 401
 402  def __init__(self, array: _NDArray) -> None:
 403    super().__init__(arraylib='numpy', array=np.asarray(array))
 404
 405  @staticmethod
 406  def recognize(array: Any) -> bool:
 407    return isinstance(array, (np.ndarray, np.number))
 408
 409  def numpy(self) -> _NDArray:
 410    return self.array
 411
 412  def dtype(self) -> _DType:
 413    dtype: _DType = self.array.dtype
 414    return dtype
 415
 416  def astype(self, dtype: _DTypeLike) -> _NDArray:
 417    return self.array.astype(dtype, copy=False)
 418
 419  def clip(self, low: Any, high: Any, dtype: _DTypeLike | None = None) -> _NDArray:
 420    return self.array.clip(low, high, dtype=dtype)
 421
 422  def square(self) -> _NDArray:
 423    return np.square(self.array)
 424
 425  def sqrt(self) -> _NDArray:
 426    return np.sqrt(self.array)
 427
 428  def where(self, if_true: Any, if_false: Any) -> _NDArray:
 429    condition = self.array
 430    return np.where(condition, if_true, if_false)
 431
 432  def transpose(self, axes: Sequence[int]) -> _NDArray:
 433    return np.transpose(self.array, tuple(axes))
 434
 435  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 436    # Our heuristics: (1) a dimension with small scaling (especially minification) gets priority,
 437    # and (2) timings show preference to resizing dimensions with larger strides first.
 438    # (Of course, tensorflow.Tensor lacks strides, so (2) does not apply.)
 439    # The optimal ordering might be related to the logic in np.einsum_path().  (Unfortunately,
 440    # np.einsum() does not support the sparse multiplications that we require here.)
 441    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 442    strides: Sequence[int] = self.array.strides
 443    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 444
 445    def priority(dim: int) -> float:
 446      scaling = dst_shape[dim] / src_shape[dim]
 447      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 448
 449    return sorted(range(len(src_shape)), key=priority)
 450
 451  def premult_with_sparse(
 452      self, sparse: scipy.sparse.csr_matrix, num_threads: int | Literal['auto']
 453  ) -> _NDArray:
 454    assert self.array.ndim == sparse.ndim == 2 and sparse.shape[1] == self.array.shape[0]
 455    # Empirically faster than with default numba.config.NUMBA_NUM_THREADS (e.g., 24).
 456    if _USING_NUMBA:
 457      num_threads2 = min(6, os.cpu_count() or 1) if num_threads == 'auto' else num_threads
 458      src = np.ascontiguousarray(self.array)  # Like .ravel() in _mul_multivector().
 459      dtype = np.result_type(sparse.dtype, src.dtype)
 460      dst = np.empty((sparse.shape[0], src.shape[1]), dtype)
 461      num_scalar_multiplies = len(sparse.data) * src.shape[1]
 462      is_small_size = num_scalar_multiplies < 200_000
 463      if is_small_size or num_threads2 == 1:
 464        _numba_serial_csr_dense_mult(sparse.indptr, sparse.indices, sparse.data, src, dst)
 465      else:
 466        numba.set_num_threads(num_threads2)
 467        _numba_parallel_csr_dense_mult(sparse.indptr, sparse.indices, sparse.data, src, dst)
 468      return dst
 469
 470    # Note that sicpy.sparse does not use multithreading.  The "@" operation
 471    # calls _spbase.__matmul__() -> _spbase._mul_dispatch() -> _cs_matrix._mul_multivector() ->
 472    # scipy.sparse._sparsetools.csr_matvecs() in
 473    # https://github.com/scipy/scipy/blob/main/scipy/sparse/sparsetools/csr.h
 474    # which iteratively calls the (in theory, LEVEL 1 BLAS) function axpy() in
 475    # https://github.com/scipy/scipy/blob/main/scipy/sparse/sparsetools/dense.h
 476    return sparse @ self.array
 477
 478  @staticmethod
 479  def concatenate(arrays: Sequence[_NDArray], axis: int) -> _NDArray:
 480    return np.concatenate(arrays, axis)
 481
 482  @staticmethod
 483  def einsum(subscripts: str, *operands: _NDArray) -> _NDArray:
 484    return np.einsum(subscripts, *operands, optimize=True)
 485
 486  @staticmethod
 487  def make_sparse_matrix(
 488      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 489  ) -> _NDArray:
 490    return scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=shape)
 491
 492
 493class _TensorflowArraylib(_Arraylib[_TensorflowTensor]):
 494  """Tensorflow implementation of the array abstraction."""
 495
 496  def __init__(self, array: _NDArray) -> None:
 497    import tensorflow
 498
 499    self.tf = tensorflow
 500    super().__init__(arraylib='tensorflow', array=self.tf.convert_to_tensor(array))
 501
 502  @staticmethod
 503  def recognize(array: Any) -> bool:
 504    # Eager: tensorflow.python.framework.ops.Tensor
 505    # Non-eager: tensorflow.python.ops.resource_variable_ops.ResourceVariable
 506    return type(array).__module__.startswith('tensorflow.')
 507
 508  def numpy(self) -> _NDArray:
 509    return self.array.numpy()
 510
 511  def dtype(self) -> _DType:
 512    return np.dtype(self.array.dtype.as_numpy_dtype)
 513
 514  def astype(self, dtype: _DTypeLike) -> _TensorflowTensor:
 515    return self.tf.cast(self.array, dtype)
 516
 517  def reshape(self, shape: tuple[int, ...]) -> _TensorflowTensor:
 518    return self.tf.reshape(self.array, shape)
 519
 520  def clip(self, low: Any, high: Any, dtype: _DTypeLike | None = None) -> _TensorflowTensor:
 521    array = self.array
 522    if dtype is not None:
 523      array = self.tf.cast(array, dtype)
 524    return self.tf.clip_by_value(array, low, high)
 525
 526  def square(self) -> _TensorflowTensor:
 527    return self.tf.square(self.array)
 528
 529  def sqrt(self) -> _TensorflowTensor:
 530    return self.tf.sqrt(self.array)
 531
 532  def getitem(self, indices: Any) -> _TensorflowTensor:
 533    if isinstance(indices, tuple):
 534      basic = all(isinstance(x, (type(None), type(Ellipsis), int, slice)) for x in indices)
 535      if not basic:
 536        # We require tf.gather_nd(), which unfortunately requires broadcast expansion of indices.
 537        assert all(isinstance(a, np.ndarray) for a in indices)
 538        assert all(a.ndim == indices[0].ndim for a in indices)
 539        broadcast_indices = np.broadcast_arrays(*indices)  # list of np.ndarray
 540        indices_array = np.moveaxis(np.array(broadcast_indices), 0, -1)
 541        return self.tf.gather_nd(self.array, indices_array)
 542    elif _arr_dtype(indices).type in (np.uint8, np.uint16):
 543      indices = self.tf.cast(indices, np.int32)
 544    return self.tf.gather(self.array, indices)
 545
 546  def where(self, if_true: Any, if_false: Any) -> _TensorflowTensor:
 547    condition = self.array
 548    return self.tf.where(condition, if_true, if_false)
 549
 550  def transpose(self, axes: Sequence[int]) -> _TensorflowTensor:
 551    return self.tf.transpose(self.array, tuple(axes))
 552
 553  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 554    # Note that a tensorflow.Tensor does not have strides.
 555    # Our heuristic is to process dimension 1 first iff dimension 0 is upsampling.  Improve?
 556    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 557    dims = list(range(len(src_shape)))
 558    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 559      dims[:2] = [1, 0]
 560    return dims
 561
 562  def premult_with_sparse(
 563      self, sparse: 'tf.sparse.SparseTensor', num_threads: int | Literal['auto']
 564  ) -> _TensorflowTensor:
 565    import tensorflow as tf
 566
 567    del num_threads
 568    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 569      sparse = sparse.with_values(_arr_astype(sparse.values, _arr_dtype(self.array)))
 570    return tf.sparse.sparse_dense_matmul(sparse, self.array)
 571
 572  @staticmethod
 573  def concatenate(arrays: Sequence[_TensorflowTensor], axis: int) -> _TensorflowTensor:
 574    import tensorflow as tf
 575
 576    return tf.concat(arrays, axis)
 577
 578  @staticmethod
 579  def einsum(subscripts: str, *operands: _TensorflowTensor) -> _TensorflowTensor:
 580    import tensorflow as tf
 581
 582    return tf.einsum(subscripts, *operands, optimize='greedy')
 583
 584  @staticmethod
 585  def make_sparse_matrix(
 586      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 587  ) -> _TensorflowTensor:
 588    import tensorflow as tf
 589
 590    indices = np.vstack((row_ind, col_ind)).T
 591    return tf.sparse.SparseTensor(indices, data, shape)
 592
 593
 594class _TorchArraylib(_Arraylib[_TorchTensor]):
 595  """Torch implementation of the array abstraction."""
 596
 597  # pylint: disable=missing-function-docstring
 598
 599  def __init__(self, array: _NDArray) -> None:
 600    import torch
 601
 602    self.torch = torch
 603    super().__init__(arraylib='torch', array=self.torch.as_tensor(array))
 604
 605  @staticmethod
 606  def recognize(array: Any) -> bool:
 607    return type(array).__module__ == 'torch'
 608
 609  def numpy(self) -> _NDArray:
 610    return self.array.numpy()
 611
 612  def dtype(self) -> _DType:
 613    numpy_type = {
 614        self.torch.float32: np.float32,
 615        self.torch.float64: np.float64,
 616        self.torch.complex64: np.complex64,
 617        self.torch.complex128: np.complex128,
 618        self.torch.uint8: np.uint8,  # No uint16, uint32, uint64.
 619        self.torch.int16: np.int16,
 620        self.torch.int32: np.int32,
 621        self.torch.int64: np.int64,
 622    }[self.array.dtype]
 623    return np.dtype(numpy_type)
 624
 625  def astype(self, dtype: _DTypeLike) -> _TorchTensor:
 626    torch_type = {
 627        np.float32: self.torch.float32,
 628        np.float64: self.torch.float64,
 629        np.complex64: self.torch.complex64,
 630        np.complex128: self.torch.complex128,
 631        np.uint8: self.torch.uint8,  # No uint16, uint32, uint64.
 632        np.int16: self.torch.int16,
 633        np.int32: self.torch.int32,
 634        np.int64: self.torch.int64,
 635    }[np.dtype(dtype).type]
 636    return self.array.type(torch_type)
 637
 638  def possibly_make_contiguous(self) -> _TorchTensor:
 639    return self.array.contiguous()
 640
 641  def clip(self, low: Any, high: Any, dtype: _DTypeLike | None = None) -> _TorchTensor:
 642    array = self.array
 643    array = _arr_astype(array, dtype) if dtype is not None else array
 644    return array.clip(low, high)
 645
 646  def square(self) -> _TorchTensor:
 647    return self.array.square()
 648
 649  def sqrt(self) -> _TorchTensor:
 650    return self.array.sqrt()
 651
 652  def getitem(self, indices: Any) -> _TorchTensor:
 653    if not isinstance(indices, tuple):
 654      indices = indices.type(self.torch.int64)
 655    return self.array[indices]  # pylint: disable=unsubscriptable-object
 656
 657  def where(self, if_true: Any, if_false: Any) -> _TorchTensor:
 658    condition = self.array
 659    return if_true.where(condition, if_false)
 660
 661  def transpose(self, axes: Sequence[int]) -> _TorchTensor:
 662    return self.torch.permute(self.array, tuple(axes))
 663
 664  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 665    # Similar to `_NumpyArraylib`.  We access `array.stride()` instead of `array.strides`.
 666    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 667    strides: Sequence[int] = self.array.stride()
 668    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 669
 670    def priority(dim: int) -> float:
 671      scaling = dst_shape[dim] / src_shape[dim]
 672      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 673
 674    return sorted(range(len(src_shape)), key=priority)
 675
 676  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _TorchTensor:
 677    del num_threads
 678    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 679      sparse = _arr_astype(sparse, _arr_dtype(self.array))
 680    return sparse @ self.array  # Calls torch.sparse.mm().
 681
 682  @staticmethod
 683  def concatenate(arrays: Sequence[_TorchTensor], axis: int) -> _TorchTensor:
 684    import torch
 685
 686    return torch.cat(tuple(arrays), axis)
 687
 688  @staticmethod
 689  def einsum(subscripts: str, *operands: _TorchTensor) -> _TorchTensor:
 690    import torch
 691
 692    operands = tuple(torch.as_tensor(operand) for operand in operands)
 693    if any(np.issubdtype(_arr_dtype(array), np.complexfloating) for array in operands):
 694      operands = tuple(
 695          _arr_astype(array, _complex_precision(_arr_dtype(array))) for array in operands
 696      )
 697    return torch.einsum(subscripts, *operands)
 698
 699  @staticmethod
 700  def make_sparse_matrix(
 701      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 702  ) -> _TorchTensor:
 703    import torch
 704
 705    indices = np.vstack((row_ind, col_ind))
 706    with torch.sparse.check_sparse_tensor_invariants(enable=False):
 707      return torch.sparse_coo_tensor(torch.as_tensor(indices), torch.as_tensor(data), shape)
 708    # .coalesce() is unnecessary because indices/data are already merged.
 709
 710
 711class _JaxArraylib(_Arraylib[_JaxArray]):
 712  """Jax implementation of the array abstraction."""
 713
 714  def __init__(self, array: _NDArray) -> None:
 715    import jax.numpy
 716
 717    self.jnp = jax.numpy
 718    super().__init__(arraylib='jax', array=self.jnp.asarray(array))
 719
 720  @staticmethod
 721  def recognize(array: Any) -> bool:
 722    # e.g., jaxlib.xla_extension.DeviceArray, jax.interpreters.ad.JVPTracer
 723    return type(array).__module__.startswith(('jaxlib.', 'jax.'))
 724
 725  def numpy(self) -> _NDArray:
 726    # 2023-01-09: jax 0.3.17 "DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead."
 727    # Whereas array.to_py() and np.asarray(array) may return a non-writable np.ndarray,
 728    # np.array(array) always returns a writable array but the copy may be more costly.
 729    # return self.array.to_py()
 730    return np.asarray(self.array)
 731
 732  def dtype(self) -> _DType:
 733    return np.dtype(self.array.dtype)
 734
 735  def astype(self, dtype: _DTypeLike) -> _JaxArray:
 736    return self.array.astype(dtype)  # (copy=False is unavailable)
 737
 738  def possibly_make_contiguous(self) -> _JaxArray:
 739    return self.array.copy()
 740
 741  def clip(self, low: Any, high: Any, dtype: _DTypeLike | None = None) -> _JaxArray:
 742    array = self.array
 743    if dtype is not None:
 744      array = array.astype(dtype)  # (copy=False is unavailable)
 745    return self.jnp.clip(array, low, high)
 746
 747  def square(self) -> _JaxArray:
 748    return self.jnp.square(self.array)
 749
 750  def sqrt(self) -> _JaxArray:
 751    return self.jnp.sqrt(self.array)
 752
 753  def where(self, if_true: Any, if_false: Any) -> _JaxArray:
 754    condition = self.array
 755    return self.jnp.where(condition, if_true, if_false)
 756
 757  def transpose(self, axes: Sequence[int]) -> _JaxArray:
 758    return self.jnp.transpose(self.array, tuple(axes))
 759
 760  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 761    # Jax/XLA does not have strides.  Arrays are contiguous, almost always in C order; see
 762    # https://github.com/google/jax/discussions/7544#discussioncomment-1197038.
 763    # We use a heuristic similar to `_TensorflowArraylib`.
 764    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 765    dims = list(range(len(src_shape)))
 766    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 767      dims[:2] = [1, 0]
 768    return dims
 769
 770  def premult_with_sparse(
 771      self, sparse: 'jax.experimental.sparse.BCOO', num_threads: int | Literal['auto']
 772  ) -> _JaxArray:
 773    del num_threads
 774    return sparse @ self.array  # Calls jax.bcoo_multiply_dense().
 775
 776  @staticmethod
 777  def concatenate(arrays: Sequence[_JaxArray], axis: int) -> _JaxArray:
 778    import jax.numpy as jnp
 779
 780    return jnp.concatenate(arrays, axis)
 781
 782  @staticmethod
 783  def einsum(subscripts: str, *operands: _JaxArray) -> _JaxArray:
 784    import jax.numpy as jnp
 785
 786    return jnp.einsum(subscripts, *operands, optimize='greedy')
 787
 788  @staticmethod
 789  def make_sparse_matrix(
 790      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 791  ) -> _JaxArray:
 792    # https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html
 793    import jax.experimental.sparse
 794
 795    indices = np.vstack((row_ind, col_ind)).T
 796    return jax.experimental.sparse.BCOO(
 797        (data, indices), shape=shape, indices_sorted=True, unique_indices=True
 798    )
 799
 800
 801_CANDIDATE_ARRAYLIBS = {
 802    'numpy': _NumpyArraylib,
 803    'tensorflow': _TensorflowArraylib,
 804    'torch': _TorchArraylib,
 805    'jax': _JaxArraylib,
 806}
 807
 808
 809def _is_available(arraylib: str) -> bool:
 810  """Return whether the array library (e.g. 'tensorflow') is available as an installed package."""
 811  # Faster than trying to import it.
 812  return importlib.util.find_spec(arraylib) is not None  # type: ignore[attr-defined]
 813
 814
 815_DICT_ARRAYLIBS = {
 816    arraylib: cls for arraylib, cls in _CANDIDATE_ARRAYLIBS.items() if _is_available(arraylib)
 817}
 818
 819ARRAYLIBS = list(_DICT_ARRAYLIBS)
 820"""Array libraries supported automatically in the resize and resampling operations.
 821
 822- The library is selected automatically based on the type of the `array` function parameter.
 823
 824- The class `_Arraylib` provides library-specific implementations of needed basic functions.
 825
 826- The `_arr_*()` functions dispatch the `_Arraylib` methods based on the array type.
 827"""
 828
 829
 830def _as_arr(array: _Array, /) -> _Arraylib[_Array]:
 831  """Return `array` wrapped as an `_Arraylib` for dispatch of functions."""
 832  for cls in _DICT_ARRAYLIBS.values():
 833    if cls.recognize(array):
 834      return cls(array)  # type: ignore[abstract]
 835  raise ValueError(f'{array} {type(array)} {type(array).__module__} unrecognized by {ARRAYLIBS}.')
 836
 837
 838def _arr_arraylib(array: _Array, /) -> str:
 839  """Return the name of the `Arraylib` representing `array`."""
 840  return _as_arr(array).arraylib
 841
 842
 843def _arr_numpy(array: _Array, /) -> _NDArray:
 844  """Return a `numpy` version of `array`."""
 845  return _as_arr(array).numpy()
 846
 847
 848def _arr_dtype(array: _Array, /) -> _DType:
 849  """Return the equivalent of `array.dtype` as a `numpy` `dtype`."""
 850  return _as_arr(array).dtype()
 851
 852
 853def _arr_astype(array: _Array, dtype: _DTypeLike, /) -> _Array:
 854  """Return the equivalent of `array.astype(dtype)` with `numpy` `dtype`."""
 855  return _as_arr(array).astype(dtype)
 856
 857
 858def _arr_reshape(array: _Array, shape: tuple[int, ...], /) -> _Array:
 859  """Return the equivalent of `array.reshape(shape)."""
 860  return _as_arr(array).reshape(shape)
 861
 862
 863def _arr_possibly_make_contiguous(array: _Array, /) -> _Array:
 864  """Return a contiguous copy of `array` or just `array` if already contiguous."""
 865  return _as_arr(array).possibly_make_contiguous()
 866
 867
 868def _arr_clip(
 869    array: _Array, low: _Array, high: _Array, /, dtype: _DTypeLike | None = None
 870) -> _Array:
 871  """Return the equivalent of `array.clip(low, high, dtype)` with `numpy` `dtype`."""
 872  return _as_arr(array).clip(low, high, dtype)
 873
 874
 875def _arr_square(array: _Array, /) -> _Array:
 876  """Return the equivalent of `np.square(array)`."""
 877  return _as_arr(array).square()
 878
 879
 880def _arr_sqrt(array: _Array, /) -> _Array:
 881  """Return the equivalent of `np.sqrt(array)`."""
 882  return _as_arr(array).sqrt()
 883
 884
 885def _arr_getitem(array: _Array, indices: _Array, /) -> _Array:
 886  """Return the equivalent of `array[indices]`."""
 887  return _as_arr(array).getitem(indices)
 888
 889
 890def _arr_where(condition: _Array, if_true: _Array, if_false: _Array, /) -> _Array:
 891  """Return the equivalent of `np.where(condition, if_true, if_false)`."""
 892  return _as_arr(condition).where(if_true, if_false)
 893
 894
 895def _arr_transpose(array: _Array, axes: Sequence[int], /) -> _Array:
 896  """Return the equivalent of `np.transpose(array, axes)`."""
 897  return _as_arr(array).transpose(axes)
 898
 899
 900def _arr_best_dims_order_for_resize(array: _Array, dst_shape: tuple[int, ...], /) -> list[int]:
 901  """Return the best order in which to process dims for resizing `array` to `dst_shape`."""
 902  return _as_arr(array).best_dims_order_for_resize(dst_shape)
 903
 904
 905def _arr_matmul_sparse_dense(
 906    sparse: Any, dense: _Array, /, *, num_threads: int | Literal['auto'] = 'auto'
 907) -> _Array:
 908  """Return the multiplication of the `sparse` and `dense` matrices."""
 909  assert num_threads == 'auto' or num_threads >= 1
 910  return _as_arr(dense).premult_with_sparse(sparse, num_threads)
 911
 912
 913def _arr_concatenate(arrays: Sequence[_Array], axis: int, /) -> _Array:
 914  """Return the equivalent of `np.concatenate(arrays, axis)`."""
 915  arraylib = _arr_arraylib(arrays[0])
 916  return _DICT_ARRAYLIBS[arraylib].concatenate(arrays, axis)
 917
 918
 919def _arr_einsum(subscripts: str, /, *operands: _Array) -> _Array:
 920  """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 921  arraylib = _arr_arraylib(operands[0])
 922  return _DICT_ARRAYLIBS[arraylib].einsum(subscripts, *operands)
 923
 924
 925def _arr_swapaxes(array: _Array, axis1: int, axis2: int, /) -> _Array:
 926  """Return the equivalent of `np.swapaxes(array, axis1, axis2)`."""
 927  ndim = len(array.shape)
 928  assert 0 <= axis1 < ndim and 0 <= axis2 < ndim, (axis1, axis2, ndim)
 929  axes = list(range(ndim))
 930  axes[axis1] = axis2
 931  axes[axis2] = axis1
 932  return _arr_transpose(array, axes)
 933
 934
 935def _arr_moveaxis(array: _Array, source: int, destination: int, /) -> _Array:
 936  """Return the equivalent of `np.moveaxis(array, source, destination)`."""
 937  ndim = len(array.shape)
 938  assert 0 <= source < ndim and 0 <= destination < ndim, (source, destination, ndim)
 939  axes = [n for n in range(ndim) if n != source]
 940  axes.insert(destination, source)
 941  return _arr_transpose(array, axes)
 942
 943
 944def _make_sparse_matrix(
 945    data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int], arraylib: str, /
 946) -> Any:
 947  """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 948  However, indices must be ordered and unique."""
 949  return _DICT_ARRAYLIBS[arraylib].make_sparse_matrix(data, row_ind, col_ind, shape)
 950
 951
 952def _make_array(array: _ArrayLike, arraylib: str, /) -> Any:
 953  """Return an array from the library `arraylib` initialized with the `numpy` `array`."""
 954  return _DICT_ARRAYLIBS[arraylib](np.asarray(array)).array  # type: ignore[abstract]
 955
 956
 957# Because np.ndarray supports strides, np.moveaxis() and np.permute() are constant-time.
 958# However, ndarray.reshape() often creates a copy of the array if the data is non-contiguous,
 959# e.g. dim=1 in an RGB image.
 960#
 961# In contrast, tf.Tensor does not support strides, so tf.transpose() returns a new permuted
 962# tensor.  However, tf.reshape() is always efficient.
 963
 964
 965def _block_shape_with_min_size(
 966    shape: tuple[int, ...], min_size: int, compact: bool = True
 967) -> tuple[int, ...]:
 968  """Return shape of block (of size at least `min_size`) to subdivide shape."""
 969  if math.prod(shape) < min_size:
 970    raise ValueError(f'Shape {shape} smaller than min_size {min_size}.')
 971  if compact:
 972    root = int(math.ceil(min_size ** (1 / len(shape))))
 973    block_shape = np.minimum(shape, root)
 974    for dim in range(len(shape)):
 975      if block_shape[dim] == 2 and block_shape.prod() >= min_size * 2:
 976        block_shape[dim] = 1
 977    for dim in range(len(shape) - 1, -1, -1):
 978      if block_shape.prod() < min_size:
 979        block_shape[dim] = shape[dim]
 980  else:
 981    block_shape = np.ones_like(shape)
 982    for dim in range(len(shape) - 1, -1, -1):
 983      if block_shape.prod() < min_size:
 984        block_shape[dim] = min(shape[dim], math.ceil(min_size / block_shape.prod()))
 985  return tuple(block_shape)
 986
 987
 988def _array_split(array: _Array, axis: int, num_sections: int) -> list[Any]:
 989  """Split `array` into `num_sections` along `axis`."""
 990  assert 0 <= axis < len(array.shape)
 991  assert 1 <= num_sections <= array.shape[axis]
 992
 993  if 0:
 994    split = np.array_split(array, num_sections, axis=axis)  # Numpy-specific.
 995
 996  else:
 997    # Adapted from https://github.com/numpy/numpy/blob/main/numpy/lib/shape_base.py#L739-L792.
 998    num_total = array.shape[axis]
 999    num_each, num_extra = divmod(num_total, num_sections)
1000    section_sizes = [0] + num_extra * [num_each + 1] + (num_sections - num_extra) * [num_each]
1001    div_points = np.array(section_sizes).cumsum()
1002    split = []
1003    tmp = _arr_swapaxes(array, axis, 0)
1004    for i in range(num_sections):
1005      split.append(_arr_swapaxes(tmp[div_points[i] : div_points[i + 1]], axis, 0))
1006
1007  return split
1008
1009
1010def _split_array_into_blocks(array: _Array, block_shape: Sequence[int], start_axis: int = 0) -> Any:
1011  """Split `array` into nested lists of blocks of size at most `block_shape`."""
1012  # See https://stackoverflow.com/a/50305924.  (If the block_shape is known to
1013  # exactly partition the array, see https://stackoverflow.com/a/16858283.)
1014  if len(block_shape) > len(array.shape):
1015    raise ValueError(f'Block ndim {len(block_shape)} > array ndim {len(array.shape)}.')
1016  if start_axis == len(block_shape):
1017    return array
1018
1019  num_sections = math.ceil(array.shape[start_axis] / block_shape[start_axis])
1020  split = _array_split(array, start_axis, num_sections)
1021  return [_split_array_into_blocks(split_a, block_shape, start_axis + 1) for split_a in split]
1022
1023
1024def _map_function_over_blocks(blocks: Any, func: Callable[[Any], Any]) -> Any:
1025  """Apply `func` to each block in the nested lists of `blocks`."""
1026  if isinstance(blocks, list):
1027    return [_map_function_over_blocks(block, func) for block in blocks]
1028  return func(blocks)
1029
1030
1031def _merge_array_from_blocks(blocks: Any, axis: int = 0) -> Any:
1032  """Merge an array from the nested lists of array blocks in `blocks`."""
1033  # More general than np.block() because the blocks can have additional dims.
1034  if isinstance(blocks, list):
1035    new_blocks = [_merge_array_from_blocks(block, axis + 1) for block in blocks]
1036    return _arr_concatenate(new_blocks, axis)
1037  return blocks
1038
1039
1040@dataclasses.dataclass(frozen=True)
1041class Gridtype(abc.ABC):
1042  """Abstract base class for grid-types such as `'dual'` and `'primal'`.
1043
1044  In resampling operations, the grid-type may be specified separately as `src_gridtype` for the
1045  source domain and `dst_gridtype` for the destination domain.  Moreover, the grid-type may be
1046  specified per domain dimension.
1047
1048  Examples:
1049    `resize(source, shape, gridtype='primal')`  # Sets both src and dst to be `'primal'` grids.
1050
1051    `resize(source, shape, src_gridtype=['dual', 'primal'],
1052            dst_gridtype='dual')`  # Source is `'dual'` in dim0 and `'primal'` in dim1.
1053  """
1054
1055  name: str
1056  """Gridtype name."""
1057
1058  @abc.abstractmethod
1059  def min_size(self) -> int:
1060    """Return the necessary minimum number of grid samples."""
1061
1062  @abc.abstractmethod
1063  def size_in_samples(self, size: int, /) -> int:
1064    """Return the domain size in units of inter-sample spacing."""
1065
1066  @abc.abstractmethod
1067  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1068    """Return [0.0, 1.0] coordinates given [0, size - 1] indices."""
1069
1070  @abc.abstractmethod
1071  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1072    """Return location x given coordinates [0.0, 1.0], where x == 0.0 is the first grid sample
1073    and x == size - 1.0 is the last grid sample."""
1074
1075  @abc.abstractmethod
1076  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1077    """Map integer sample indices to interior ones using boundary reflection."""
1078
1079  @abc.abstractmethod
1080  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1081    """Map integer sample indices to interior ones using wrapping."""
1082
1083  @abc.abstractmethod
1084  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1085    """Map integer sample indices to interior ones using reflect-clamp."""
1086
1087
1088class DualGridtype(Gridtype):
1089  """Samples are at the center of cells in a uniform partition of the domain.
1090
1091  For a unit-domain dimension with N samples, each sample 0 <= i < N has position (i + 0.5) / N,
1092  e.g., [0.125, 0.375, 0.625, 0.875] for N = 4.
1093  """
1094
1095  def __init__(self) -> None:
1096    super().__init__(name='dual')
1097
1098  def min_size(self) -> int:
1099    return 1
1100
1101  def size_in_samples(self, size: int, /) -> int:
1102    return size
1103
1104  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1105    return (index + 0.5) / size
1106
1107  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1108    return point * size - 0.5
1109
1110  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1111    index = np.mod(index, size * 2)
1112    return np.where(index < size, index, 2 * size - 1 - index)
1113
1114  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1115    return np.mod(index, size)
1116
1117  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1118    return np.minimum(np.where(index < 0, -1 - index, index), size - 1)
1119
1120
1121class PrimalGridtype(Gridtype):
1122  """Samples are at the vertices of cells in a uniform partition of the domain.
1123
1124  For a unit-domain dimension with N samples, each sample 0 <= i < N has position i / (N - 1),
1125  e.g., [0, 1/3, 2/3, 1] for N = 4.
1126  """
1127
1128  def __init__(self) -> None:
1129    super().__init__(name='primal')
1130
1131  def min_size(self) -> int:
1132    return 2
1133
1134  def size_in_samples(self, size: int, /) -> int:
1135    return size - 1
1136
1137  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1138    return index / (size - 1)
1139
1140  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1141    return point * (size - 1)
1142
1143  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1144    index = np.mod(index, size * 2 - 2)
1145    return np.where(index < size, index, 2 * size - 2 - index)
1146
1147  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1148    return np.mod(index, size - 1)
1149
1150  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1151    return np.minimum(np.abs(index), size - 1)
1152
1153
1154_DICT_GRIDTYPES = {
1155    'dual': DualGridtype(),
1156    'primal': PrimalGridtype(),
1157}
1158
1159GRIDTYPES = list(_DICT_GRIDTYPES)
1160r"""Shortcut names for the two predefined grid types (specified per dimension):
1161
1162| `gridtype` | `'dual'`<br/>`DualGridtype()`<br/>(default) | `'primal'`<br/>`PrimalGridtype()`<br/>&nbsp; |
1163| --- |:---:|:---:|
1164| Sample positions in 2D<br/>and in 1D at different resolutions | ![Dual](https://github.com/hhoppe/resampler/raw/main/media/dual_grid_small.png) | ![Primal](https://github.com/hhoppe/resampler/raw/main/media/primal_grid_small.png) |
1165| Nesting of samples across resolutions | The samples positions do *not* nest. | The *even* samples remain at coarser scale. |
1166| Number $N_\ell$ of samples (per-dimension) at resolution level $\ell$ | $N_\ell=2^\ell$ | $N_\ell=2^\ell+1$ |
1167| Position of sample index $i$ within domain $[0, 1]$ | $\frac{i + 0.5}{N}$ ("half-integer" coordinates) | $\frac{i}{N-1}$ |
1168| Image resolutions ($N_\ell\times N_\ell$) for dyadic scales | $1\times1, ~~2\times2, ~~4\times4, ~~8\times8, ~\ldots$ | $2\times2, ~~3\times3, ~~5\times5, ~~9\times9, ~\ldots$ |
1169
1170See the source code for extensibility.
1171"""
1172
1173
1174def _get_gridtype(gridtype: str | Gridtype) -> Gridtype:
1175  """Return a `Gridtype`, which can be specified as a name in `GRIDTYPES`."""
1176  return gridtype if isinstance(gridtype, Gridtype) else _DICT_GRIDTYPES[gridtype]
1177
1178
1179def _get_gridtypes(
1180    gridtype: str | Gridtype | None,
1181    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1182    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1183    src_ndim: int,
1184    dst_ndim: int,
1185) -> tuple[list[Gridtype], list[Gridtype]]:
1186  """Return per-dim source and destination grid types given all parameters."""
1187  if gridtype is None and src_gridtype is None and dst_gridtype is None:
1188    gridtype = 'dual'
1189  if gridtype is not None:
1190    if src_gridtype is not None:
1191      raise ValueError('Cannot have both gridtype and src_gridtype.')
1192    if dst_gridtype is not None:
1193      raise ValueError('Cannot have both gridtype and dst_gridtype.')
1194    src_gridtype = dst_gridtype = gridtype
1195  src_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(src_gridtype), src_ndim)]
1196  dst_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(dst_gridtype), dst_ndim)]
1197  return src_gridtype2, dst_gridtype2
1198
1199
1200@dataclasses.dataclass(frozen=True)
1201class RemapCoordinates(abc.ABC):
1202  """Abstract base class for modifying the specified coordinates prior to evaluating the
1203  reconstruction kernels."""
1204
1205  @abc.abstractmethod
1206  def __call__(self, point: _NDArray, /) -> _NDArray:
1207    ...
1208
1209
1210class NoRemapCoordinates(RemapCoordinates):
1211  """The coordinates are not remapped."""
1212
1213  def __call__(self, point: _NDArray, /) -> _NDArray:
1214    return point
1215
1216
1217class MirrorRemapCoordinates(RemapCoordinates):
1218  """The coordinates are reflected across the domain boundaries so that they lie in the unit
1219  interval.  The resulting function is continuous but not smooth across the boundaries."""
1220
1221  def __call__(self, point: _NDArray, /) -> _NDArray:
1222    point = np.mod(point, 2.0)
1223    return np.where(point >= 1.0, 2.0 - point, point)
1224
1225
1226class TileRemapCoordinates(RemapCoordinates):
1227  """The coordinates are mapped to the unit interval using a "modulo 1.0" operation.  The resulting
1228  function is generally discontinuous across the domain boundaries."""
1229
1230  def __call__(self, point: _NDArray, /) -> _NDArray:
1231    return np.mod(point, 1.0)
1232
1233
1234@dataclasses.dataclass(frozen=True)
1235class ExtendSamples(abc.ABC):
1236  """Abstract base class for replacing references to grid samples exterior to the unit domain by
1237  affine combinations of interior sample(s) and possibly the constant value (`cval`)."""
1238
1239  uses_cval: bool = False
1240  """True if some exterior samples are defined in terms of `cval`, i.e., if the computed weight
1241  is non-affine."""
1242
1243  @abc.abstractmethod
1244  def __call__(
1245      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1246  ) -> tuple[_NDArray, _NDArray]:
1247    """Detect references to exterior samples, i.e., entries of `index` that lie outside the
1248    interval [0, size), and update these indices (and possibly their associated weights) to
1249    reference only interior samples.  Return `new_index, new_weight`."""
1250
1251
1252class ReflectExtendSamples(ExtendSamples):
1253  """Find the interior sample by reflecting across domain boundaries."""
1254
1255  def __call__(
1256      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1257  ) -> tuple[_NDArray, _NDArray]:
1258    index = gridtype.reflect(index, size)
1259    return index, weight
1260
1261
1262class WrapExtendSamples(ExtendSamples):
1263  """Wrap the interior samples periodically.  For a `'primal'` grid, the last
1264  sample is ignored as its value is replaced by the first sample."""
1265
1266  def __call__(
1267      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1268  ) -> tuple[_NDArray, _NDArray]:
1269    index = gridtype.wrap(index, size)
1270    return index, weight
1271
1272
1273class ClampExtendSamples(ExtendSamples):
1274  """Use the nearest interior sample."""
1275
1276  def __call__(
1277      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1278  ) -> tuple[_NDArray, _NDArray]:
1279    index = index.clip(0, size - 1)
1280    return index, weight
1281
1282
1283class ReflectClampExtendSamples(ExtendSamples):
1284  """Extend the grid samples from [0, 1] into [-1, 0] using reflection and then define grid
1285  samples outside [-1, 1] as that of the nearest sample."""
1286
1287  def __call__(
1288      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1289  ) -> tuple[_NDArray, _NDArray]:
1290    index = gridtype.reflect_clamp(index, size)
1291    return index, weight
1292
1293
1294class BorderExtendSamples(ExtendSamples):
1295  """Let all exterior samples have the constant value (`cval`)."""
1296
1297  def __init__(self) -> None:
1298    super().__init__(uses_cval=True)
1299
1300  def __call__(
1301      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1302  ) -> tuple[_NDArray, _NDArray]:
1303    low = index < 0
1304    weight[low] = 0.0
1305    index[low] = 0
1306    high = index >= size
1307    weight[high] = 0.0
1308    index[high] = size - 1
1309    return index, weight
1310
1311
1312class ValidExtendSamples(ExtendSamples):
1313  """Assign all domain samples weight 1 and all outside samples weight 0.
1314  Compute a weighted reconstruction and divide by the reconstructed weight."""
1315
1316  def __init__(self) -> None:
1317    super().__init__(uses_cval=True)
1318
1319  def __call__(
1320      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1321  ) -> tuple[_NDArray, _NDArray]:
1322    low = index < 0
1323    weight[low] = 0.0
1324    index[low] = 0
1325    high = index >= size
1326    weight[high] = 0.0
1327    index[high] = size - 1
1328    sum_weight = weight.sum(axis=-1)
1329    nonzero_sum = sum_weight != 0.0
1330    np.divide(weight, sum_weight[..., None], out=weight, where=nonzero_sum[..., None])
1331    return index, weight
1332
1333
1334class LinearExtendSamples(ExtendSamples):
1335  """Linearly extrapolate beyond boundary samples."""
1336
1337  def __call__(
1338      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1339  ) -> tuple[_NDArray, _NDArray]:
1340    if size < 2:
1341      index = gridtype.reflect(index, size)
1342      return index, weight
1343    # For each boundary, define new columns in index and weight arrays to represent the last and
1344    # next-to-last samples.  When we later construct the sparse resize matrix, we will sum the
1345    # duplicate index entries.
1346    low = index < 0
1347    high = index >= size
1348    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 4), weight.dtype)
1349    x = index
1350    w[..., -4] = ((1 - x) * weight).sum(where=low, axis=-1)
1351    w[..., -3] = ((x) * weight).sum(where=low, axis=-1)
1352    x = (size - 1) - index
1353    w[..., -2] = ((x) * weight).sum(where=high, axis=-1)
1354    w[..., -1] = ((1 - x) * weight).sum(where=high, axis=-1)
1355    weight[low] = 0.0
1356    index[low] = 0
1357    weight[high] = 0.0
1358    index[high] = size - 1
1359    w[..., :-4] = weight
1360    weight = w
1361    new_index = np.empty(w.shape, index.dtype)
1362    new_index[..., :-4] = index
1363    # Let matrix (including zero values) be banded.
1364    new_index[..., -4:] = np.where(w[..., -4:] != 0.0, [0, 1, size - 2, size - 1], index[..., :1])
1365    index = new_index
1366    return index, weight
1367
1368
1369class QuadraticExtendSamples(ExtendSamples):
1370  """Quadratically extrapolate beyond boundary samples."""
1371
1372  def __call__(
1373      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1374  ) -> tuple[_NDArray, _NDArray]:
1375    # [Keys 1981] suggests this as x[-1] = 3*x[0] - 3*x[1] + x[2], calling it "cubic precision",
1376    # but it seems just quadratic.
1377    if size < 3:
1378      index = gridtype.reflect(index, size)
1379      return index, weight
1380    low = index < 0
1381    high = index >= size
1382    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 6), weight.dtype)
1383    x = index
1384    w[..., -6] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=low, axis=-1)
1385    w[..., -5] = (((-x + 2) * x) * weight).sum(where=low, axis=-1)
1386    w[..., -4] = (((0.5 * x - 0.5) * x) * weight).sum(where=low, axis=-1)
1387    x = (size - 1) - index
1388    w[..., -3] = (((0.5 * x - 0.5) * x) * weight).sum(where=high, axis=-1)
1389    w[..., -2] = (((-x + 2) * x) * weight).sum(where=high, axis=-1)
1390    w[..., -1] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=high, axis=-1)
1391    weight[low] = 0.0
1392    index[low] = 0
1393    weight[high] = 0.0
1394    index[high] = size - 1
1395    w[..., :-6] = weight
1396    weight = w
1397    new_index = np.empty(w.shape, index.dtype)
1398    new_index[..., :-6] = index
1399    # Let matrix (including zero values) be banded.
1400    new_index[..., -6:] = np.where(
1401        w[..., -6:] != 0.0, [0, 1, 2, size - 3, size - 2, size - 1], index[..., :1]
1402    )
1403    index = new_index
1404    return index, weight
1405
1406
1407@dataclasses.dataclass(frozen=True)
1408class OverrideExteriorValue:
1409  """Abstract base class to set the value outside some domain extent to a
1410  constant value (`cval`)."""
1411
1412  boundary_antialiasing: bool = True
1413  """Antialias the pixel values adjacent to the boundary of the extent."""
1414
1415  uses_cval: bool = False
1416  """Modify some weights to introduce references to the `cval` constant value."""
1417
1418  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1419    """For all `point` outside some extent, modify the weight to be zero."""
1420
1421  def override_using_signed_distance(
1422      self, weight: _NDArray, point: _NDArray, signed_distance: _NDArray, /
1423  ) -> None:
1424    """Reduce sample weights for "outside" values based on the signed distance function,
1425    to effectively assign the constant value `cval`."""
1426    all_points_inside_domain = np.all(signed_distance <= 0.0)
1427    if all_points_inside_domain:
1428      return
1429    if self.boundary_antialiasing and min(point.shape) >= 2:
1430      # For discontinuous coordinate mappings, we may need to somehow ignore
1431      # the large finite differences computed across the map discontinuities.
1432      gradient = np.gradient(point)
1433      gradient_norm = np.linalg.norm(np.atleast_2d(gradient), axis=0)
1434      signed_distance_in_samples = signed_distance / (gradient_norm + 1e-20)
1435      # Opacity is in linear space, which is correct if Gamma is set.
1436      opacity = (0.5 - signed_distance_in_samples).clip(0.0, 1.0)
1437      weight *= opacity[..., None]
1438    else:
1439      is_outside = signed_distance > 0.0
1440      weight[is_outside, :] = 0.0
1441
1442
1443class NoOverrideExteriorValue(OverrideExteriorValue):
1444  """The function value is not overridden."""
1445
1446  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1447    pass
1448
1449
1450class UnitDomainOverrideExteriorValue(OverrideExteriorValue):
1451  """Values outside the unit interval [0, 1] are replaced by the constant `cval`."""
1452
1453  def __init__(self, **kwargs: Any) -> None:
1454    super().__init__(uses_cval=True, **kwargs)
1455
1456  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1457    signed_distance = abs(point - 0.5) - 0.5  # Boundaries at 0.0 and 1.0.
1458    self.override_using_signed_distance(weight, point, signed_distance)
1459
1460
1461class PlusMinusOneOverrideExteriorValue(OverrideExteriorValue):
1462  """Values outside the interval [-1, 1] are replaced by the constant `cval`."""
1463
1464  def __init__(self, **kwargs: Any) -> None:
1465    super().__init__(uses_cval=True, **kwargs)
1466
1467  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1468    signed_distance = abs(point) - 1.0  # Boundaries at -1.0 and 1.0.
1469    self.override_using_signed_distance(weight, point, signed_distance)
1470
1471
1472@dataclasses.dataclass(frozen=True)
1473class Boundary:
1474  """Domain boundary rules.  These define the reconstruction over the source domain near and beyond
1475  the domain boundaries.  The rules may be specified separately for each domain dimension."""
1476
1477  name: str = ''
1478  """Boundary rule name."""
1479
1480  coord_remap: RemapCoordinates = NoRemapCoordinates()
1481  """Modify specified coordinates prior to evaluating the reconstruction kernels."""
1482
1483  extend_samples: ExtendSamples = ReflectExtendSamples()
1484  """Define the value of each grid sample outside the unit domain as an affine combination of
1485  interior sample(s) and possibly the constant value (`cval`)."""
1486
1487  override_value: OverrideExteriorValue = NoOverrideExteriorValue()
1488  """Set the value outside some extent to a constant value (`cval`)."""
1489
1490  @property
1491  def uses_cval(self) -> bool:
1492    """True if weights may be non-affine, involving the constant value (`cval`)."""
1493    return self.extend_samples.uses_cval or self.override_value.uses_cval
1494
1495  def preprocess_coordinates(self, point: _NDArray, /) -> _NDArray:
1496    """Modify coordinates prior to evaluating the filter kernels."""
1497    # Antialiasing across the tile boundaries may be feasible but seems hard.
1498    point = self.coord_remap(point)
1499    return point
1500
1501  def apply(
1502      self, index: _NDArray, weight: _NDArray, point: _NDArray, size: int, gridtype: Gridtype, /
1503  ) -> tuple[_NDArray, _NDArray]:
1504    """Replace exterior samples by combinations of interior samples."""
1505    index, weight = self.extend_samples(index, weight, size, gridtype)
1506    self.override_reconstruction(weight, point)
1507    return index, weight
1508
1509  def override_reconstruction(self, weight: _NDArray, point: _NDArray, /) -> None:
1510    """For points outside an extent, modify weight to zero to assign `cval`."""
1511    self.override_value(weight, point)
1512
1513
1514_DICT_BOUNDARIES = {
1515    'reflect': Boundary('reflect', extend_samples=ReflectExtendSamples()),
1516    'wrap': Boundary('wrap', extend_samples=WrapExtendSamples()),
1517    'tile': Boundary(
1518        'title', coord_remap=TileRemapCoordinates(), extend_samples=ReflectExtendSamples()
1519    ),
1520    'clamp': Boundary('clamp', extend_samples=ClampExtendSamples()),
1521    'border': Boundary('border', extend_samples=BorderExtendSamples()),
1522    'natural': Boundary(
1523        'natural',
1524        extend_samples=ValidExtendSamples(),
1525        override_value=UnitDomainOverrideExteriorValue(),
1526    ),
1527    'linear_constant': Boundary(
1528        'linear_constant',
1529        extend_samples=LinearExtendSamples(),
1530        override_value=UnitDomainOverrideExteriorValue(),
1531    ),
1532    'quadratic_constant': Boundary(
1533        'quadratic_constant',
1534        extend_samples=QuadraticExtendSamples(),
1535        override_value=UnitDomainOverrideExteriorValue(),
1536    ),
1537    'reflect_clamp': Boundary('reflect_clamp', extend_samples=ReflectClampExtendSamples()),
1538    'constant': Boundary(
1539        'constant',
1540        extend_samples=ReflectExtendSamples(),
1541        override_value=UnitDomainOverrideExteriorValue(),
1542    ),
1543    'linear': Boundary('linear', extend_samples=LinearExtendSamples()),
1544    'quadratic': Boundary('quadratic', extend_samples=QuadraticExtendSamples()),
1545}
1546
1547BOUNDARIES = list(_DICT_BOUNDARIES)
1548"""Shortcut names for some predefined boundary rules (as defined by `_DICT_BOUNDARIES`):
1549
1550| name                   | a.k.a. / comments |
1551|------------------------|-------------------|
1552| `'reflect'`            | *reflected*, *symm*, *symmetric*, *mirror*, *grid-mirror* |
1553| `'wrap'`               | *periodic*, *repeat*, *grid-wrap* |
1554| `'tile'`               | like `'reflect'` within unit domain, then tile discontinuously |
1555| `'clamp'`              | *clamped*, *nearest*, *edge*, *clamp-to-edge*, repeat last sample |
1556| `'border'`             | *grid-constant*, use `cval` for samples outside unit domain |
1557| `'natural'`            | *renormalize* using only interior samples, use `cval` outside domain |
1558| `'reflect_clamp'`      | *mirror-clamp-to-edge* |
1559| `'constant'`           | like `'reflect'` but replace by `cval` outside unit domain |
1560| `'linear'`             | extrapolate from 2 last samples |
1561| `'quadratic'`          | extrapolate from 3 last samples |
1562| `'linear_constant'`    | like `'linear'` but replace by `cval` outside unit domain |
1563| `'quadratic_constant'` | like `'quadratic'` but replace by `cval` outside unit domain |
1564
1565These boundary rules may be specified per dimension.  See the source code for extensibility
1566using the classes `RemapCoordinates`, `ExtendSamples`, and `OverrideExteriorValue`.
1567
1568**Boundary rules illustrated in 1D:**
1569
1570<center>
1571<img src="https://github.com/hhoppe/resampler/raw/main/media/boundary_rules_in_1D.png" width="100%"/>
1572</center>
1573
1574**Boundary rules illustrated in 2D:**
1575
1576<center>
1577<img src="https://github.com/hhoppe/resampler/raw/main/media/boundary_rules_in_2D.png" width="100%"/>
1578</center>
1579"""
1580
1581_OFTUSED_BOUNDARIES = (
1582    'reflect wrap tile clamp border natural linear_constant quadratic_constant'.split()
1583)
1584"""A useful subset of `BOUNDARIES` for visualization in figures."""
1585
1586
1587def _get_boundary(boundary: str | Boundary, /) -> Boundary:
1588  """Return a `Boundary`, which can be specified as a name in `BOUNDARIES`."""
1589  return boundary if isinstance(boundary, Boundary) else _DICT_BOUNDARIES[boundary]
1590
1591
1592@dataclasses.dataclass(frozen=True)
1593class Filter(abc.ABC):
1594  """Abstract base class for filter kernel functions.
1595
1596  Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support
1597  interval [-radius, radius].  (Some sites instead define kernels over the interval [0, N]
1598  where N = 2 * radius.)
1599
1600  Portions of this code are adapted from the C++ library in
1601  https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp
1602
1603  See also https://hhoppe.com/proj/filtering/.
1604  """
1605
1606  name: str
1607  """Filter kernel name."""
1608
1609  radius: float
1610  """Max absolute value of x for which self(x) is nonzero."""
1611
1612  interpolating: bool = True
1613  """True if self(0) == 1.0 and self(i) == 0.0 for all nonzero integers i."""
1614
1615  continuous: bool = True
1616  """True if the kernel function has $C^0$ continuity."""
1617
1618  partition_of_unity: bool = True
1619  """True if the convolution of the kernel with a Dirac comb reproduces the
1620  unity function."""
1621
1622  unit_integral: bool = True
1623  """True if the integral of the kernel function is 1."""
1624
1625  requires_digital_filter: bool = False
1626  """True if the filter needs a pre/post digital filter for interpolation."""
1627
1628  @abc.abstractmethod
1629  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1630    """Return evaluation of filter kernel at locations x."""
1631
1632
1633class ImpulseFilter(Filter):
1634  """See https://en.wikipedia.org/wiki/Dirac_delta_function."""
1635
1636  def __init__(self) -> None:
1637    super().__init__(name='impulse', radius=1e-20, continuous=False, partition_of_unity=False)
1638
1639  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1640    raise AssertionError('The Impulse is infinitely narrow, so cannot be directly evaluated.')
1641
1642
1643class BoxFilter(Filter):
1644  """See https://en.wikipedia.org/wiki/Box_function.
1645
1646  The kernel function has value 1.0 over the half-open interval [-.5, .5).
1647  """
1648
1649  def __init__(self) -> None:
1650    super().__init__(name='box', radius=0.5, continuous=False)
1651
1652  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1653    use_asymmetric = True
1654    if use_asymmetric:
1655      x = np.asarray(x)
1656      return np.where((-0.5 <= x) & (x < 0.5), 1.0, 0.0)
1657    x = np.abs(x)
1658    return np.where(x < 0.5, 1.0, np.where(x == 0.5, 0.5, 0.0))
1659
1660
1661class TrapezoidFilter(Filter):
1662  """Filter for antialiased "area-based" filtering.
1663
1664  Args:
1665    radius: Specifies the support [-radius, radius] of the filter, where 0.5 < radius <= 1.0.
1666      The special case `radius = None` is a placeholder that indicates that the filter will be
1667      replaced by a trapezoid of the appropriate radius (based on scaling) for correct
1668      antialiasing in both minification and magnification.
1669
1670  This filter is similar to the BoxFilter but with linearly sloped sides.  It has value 1.0
1671  in the interval abs(x) <= 1.0 - radius and decreases linearly to value 0.0 in the interval
1672  1.0 - radius <= abs(x) <= radius, always with value 0.5 at x = 0.5.
1673  """
1674
1675  def __init__(self, *, radius: float | None = None) -> None:
1676    if radius is None:
1677      super().__init__(name='trapezoid', radius=0.0)
1678      return
1679    if not 0.5 < radius <= 1.0:
1680      raise ValueError(f'Radius {radius} is outside the range (0.5, 1.0].')
1681    super().__init__(name=f'trapezoid_{radius}', radius=radius)
1682
1683  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1684    x = np.abs(x)
1685    assert 0.5 < self.radius <= 1.0
1686    return ((0.5 + 0.25 / (self.radius - 0.5)) - (0.5 / (self.radius - 0.5)) * x).clip(0.0, 1.0)
1687
1688
1689class TriangleFilter(Filter):
1690  """See https://en.wikipedia.org/wiki/Triangle_function.
1691
1692  Also known as the hat or tent function.  It is used for piecewise-linear
1693  (or bilinear, or trilinear, ...) interpolation.
1694  """
1695
1696  def __init__(self) -> None:
1697    super().__init__(name='triangle', radius=1.0)
1698
1699  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1700    return (1.0 - np.abs(x)).clip(0.0, 1.0)
1701
1702
1703class CubicFilter(Filter):
1704  """Family of cubic filters parameterized by two scalar parameters.
1705
1706  Args:
1707    b: first scalar parameter.
1708    c: second scalar parameter.
1709
1710  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1711  https://doi.org/10.1145/378456.378514.
1712
1713  [D. P. Mitchell and A. N. Netravali. Reconstruction filters in computer graphics.
1714  Computer Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1715
1716  - The filter has quadratic precision iff b + 2 * c == 1.
1717  - The filter is interpolating iff b == 0.
1718  - (b=1, c=0) is the (non-interpolating) cubic B-spline basis;
1719  - (b=1/3, c=1/3) is the Mitchell filter;
1720  - (b=0, c=0.5) is the Catmull-Rom spline (which has cubic precision);
1721  - (b=0, c=0.75) is the "sharper cubic" used in Photoshop and OpenCV.
1722  """
1723
1724  def __init__(self, *, b: float, c: float, name: str | None = None) -> None:
1725    name = f'cubic_b{b}_c{c}' if name is None else name
1726    interpolating = b == 0
1727    super().__init__(name=name, radius=2.0, interpolating=interpolating)
1728    self.b, self.c = b, c
1729
1730  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1731    x = np.abs(x)
1732    b, c = self.b, self.c
1733    f3, f2, f0 = 2 - 9 / 6 * b - c, -3 + 2 * b + c, 1 - 1 / 3 * b
1734    g3, g2, g1, g0 = -b / 6 - c, b + 5 * c, -2 * b - 8 * c, 8 / 6 * b + 4 * c
1735    # (np.polynomial.polynomial.polyval(x, [f0, 0, f2, f3]) is almost
1736    # twice as slow; see also https://stackoverflow.com/questions/24065904)
1737    v01 = ((f3 * x + f2) * x) * x + f0
1738    v12 = ((g3 * x + g2) * x + g1) * x + g0
1739    return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1740
1741
1742class CatmullRomFilter(CubicFilter):
1743  """Cubic filter with cubic precision.  Also known as Keys filter.
1744
1745  [E. Catmull, R. Rom.  A class of local interpolating splines.  Computer aided geometric
1746  design, 1974]
1747  [Wikipedia](https://en.wikipedia.org/wiki/Cubic_Hermite_spline#Catmull%E2%80%93Rom_spline)
1748
1749  [R. G. Keys.  Cubic convolution interpolation for digital image processing.
1750  IEEE Trans. on Acoustics, Speech, and Signal Processing, 29(6), 1981.]
1751  https://ieeexplore.ieee.org/document/1163711/.
1752  """
1753
1754  def __init__(self) -> None:
1755    super().__init__(b=0, c=0.5, name='cubic')
1756
1757
1758class MitchellFilter(CubicFilter):
1759  """See https://doi.org/10.1145/378456.378514.
1760
1761  [D. P. Mitchell and A. N. Netravali.  Reconstruction filters in computer graphics.  Computer
1762  Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1763  """
1764
1765  def __init__(self) -> None:
1766    super().__init__(b=1 / 3, c=1 / 3, name='mitchell')
1767
1768
1769class SharpCubicFilter(CubicFilter):
1770  """Cubic filter that is sharper than Catmull-Rom filter.
1771
1772  Used by some tools including OpenCV and Photoshop.
1773
1774  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1775  https://entropymine.com/resamplescope/notes/photoshop/.
1776  """
1777
1778  def __init__(self) -> None:
1779    super().__init__(b=0, c=0.75, name='sharpcubic')
1780
1781
1782class LanczosFilter(Filter):
1783  """High-quality filter: sinc function modulated by a sinc window.
1784
1785  Args:
1786    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1787    sampled: If True, use a discretized approximation for improved speed.
1788
1789  See https://en.wikipedia.org/wiki/Lanczos_kernel.
1790  """
1791
1792  def __init__(self, *, radius: int, sampled: bool = True) -> None:
1793    super().__init__(
1794        name=f'lanczos_{radius}', radius=radius, partition_of_unity=False, unit_integral=False
1795    )
1796
1797    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1798    def _eval(x: _ArrayLike) -> _NDArray:
1799      x = np.abs(x)
1800      # Note that window[n] = sinc(2*n/N - 1), with 0 <= n <= N.
1801      # But, x = n - N/2, or equivalently, n = x + N/2, with -N/2 <= x <= N/2.
1802      window = _sinc(x / radius)  # Zero-phase function w_0(x).
1803      return np.where(x < radius, _sinc(x) * window, 0.0)
1804
1805    self._function = _eval
1806
1807  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1808    return self._function(x)
1809
1810
1811class GeneralizedHammingFilter(Filter):
1812  """Sinc function modulated by a Hamming window.
1813
1814  Args:
1815    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1816    a0: Scalar parameter, where 0.0 < a0 < 1.0.  The case of a0=0.5 is the Hann filter.
1817
1818  See https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows,
1819  and hamming() in https://github.com/scipy/scipy/blob/main/scipy/signal/windows/_windows.py.
1820
1821  Note that `'hamming3'` is `(radius=3, a0=25/46)`, which close to but different from `a0=0.54`.
1822
1823  See also np.hamming() and np.hanning().
1824  """
1825
1826  def __init__(self, *, radius: int, a0: float) -> None:
1827    super().__init__(
1828        name=f'hamming_{radius}',
1829        radius=radius,
1830        partition_of_unity=False,  # 1:1.00242  av=1.00188  sd=0.00052909
1831        unit_integral=False,  # 1.00188
1832    )
1833    assert 0.0 < a0 < 1.0
1834    self.a0 = a0
1835
1836  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1837    x = np.abs(x)
1838    # Note that window[n] = a0 - (1 - a0) * cos(2 * pi * n / N), 0 <= n <= N.
1839    # With n = x + N/2, we get the zero-phase function w_0(x):
1840    window = self.a0 + (1.0 - self.a0) * np.cos(np.pi / self.radius * x)
1841    return np.where(x < self.radius, _sinc(x) * window, 0.0)
1842
1843
1844class KaiserFilter(Filter):
1845  """Sinc function modulated by a Kaiser-Bessel window.
1846
1847  See https://en.wikipedia.org/wiki/Kaiser_window, and example use in:
1848  [Karras et al. 20201.  Alias-free generative adversarial networks.
1849  https://arxiv.org/pdf/2106.12423.pdf].
1850
1851  See also np.kaiser().
1852
1853  Args:
1854    radius: Value L/2 in the definition.  It may be fractional for a (digital) resizing filter
1855      (sample spacing s != 1) with an even number of samples (dual grid), e.g., Eq. (6)
1856      in [Karras et al. 2021] --- this effects the precise shape of the window function.
1857    beta: Determines the trade-off between main-lobe width and side-lobe level.
1858    sampled: If True, use a discretized approximation for improved speed.
1859  """
1860
1861  def __init__(self, *, radius: float, beta: float, sampled: bool = True) -> None:
1862    assert beta >= 0.0
1863    super().__init__(
1864        name=f'kaiser_{radius}_{beta}', radius=radius, partition_of_unity=False, unit_integral=False
1865    )
1866
1867    @_cache_sampled_1d_function(xmin=-math.ceil(radius), xmax=math.ceil(radius), enable=sampled)
1868    def _eval(x: _ArrayLike) -> _NDArray:
1869      x = np.abs(x)
1870      window = np.i0(beta * np.sqrt((1.0 - np.square(x / radius)).clip(0.0, 1.0))) / np.i0(beta)
1871      return np.where(x <= radius + 1e-6, _sinc(x) * window, 0.0)
1872
1873    self._function = _eval
1874
1875  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1876    return self._function(x)
1877
1878
1879class BsplineFilter(Filter):
1880  """B-spline of a non-negative degree.
1881
1882  Args:
1883    degree: The polynomial degree of the B-spline segments.
1884      With `degree=0`, it is like `BoxFilter` except with f(0.5) = f(-0.5) = 0.
1885      With `degree=1`, it is identical to `TriangleFilter`.
1886      With `degree >= 2`, it is no longer interpolating.
1887
1888  See [Carl de Boor.  A practical guide to splines.  Springer, 2001.]
1889  https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.BSpline.html
1890  """
1891
1892  def __init__(self, *, degree: int) -> None:
1893    if degree < 0:
1894      raise ValueError(f'Bspline of degree {degree} is invalid.')
1895    radius = (degree + 1) / 2
1896    interpolating = degree <= 1
1897    super().__init__(name=f'bspline{degree}', radius=radius, interpolating=interpolating)
1898    t = list(range(degree + 2))
1899    self._bspline = scipy.interpolate.BSpline.basis_element(t)
1900
1901  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1902    x = np.abs(x)
1903    return np.where(x < self.radius, self._bspline(x + self.radius), 0.0)
1904
1905
1906class CardinalBsplineFilter(Filter):
1907  """Interpolating B-spline, achieved with aid of digital pre or post filter.
1908
1909  Args:
1910    degree: The polynomial degree of the B-spline segments.
1911    sampled: If True, use a discretized approximation for improved speed.
1912
1913  See [Hou and Andrews.  Cubic splines for image interpolation and digital filtering, 1978] and
1914  [Unser et al.  Fast B-spline transforms for continuous image representation and interpolation,
1915  1991].
1916  """
1917
1918  def __init__(self, *, degree: int, sampled: bool = True) -> None:
1919    self.degree = degree
1920    if degree < 0:
1921      raise ValueError(f'Bspline of degree {degree} is invalid.')
1922    radius = (degree + 1) / 2
1923    super().__init__(
1924        name=f'cardinal{degree}',
1925        radius=radius,
1926        requires_digital_filter=degree >= 2,
1927        continuous=degree >= 1,
1928    )
1929    t = list(range(degree + 2))
1930    bspline = scipy.interpolate.BSpline.basis_element(t)
1931
1932    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1933    def _eval(x: _ArrayLike) -> _NDArray:
1934      x = np.abs(x)
1935      return np.where(x < radius, bspline(x + radius), 0.0)
1936
1937    self._function = _eval
1938
1939  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1940    return self._function(x)
1941
1942
1943class OmomsFilter(Filter):
1944  """OMOMS interpolating filter, with aid of digital pre or post filter.
1945
1946  Args:
1947    degree: The polynomial degree of the filter segments.
1948
1949  Optimal MOMS (maximal-order-minimal-support) function; see [Blu and Thevenaz, MOMS: Maximal-order
1950  interpolation of minimal support, 2001].
1951  https://infoscience.epfl.ch/record/63074/files/blu0101.pdf
1952  """
1953
1954  def __init__(self, *, degree: int) -> None:
1955    if degree not in (3, 5):
1956      raise ValueError(f'Degree {degree} not supported.')
1957    super().__init__(name=f'omoms{degree}', radius=(degree + 1) / 2, requires_digital_filter=True)
1958    self.degree = degree
1959
1960  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1961    x = np.abs(x)
1962    match self.degree:
1963      case 3:
1964        v01 = ((0.5 * x - 1.0) * x + 3 / 42) * x + 26 / 42
1965        v12 = ((-7 / 42 * x + 1.0) * x - 85 / 42) * x + 58 / 42
1966        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1967      case 5:
1968        v01 = ((((-1 / 12 * x + 1 / 4) * x - 5 / 99) * x - 9 / 22) * x - 1 / 792) * x + 229 / 440
1969        v12 = (
1970            (((1 / 24 * x - 3 / 8) * x + 505 / 396) * x - 83 / 44) * x + 1351 / 1584
1971        ) * x + 839 / 2640
1972        v23 = (
1973            (((-1 / 120 * x + 1 / 8) * x - 299 / 396) * x + 101 / 44) * x - 27811 / 7920
1974        ) * x + 5707 / 2640
1975        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, np.where(x < 3.0, v23, 0.0)))
1976      case _:
1977        raise ValueError(self.degree)
1978
1979
1980class GaussianFilter(Filter):
1981  r"""See https://en.wikipedia.org/wiki/Gaussian_function.
1982
1983  Args:
1984    standard_deviation: Sets the Gaussian $\sigma$.  The default value is 1.25/3.0, which
1985      creates a kernel that is as-close-as-possible to a partition of unity.
1986  """
1987
1988  DEFAULT_STANDARD_DEVIATION = 1.25 / 3.0
1989  """This value creates a kernel that is as-close-as-possible to a partition of unity; see
1990  mesh_processing/test/GridOp_test.cpp: `0.93503:1.06497     av=1           sd=0.0459424`.
1991  Another possibility is 0.5, as suggested on p. 4 of [Ken Turkowski.  Filters for common
1992  resampling tasks, 1990] for kernels with a support of 3 pixels.
1993  https://cadxfem.org/inf/ResamplingFilters.pdf
1994  """
1995
1996  def __init__(self, *, standard_deviation: float = DEFAULT_STANDARD_DEVIATION) -> None:
1997    super().__init__(
1998        name=f'gaussian_{standard_deviation:.3f}',
1999        radius=np.ceil(8.0 * standard_deviation),  # Sufficiently large.
2000        interpolating=False,
2001        partition_of_unity=False,
2002    )
2003    self.standard_deviation = standard_deviation
2004
2005  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2006    x = np.abs(x)
2007    sdv = self.standard_deviation
2008    v0r = np.exp(np.square(x / sdv) / -2.0) / (np.sqrt(math.tau) * sdv)
2009    return np.where(x < self.radius, v0r, 0.0)
2010
2011
2012class NarrowBoxFilter(Filter):
2013  """Compact footprint, used for visualization of grid sample location.
2014
2015  Args:
2016    radius: Specifies the support [-radius, radius] of the narrow box function.  (The default
2017      value 0.199 is an inexact 0.2 to avoid numerical ambiguities.)
2018  """
2019
2020  def __init__(self, *, radius: float = 0.199) -> None:
2021    super().__init__(
2022        name='narrowbox',
2023        radius=radius,
2024        continuous=False,
2025        unit_integral=False,
2026        partition_of_unity=False,
2027    )
2028
2029  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2030    radius = self.radius
2031    magnitude = 1.0
2032    x = np.asarray(x)
2033    return np.where((-radius <= x) & (x < radius), magnitude, 0.0)
2034
2035
2036_DEFAULT_FILTER = 'lanczos3'
2037
2038_DICT_FILTERS = {
2039    'impulse': ImpulseFilter(),
2040    'box': BoxFilter(),
2041    'trapezoid': TrapezoidFilter(),
2042    'triangle': TriangleFilter(),
2043    'cubic': CatmullRomFilter(),
2044    'sharpcubic': SharpCubicFilter(),
2045    'lanczos3': LanczosFilter(radius=3),
2046    'lanczos5': LanczosFilter(radius=5),
2047    'lanczos10': LanczosFilter(radius=10),
2048    'cardinal3': CardinalBsplineFilter(degree=3),
2049    'cardinal5': CardinalBsplineFilter(degree=5),
2050    'omoms3': OmomsFilter(degree=3),
2051    'omoms5': OmomsFilter(degree=5),
2052    'hamming3': GeneralizedHammingFilter(radius=3, a0=25 / 46),
2053    'kaiser3': KaiserFilter(radius=3.0, beta=7.12),
2054    'gaussian': GaussianFilter(),
2055    'bspline3': BsplineFilter(degree=3),
2056    'mitchell': MitchellFilter(),
2057    'narrowbox': NarrowBoxFilter(),
2058    # Not in FILTERS:
2059    'hamming1': GeneralizedHammingFilter(radius=1, a0=0.54),
2060    'hann3': GeneralizedHammingFilter(radius=3, a0=0.5),
2061    'lanczos4': LanczosFilter(radius=4),
2062}
2063
2064FILTERS = list(itertools.takewhile(lambda x: x != 'hamming1', _DICT_FILTERS))
2065r"""Shortcut names for some predefined filter kernels (specified per dimension).
2066The names expand to:
2067
2068| name           | `Filter`                      | a.k.a. / comments |
2069|----------------|-------------------------------|-------------------|
2070| `'impulse'`    | `ImpulseFilter()`             | *nearest* |
2071| `'box'`        | `BoxFilter()`                 | non-antialiased box, e.g. ImageMagick |
2072| `'trapezoid'`  | `TrapezoidFilter()`           | *area* antialiasing, e.g. `cv.INTER_AREA` |
2073| `'triangle'`   | `TriangleFilter()`            | *linear*  (*bilinear* in 2D), spline `order=1` |
2074| `'cubic'`      | `CatmullRomFilter()`          | *catmullrom*, *keys*, *bicubic* |
2075| `'sharpcubic'` | `SharpCubicFilter()`          | `cv.INTER_CUBIC`, `torch 'bicubic'` |
2076| `'lanczos3'`   | `LanczosFilter`(radius=3)     | support window [-3, 3] |
2077| `'lanczos5'`   | `LanczosFilter`(radius=5)     | [-5, 5] |
2078| `'lanczos10'`  | `LanczosFilter`(radius=10)    | [-10, 10] |
2079| `'cardinal3'`  | `CardinalBsplineFilter`(degree=3) | *spline interpolation*, `order=3`, *GF* |
2080| `'cardinal5'`  | `CardinalBsplineFilter`(degree=5) | *spline interpolation*, `order=5`, *GF* |
2081| `'omoms3'`     | `OmomsFilter`(degree=3)       | non-$C^1$, [-3, 3], *GF* |
2082| `'omoms5'`     | `OmomsFilter`(degree=5)       | non-$C^1$, [-5, 5], *GF* |
2083| `'hamming3'`   | `GeneralizedHammingFilter`(...) | (radius=3, a0=25/46) |
2084| `'kaiser3'`    | `KaiserFilter`(radius=3.0, beta=7.12) | |
2085| `'gaussian'`   | `GaussianFilter()`            | non-interpolating, default $\sigma=1.25/3$ |
2086| `'bspline3'`   | `BsplineFilter`(degree=3)     | non-interpolating |
2087| `'mitchell'`   | `MitchellFilter()`            | *mitchellcubic* |
2088| `'narrowbox'`  | `NarrowBoxFilter()`           | for visualization of sample positions |
2089
2090The comment label *GF* denotes a [generalized filter](https://hhoppe.com/proj/filtering/), formed
2091as the composition of a finitely supported kernel and a discrete inverse convolution.
2092
2093**Some example filter kernels:**
2094
2095<center>
2096<img src="https://github.com/hhoppe/resampler/raw/main/media/filter_summary.png" width="100%"/>
2097</center>
2098
2099<br/>A more extensive set of filters is presented [here](#plots_of_filters) in the
2100[notebook](https://colab.research.google.com/github/hhoppe/resampler/blob/main/resampler_notebook.ipynb),
2101together with visualizations and analyses of the filter properties.
2102See the source code for extensibility.
2103"""
2104
2105
2106def _get_filter(filter: str | Filter, /) -> Filter:
2107  """Return a `Filter`, which can be specified as a name string key in `FILTERS`."""
2108  return filter if isinstance(filter, Filter) else _DICT_FILTERS[filter]
2109
2110
2111def _to_float_01(array: _Array, /, dtype: _DTypeLike) -> _Array:
2112  """Scale uint to the range [0.0, 1.0], and clip float to [0.0, 1.0]."""
2113  array_dtype = _arr_dtype(array)
2114  dtype = np.dtype(dtype)
2115  assert np.issubdtype(dtype, np.floating)
2116  match array_dtype.type:
2117    case np.uint8 | np.uint16 | np.uint32:
2118      if _arr_arraylib(array) == 'numpy':
2119        assert isinstance(array, np.ndarray)  # Help mypy.
2120        return np.multiply(array, 1 / np.iinfo(array_dtype).max, dtype=dtype)
2121      return _arr_astype(array, dtype) / np.iinfo(array_dtype).max
2122    case _:
2123      assert np.issubdtype(array_dtype, np.floating)
2124      return _arr_clip(array, 0.0, 1.0, dtype)
2125
2126
2127def _from_float(array: _Array, /, dtype: _DTypeLike) -> _Array:
2128  """Convert a float in range [0.0, 1.0] to uint or float type."""
2129  assert np.issubdtype(_arr_dtype(array), np.floating)
2130  dtype = np.dtype(dtype)
2131  match dtype.type:
2132    case np.uint8 | np.uint16:
2133      return cast(_Array, _arr_astype(array * np.float32(np.iinfo(dtype).max) + 0.5, dtype))
2134    case np.uint32:
2135      return cast(_Array, _arr_astype(array * np.float64(np.iinfo(dtype).max) + 0.5, dtype))
2136    case _:
2137      assert np.issubdtype(dtype, np.floating)
2138      return _arr_astype(array, dtype)
2139
2140
2141@dataclasses.dataclass(frozen=True)
2142class Gamma(abc.ABC):
2143  """Abstract base class for transfer functions on sample values.
2144
2145  Image/video content is often stored using a color component transfer function.
2146  See https://en.wikipedia.org/wiki/Gamma_correction.
2147
2148  Converts between integer types and [0.0, 1.0] internal value range.
2149  """
2150
2151  name: str
2152  """Name of component transfer function."""
2153
2154  @abc.abstractmethod
2155  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2156    """Decode source sample values into floating-point, possibly nonlinearly.
2157
2158    Uint source values are mapped to the range [0.0, 1.0].
2159    """
2160
2161  @abc.abstractmethod
2162  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2163    """Encode float signal into destination samples, possibly nonlinearly.
2164
2165    Uint destination values are mapped from the range [0.0, 1.0].
2166
2167    Note that non-integer destination types are not clipped to the range [0.0, 1.0].
2168    If that is desired, it can be performed as a postprocess using `output.clip(0.0, 1.0)`.
2169    """
2170
2171
2172class IdentityGamma(Gamma):
2173  """Identity component transfer function."""
2174
2175  def __init__(self) -> None:
2176    super().__init__('identity')
2177
2178  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2179    dtype = np.dtype(dtype)
2180    assert np.issubdtype(dtype, np.inexact)
2181    if np.issubdtype(_arr_dtype(array), np.unsignedinteger):
2182      return _to_float_01(array, dtype)
2183    return _arr_astype(array, dtype)
2184
2185  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2186    dtype = np.dtype(dtype)
2187    assert np.issubdtype(dtype, np.number)
2188    if np.issubdtype(dtype, np.unsignedinteger):
2189      return _from_float(_arr_clip(array, 0.0, 1.0), dtype)
2190    if np.issubdtype(dtype, np.integer):
2191      return _arr_astype(array + 0.5, dtype)
2192    return _arr_astype(array, dtype)
2193
2194
2195class PowerGamma(Gamma):
2196  """Gamma correction using a power function."""
2197
2198  def __init__(self, power: float) -> None:
2199    super().__init__(name=f'power_{power}')
2200    self.power = power
2201
2202  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2203    dtype = np.dtype(dtype)
2204    assert np.issubdtype(dtype, np.floating)
2205    if _arr_dtype(array) == np.uint8 and self.power != 2:
2206      arraylib = _arr_arraylib(array)
2207      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2208      return _arr_getitem(decode_table, array)
2209
2210    array = _to_float_01(array, dtype)
2211    return _arr_square(array) if self.power == 2 else array**self.power
2212
2213  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2214    array = _arr_clip(array, 0.0, 1.0)
2215    array = _arr_sqrt(array) if self.power == 2 else array ** (1.0 / self.power)
2216    return _from_float(array, dtype)
2217
2218
2219class SrgbGamma(Gamma):
2220  """Gamma correction using sRGB; see https://en.wikipedia.org/wiki/SRGB."""
2221
2222  def __init__(self) -> None:
2223    super().__init__(name='srgb')
2224
2225  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2226    dtype = np.dtype(dtype)
2227    assert np.issubdtype(dtype, np.floating)
2228    if _arr_dtype(array) == np.uint8:
2229      arraylib = _arr_arraylib(array)
2230      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2231      return _arr_getitem(decode_table, array)
2232
2233    x = _to_float_01(array, dtype)
2234    return _arr_where(x > 0.04045, ((x + 0.055) / 1.055) ** 2.4, x / 12.92)
2235
2236  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2237    x = _arr_clip(array, 0.0, 1.0)
2238    # Unfortunately, exponentiation is slow, and np.digitize() is even slower.
2239    # pytype: disable=wrong-arg-types
2240    x = _arr_where(x > 0.0031308, x ** (1.0 / 2.4) * 1.055 - (0.055 - 1e-17), x * 12.92)
2241    # pytype: enable=wrong-arg-types
2242    return _from_float(x, dtype)
2243
2244
2245_DICT_GAMMAS = {
2246    'identity': IdentityGamma(),
2247    'power2': PowerGamma(2.0),
2248    'power22': PowerGamma(2.2),
2249    'srgb': SrgbGamma(),
2250}
2251
2252GAMMAS = list(_DICT_GAMMAS)
2253r"""Shortcut names for some predefined gamma-correction schemes:
2254
2255| name | `Gamma` | Decoding function<br/> (linear space from stored value) | Encoding function<br/> (stored value from linear space) |
2256|---|---|:---:|:---:|
2257| `'identity'` | `IdentityGamma()` | $l = e$ | $e = l$ |
2258| `'power2'` | `PowerGamma`(2.0) | $l = e^{2.0}$ | $e = l^{1/2.0}$ |
2259| `'power22'` | `PowerGamma`(2.2) | $l = e^{2.2}$ | $e = l^{1/2.2}$ |
2260| `'srgb'` | `SrgbGamma()` | $l = \left(\left(e + 0.055\right) / 1.055\right)^{2.4}$ | $e = l^{1/2.4} * 1.055 - 0.055$ |
2261
2262See the source code for extensibility.
2263"""
2264
2265
2266def _get_gamma(gamma: str | Gamma, /) -> Gamma:
2267  """Return a `Gamma`, which can be specified as a name in `GAMMAS`."""
2268  return gamma if isinstance(gamma, Gamma) else _DICT_GAMMAS[gamma]
2269
2270
2271def _get_src_dst_gamma(
2272    gamma: str | Gamma | None,
2273    src_gamma: str | Gamma | None,
2274    dst_gamma: str | Gamma | None,
2275    src_dtype: _DType,
2276    dst_dtype: _DType,
2277) -> tuple[Gamma, Gamma]:
2278  if gamma is None and src_gamma is None and dst_gamma is None:
2279    src_uint = np.issubdtype(src_dtype, np.unsignedinteger)
2280    dst_uint = np.issubdtype(dst_dtype, np.unsignedinteger)
2281    if src_uint and dst_uint:
2282      # The default might ideally be 'srgb' but that conversion is costlier.
2283      gamma = 'power2'
2284    elif not src_uint and not dst_uint:
2285      gamma = 'identity'
2286    else:
2287      raise ValueError(f'Gamma must be specified because {src_dtype=} and {dst_dtype=}.')
2288  if gamma is not None:
2289    if src_gamma is not None:
2290      raise ValueError('Cannot specify both gamma and src_gamma.')
2291    if dst_gamma is not None:
2292      raise ValueError('Cannot specify both gamma and dst_gamma.')
2293    src_gamma = dst_gamma = gamma
2294  assert src_gamma and dst_gamma
2295  src_gamma = _get_gamma(src_gamma)
2296  dst_gamma = _get_gamma(dst_gamma)
2297  return src_gamma, dst_gamma
2298
2299
2300def _create_resize_matrix(
2301    src_size: int,
2302    dst_size: int,
2303    src_gridtype: Gridtype,
2304    dst_gridtype: Gridtype,
2305    boundary: Boundary,
2306    filter: Filter,
2307    prefilter: Filter | None = None,
2308    scale: float = 1.0,
2309    translate: float = 0.0,
2310    dtype: _DTypeLike = np.float64,
2311    arraylib: str = 'numpy',
2312) -> tuple[_Array, _Array | None]:
2313  """Compute affine weights for 1D resampling from `src_size` to `dst_size`.
2314
2315  Compute a sparse matrix in which each row expresses a destination sample value as a combination
2316  of source sample values depending on the boundary rule.  If the combination is non-affine,
2317  the remainder (returned as `cval_weight`) is the contribution of the special constant value
2318  (cval) defined outside the domain.
2319
2320  Args:
2321    src_size: The number of samples within the source 1D domain.
2322    dst_size: The number of samples within the destination 1D domain.
2323    src_gridtype: Placement of the samples in the source domain grid.
2324    dst_gridtype: Placement of the output samples in the destination domain grid.
2325    boundary: The reconstruction boundary rule.
2326    filter: The reconstruction kernel (used for upsampling/magnification).
2327    prefilter: The prefilter kernel (used for downsampling/minification).  If it is `None`,
2328      `filter` is used.
2329    scale: Scaling factor applied when mapping the source domain onto the destination domain.
2330    translate: Offset applied when mapping the scaled source domain onto the destination domain.
2331    dtype: Precision of computed resize matrix entries.
2332    arraylib: Representation of output.  Must be an element of `ARRAYLIBS`.
2333
2334  Returns:
2335    sparse_matrix: Matrix whose rows express output sample values as affine combinations of the
2336      source sample values.
2337    cval_weight: Optional vector expressing the additional contribution of the constant value
2338      (`cval`) to the combination in each row of `sparse_matrix`.  It equals one minus the sum of
2339      the weights in each matrix row.
2340  """
2341  if src_size < src_gridtype.min_size():
2342    raise ValueError(f'Source size {src_size} is too small for resize.')
2343  if dst_size < dst_gridtype.min_size():
2344    raise ValueError(f'Destination size {dst_size} is too small for resize.')
2345  prefilter = filter if prefilter is None else prefilter
2346  dtype = np.dtype(dtype)
2347  assert np.issubdtype(dtype, np.floating)
2348
2349  scaling = dst_gridtype.size_in_samples(dst_size) / src_gridtype.size_in_samples(src_size) * scale
2350  is_minification = scaling < 1.0
2351  filter = prefilter if is_minification else filter
2352  if filter.name == 'trapezoid':
2353    radius = 0.5 + 0.5 * min(scaling, 1.0 / scaling)
2354    filter = TrapezoidFilter(radius=radius)
2355  radius = filter.radius
2356  num_samples = int(np.ceil(radius * 2 / scaling) if is_minification else np.ceil(radius * 2))
2357
2358  dst_index = np.arange(dst_size, dtype=dtype)
2359  # Destination sample locations in unit domain [0, 1].
2360  dst_position = dst_gridtype.point_from_index(dst_index, dst_size)
2361
2362  src_position = (dst_position - translate) / scale
2363  src_position = boundary.preprocess_coordinates(src_position)
2364
2365  # Sample positions mapped back to source unit domain [0, 1].
2366  src_float_index = src_gridtype.index_from_point(src_position, src_size)
2367  src_first_index = (
2368      np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
2369      - (num_samples - 1) // 2
2370  )
2371
2372  sample_index = np.arange(num_samples, dtype=np.int32)
2373  src_index = src_first_index[:, None] + sample_index  # (dst_size, num_samples)
2374
2375  def get_weight_matrix() -> _NDArray:
2376    if filter.name == 'impulse':
2377      return np.ones(src_index.shape, dtype)
2378    if is_minification:
2379      x = (src_float_index[:, None] - src_index.astype(dtype)) * scaling
2380      return filter(x) * scaling
2381    # Either same size or magnification.
2382    x = src_float_index[:, None] - src_index.astype(dtype)
2383    return filter(x)
2384
2385  weight = get_weight_matrix().astype(dtype, copy=False)
2386
2387  if filter.name != 'narrowbox' and (is_minification or not filter.partition_of_unity):
2388    weight = weight / weight.sum(axis=-1)[..., None]
2389
2390  src_index, weight = boundary.apply(src_index, weight, src_position, src_size, src_gridtype)
2391  shape = dst_size, src_size
2392
2393  def prepare_sparse_resize_matrix() -> tuple[_NDArray, _NDArray, _NDArray]:
2394    linearized = (src_index + np.indices(src_index.shape)[0] * src_size).ravel()
2395    values = weight.ravel()
2396    # Remove the zero weights.
2397    nonzero = values != 0.0
2398    linearized, values = linearized[nonzero], values[nonzero]
2399    # Sort and merge the duplicate indices.
2400    unique, unique_inverse = np.unique(linearized, return_inverse=True)
2401    data2 = np.ones(len(linearized), np.float32)
2402    row_ind2 = unique_inverse
2403    col_ind2 = np.arange(len(linearized))
2404    shape2 = len(unique), len(linearized)
2405    csr = scipy.sparse.csr_matrix((data2, (row_ind2, col_ind2)), shape=shape2)
2406    data = csr * values  # Merged values.
2407    row_ind, col_ind = unique // src_size, unique % src_size  # Merged indices.
2408    return data, row_ind, col_ind
2409
2410  data, row_ind, col_ind = prepare_sparse_resize_matrix()
2411  resize_matrix = _make_sparse_matrix(data, row_ind, col_ind, shape, arraylib)
2412
2413  uses_cval = boundary.uses_cval or filter.name == 'narrowbox'
2414  cval_weight = _make_array(1.0 - weight.sum(axis=-1), arraylib) if uses_cval else None
2415
2416  return resize_matrix, cval_weight
2417
2418
2419def _apply_digital_filter_1d(
2420    array: _Array,
2421    gridtype: Gridtype,
2422    boundary: Boundary,
2423    cval: _ArrayLike,
2424    filter: Filter,
2425    /,
2426    *,
2427    axis: int = 0,
2428) -> _Array:
2429  """Apply inverse convolution to the specified dimension of the array.
2430
2431  Find the array coefficients such that convolution with the (continuous) filter (given
2432  gridtype and boundary) interpolates the original array values.
2433  """
2434  assert filter.requires_digital_filter
2435  arraylib = _arr_arraylib(array)
2436
2437  if arraylib == 'tensorflow':
2438    import tensorflow as tf
2439
2440    def forward(x: _NDArray) -> _NDArray:
2441      return _apply_digital_filter_1d_numpy(x, gridtype, boundary, cval, filter, axis, False)
2442
2443    def backward(grad_output: _NDArray) -> _NDArray:
2444      return _apply_digital_filter_1d_numpy(
2445          grad_output, gridtype, boundary, cval, filter, axis, True
2446      )
2447
2448    @tf.custom_gradient  # type: ignore[untyped-decorator]
2449    def tensorflow_inverse_convolution(x: _TensorflowTensor) -> _TensorflowTensor:
2450      # Although `forward` accesses parameters gridtype, boundary, etc., it is not stateful
2451      # because the function is redefined on each invocation of _apply_digital_filter_1d.
2452      y = tf.numpy_function(forward, [x], x.dtype, stateful=False)
2453
2454      def grad(grad_output: _TensorflowTensor) -> _TensorflowTensor:
2455        return tf.numpy_function(backward, [grad_output], x.dtype, stateful=False)
2456
2457      return y, grad
2458
2459    return tensorflow_inverse_convolution(array)
2460
2461  if arraylib == 'torch':
2462    import torch.autograd
2463
2464    class InverseConvolution(torch.autograd.Function):  # type: ignore[misc] # pylint: disable=abstract-method
2465      """Differentiable wrapper for _apply_digital_filter_1d."""
2466
2467      @staticmethod
2468      # pylint: disable-next=arguments-differ
2469      def forward(ctx: Any, *args: _TorchTensor, **kwargs: Any) -> _TorchTensor:
2470        del ctx
2471        assert not kwargs
2472        (x,) = args
2473        a = _apply_digital_filter_1d_numpy(
2474            x.detach().numpy(), gridtype, boundary, cval, filter, axis, False
2475        )
2476        return torch.as_tensor(a)
2477
2478      @staticmethod
2479      def backward(ctx: Any, *grad_outputs: _TorchTensor) -> _TorchTensor:
2480        del ctx
2481        (grad_output,) = grad_outputs
2482        a = _apply_digital_filter_1d_numpy(
2483            grad_output.detach().numpy(), gridtype, boundary, cval, filter, axis, True
2484        )
2485        return torch.as_tensor(a)
2486
2487    return InverseConvolution.apply(array)
2488
2489  if arraylib == 'jax':
2490    import jax
2491    import jax.numpy as jnp
2492    # It seems rather difficult to implement this digital filter (inverse convolution) in Jax.
2493    # https://jax.readthedocs.io/en/latest/jax.scipy.html sadly omits scipy.signal.filtfilt().
2494    # To include a (non-traceable) numpy function in Jax requires jax.experimental.host_callback
2495    # and/or defining a new jax.core.Primitive (which allows differentiability).  See
2496    # https://github.com/google/jax/issues/1142#issuecomment-544286585
2497    # https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb  :-(
2498    # https://github.com/google/jax/issues/5934
2499
2500    @jax.custom_gradient  # type: ignore[untyped-decorator]
2501    def jax_inverse_convolution(x: _JaxArray) -> _JaxArray:
2502      # This function is not jax-traceable due to the presence of to_py(), so jit and grad fail.
2503      x_py = np.asarray(x)  # to_py() deprecated.
2504      a = _apply_digital_filter_1d_numpy(x_py, gridtype, boundary, cval, filter, axis, False)
2505      y = jnp.asarray(a)
2506
2507      def grad(grad_output: _JaxArray) -> _JaxArray:
2508        grad_output_py = np.asarray(grad_output)  # to_py() deprecated.
2509        a = _apply_digital_filter_1d_numpy(
2510            grad_output_py, gridtype, boundary, cval, filter, axis, True
2511        )
2512        return jnp.asarray(a)
2513
2514      return y, grad
2515
2516    return jax_inverse_convolution(array)
2517
2518  assert arraylib == 'numpy'
2519  assert isinstance(array, np.ndarray)  # Help mypy.
2520  return _apply_digital_filter_1d_numpy(array, gridtype, boundary, cval, filter, axis, False)
2521
2522
2523def _apply_digital_filter_1d_numpy(
2524    array: _NDArray,
2525    gridtype: Gridtype,
2526    boundary: Boundary,
2527    cval: _ArrayLike,
2528    filter: Filter,
2529    axis: int,
2530    compute_backward: bool,
2531    /,
2532) -> _NDArray:
2533  """Version of _apply_digital_filter_1d` specialized to numpy array."""
2534  assert np.issubdtype(array.dtype, np.inexact)
2535  cval = np.asarray(cval).astype(array.dtype, copy=False)
2536
2537  # Use fast spline_filter1d() if we have a compatible gridtype, boundary, and filter:
2538  mode = {
2539      ('reflect', 'dual'): 'reflect',
2540      ('reflect', 'primal'): 'mirror',
2541      ('wrap', 'dual'): 'grid-wrap',
2542      ('wrap', 'primal'): 'wrap',
2543  }.get((boundary.name, gridtype.name))
2544  filter_is_compatible = isinstance(filter, CardinalBsplineFilter)
2545  use_split_filter1d = filter_is_compatible and mode
2546  if use_split_filter1d:
2547    assert isinstance(filter, CardinalBsplineFilter)  # Help mypy.
2548    assert filter.degree >= 2
2549    # compute_backward=True is same: matrix is symmetric and cval is unused.
2550    return scipy.ndimage.spline_filter1d(
2551        array, axis=axis, order=filter.degree, mode=mode, output=array.dtype
2552    )
2553
2554  array_dim = np.moveaxis(array, axis, 0)
2555  l = original_l = math.ceil(filter.radius) - 1
2556  x = np.arange(-l, l + 1, dtype=array.real.dtype)
2557  values = filter(x)
2558  size = array_dim.shape[0]
2559  src_index = np.arange(size)[:, None] + np.arange(len(values)) - l
2560  weight: _NDArray = np.full((size, len(values)), values)
2561  src_position = np.broadcast_to(0.5, len(values))
2562  src_index, weight = boundary.apply(src_index, weight, src_position, size, gridtype)
2563  if gridtype.name == 'primal' and boundary.name == 'wrap':
2564    # Overwrite redundant last row to preserve unreferenced last sample and thereby make the
2565    # matrix non-singular.
2566    src_index[-1] = [size - 1] + [0] * (src_index.shape[1] - 1)
2567    weight[-1] = [1.0] + [0.0] * (weight.shape[1] - 1)
2568  bandwidth = abs(src_index - np.arange(size)[:, None]).max()
2569  is_banded = bandwidth <= l + 1  # Add one for quadratic boundary and l == 1.
2570  # Currently, matrix is always banded unless boundary.name == 'wrap'.
2571
2572  data = weight.reshape(-1).astype(array.dtype, copy=False)
2573  row_ind = np.arange(size).repeat(src_index.shape[1])
2574  col_ind = src_index.reshape(-1)
2575  matrix = scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=(size, size))
2576  if compute_backward:
2577    matrix = matrix.T
2578
2579  if boundary.uses_cval and not compute_backward:
2580    cval_weight = 1.0 - np.asarray(matrix.sum(axis=-1))[:, 0]
2581    if array_dim.ndim == 2:  # Handle the case that we have array_flat.
2582      cval = np.tile(cval.reshape(-1), array_dim[0].size // cval.size)
2583    array_dim = array_dim - cval_weight.reshape(-1, *(1,) * array_dim[0].ndim) * cval
2584
2585  array_flat = array_dim.reshape(array_dim.shape[0], -1)
2586
2587  if is_banded:
2588    matrix = matrix.todia()
2589    assert np.all(np.diff(matrix.offsets) == 1)  # Consecutive, often [-l, l].
2590    l, u = -matrix.offsets[0], matrix.offsets[-1]
2591    assert l <= original_l + 1 and u <= original_l + 1, (l, u, original_l)
2592    options = dict(check_finite=False, overwrite_ab=True, overwrite_b=False)
2593    if _is_symmetric(matrix):
2594      array_flat = scipy.linalg.solveh_banded(matrix.data[-1 : l - 1 : -1], array_flat, **options)
2595    else:
2596      array_flat = scipy.linalg.solve_banded((l, u), matrix.data[::-1], array_flat, **options)
2597
2598  else:
2599    lu = scipy.sparse.linalg.splu(matrix.tocsc(), permc_spec='NATURAL')
2600    assert all(s <= size * len(values) for s in (lu.L.nnz, lu.U.nnz))  # Sparse.
2601    array_flat = lu.solve(array_flat)
2602
2603  array_dim = array_flat.reshape(array_dim.shape)
2604  return np.moveaxis(array_dim, 0, axis)
2605
2606
2607def resize(
2608    array: _Array,
2609    /,
2610    shape: Iterable[int],
2611    *,
2612    gridtype: str | Gridtype | None = None,
2613    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2614    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2615    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2616    cval: _ArrayLike = 0.0,
2617    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2618    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2619    gamma: str | Gamma | None = None,
2620    src_gamma: str | Gamma | None = None,
2621    dst_gamma: str | Gamma | None = None,
2622    scale: float | Iterable[float] = 1.0,
2623    translate: float | Iterable[float] = 0.0,
2624    precision: _DTypeLike | None = None,
2625    dtype: _DTypeLike | None = None,
2626    dim_order: Iterable[int] | None = None,
2627    num_threads: int | Literal['auto'] = 'auto',
2628) -> _Array:
2629  """Resample `array` (a grid of sample values) onto a grid with resolution `shape`.
2630
2631  The source `array` is any object recognized by `ARRAYLIBS`.  It is interpreted as a grid
2632  with `len(shape)` domain coordinate dimensions, where each grid sample value has shape
2633  `array.shape[len(shape):]`.
2634
2635  Some examples:
2636
2637  - A grayscale image has `array.shape = height, width` and resizing it with `len(shape) == 2`
2638    produces a new image of scalar values.
2639  - An RGB image has `array.shape = height, width, 3` and resizing it with `len(shape) == 2`
2640    produces a new image of RGB values.
2641  - An 3D grid of 3x3 Jacobians has `array.shape = Z, Y, X, 3, 3` and resizing it with
2642    `len(shape) == 3` produces a new 3D grid of Jacobians.
2643
2644  This function also allows scaling and translation from the source domain to the output domain
2645  through the parameters `scale` and `translate`.  For more general transforms, see `resample`.
2646
2647  Args:
2648    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
2649      The array must have numeric type.  Its first `len(shape)` dimensions are the domain
2650      coordinate dimensions.  Each grid dimension must be at least 1 for a `'dual'` grid or
2651      at least 2 for a `'primal'` grid.
2652    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2653      `array` must have at least as many dimensions as `len(shape)`.
2654    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
2655      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
2656      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
2657    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2658      Parameters `gridtype` and `src_gridtype` cannot both be set.
2659    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2660      Parameters `gridtype` and `dst_gridtype` cannot both be set.
2661    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2662      a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses `'reflect'`
2663      for upsampling and `'clamp'` for downsampling.
2664    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
2665      onto `array.shape[len(shape):]`.  It is subject to `src_gamma`.
2666    filter: The reconstruction kernel for each dimension in `shape`, specified as either a filter
2667      name in `FILTERS` or a `Filter` instance.  It is used during upsampling (i.e., magnification).
2668    prefilter: The prefilter kernel for each dimension in `shape`, specified as either a filter
2669      name in `FILTERS` or a `Filter` instance.  It is used during downsampling
2670      (i.e., minification).  If `None`, it inherits the value of `filter`.  The default
2671      `'lanczos3'` is good for natural images.  For vector graphics images, `'trapezoid'` is better
2672      because it avoids ringing artifacts.
2673    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from
2674      `array` and when creating output grid samples.  It is specified as either a name in `GAMMAS`
2675      or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default is
2676      `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
2677      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
2678      range [0.0, 1.0].
2679    src_gamma: Component transfer function used to "decode" `array` samples.
2680      Parameters `gamma` and `src_gamma` cannot both be set.
2681    dst_gamma: Component transfer function used to "encode" the output samples.
2682      Parameters `gamma` and `dst_gamma` cannot both be set.
2683    scale: Scaling factor applied to each dimension of the source domain when it is mapped onto
2684      the destination domain.
2685    translate: Offset applied to each dimension of the scaled source domain when it is mapped onto
2686      the destination domain.
2687    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
2688      on `array.dtype` and `dtype`.
2689    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
2690      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
2691      to the uint range.
2692    dim_order: Override the automatically selected order in which the grid dimensions are resized.
2693      Must contain a permutation of `range(len(shape))`.
2694    num_threads: Used to determine multithread parallelism if `array` is from `numpy`.  If set to
2695      `'auto'`, it is selected automatically.  Otherwise, it must be a positive integer.
2696
2697  Returns:
2698    An array of the same class as the source `array`, with shape `shape + array.shape[len(shape):]`
2699      and data type `dtype`.
2700
2701  **Example of image upsampling:**
2702
2703  >>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
2704  >>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.
2705
2706  <center>
2707  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_upsampled.png"/>
2708  </center>
2709
2710  **Example of image downsampling:**
2711
2712  >>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
2713  >>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
2714  >>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
2715  >>> downsampled = resize(array, (24, 48))
2716
2717  <center>
2718  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_downsampled2.png"/>
2719  </center>
2720
2721  **Unit test:**
2722
2723  >>> result = resize([1.0, 4.0, 5.0], shape=(4,))
2724  >>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
2725  """
2726  if isinstance(array, (tuple, list)):
2727    array = np.asarray(array)
2728  arraylib = _arr_arraylib(array)
2729  array_dtype = _arr_dtype(array)
2730  if not np.issubdtype(array_dtype, np.number):
2731    raise ValueError(f'Type {array.dtype} is not numeric.')
2732  shape2 = tuple(shape)
2733  array_ndim = len(array.shape)
2734  if not 0 < len(shape2) <= array_ndim:
2735    raise ValueError(f'Shape {array.shape} cannot be resized to {shape2}.')
2736  src_shape = array.shape[: len(shape2)]
2737  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2738      gridtype, src_gridtype, dst_gridtype, len(shape2), len(shape2)
2739  )
2740  boundary2 = np.broadcast_to(np.array(boundary), len(shape2))
2741  cval = np.broadcast_to(cval, array.shape[len(shape2) :])
2742  prefilter = filter if prefilter is None else prefilter
2743  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape2))]
2744  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), len(shape2))]
2745  dtype = array_dtype if dtype is None else np.dtype(dtype)
2746  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, array_dtype, dtype)
2747  scale2 = np.broadcast_to(np.array(scale), len(shape2))
2748  translate2 = np.broadcast_to(np.array(translate), len(shape2))
2749  del (shape, src_gridtype, dst_gridtype, boundary, filter, prefilter)
2750  del (src_gamma, dst_gamma, scale, translate)
2751  precision = _get_precision(precision, [array_dtype, dtype], [])
2752  weight_precision = _real_precision(precision)
2753
2754  is_noop = (
2755      all(src == dst for src, dst in zip(src_shape, shape2))
2756      and all(gt1 == gt2 for gt1, gt2 in zip(src_gridtype2, dst_gridtype2))
2757      and all(f.interpolating for f in prefilter2)
2758      and np.all(scale2 == 1.0)
2759      and np.all(translate2 == 0.0)
2760      and src_gamma2 == dst_gamma2
2761  )
2762  if is_noop:
2763    return array
2764
2765  if dim_order is None:
2766    dim_order = _arr_best_dims_order_for_resize(array, shape2)
2767  else:
2768    dim_order = tuple(dim_order)
2769    if sorted(dim_order) != list(range(len(shape2))):
2770      raise ValueError(f'{dim_order} not a permutation of {list(range(len(shape2)))}.')
2771
2772  array = src_gamma2.decode(array, precision)
2773  cval = _arr_numpy(src_gamma2.decode(cval, precision))
2774
2775  can_use_fast_box_downsampling = (
2776      _USING_NUMBA
2777      and arraylib == 'numpy'
2778      and len(shape2) == 2
2779      and array_ndim in (2, 3)
2780      and all(src > dst for src, dst in zip(src_shape, shape2))
2781      and all(src % dst == 0 for src, dst in zip(src_shape, shape2))
2782      and all(gridtype.name == 'dual' for gridtype in src_gridtype2)
2783      and all(gridtype.name == 'dual' for gridtype in dst_gridtype2)
2784      and all(f.name in ('box', 'trapezoid') for f in prefilter2)
2785      and np.all(scale2 == 1.0)
2786      and np.all(translate2 == 0.0)
2787  )
2788  if can_use_fast_box_downsampling:
2789    assert isinstance(array, np.ndarray)  # Help mypy.
2790    array = _downsample_in_2d_using_box_filter(array, cast(Any, shape2))
2791    array = dst_gamma2.encode(array, dtype)
2792    return array
2793
2794  # Multidimensional resize can be expressed using einsum() with multiple per-dim resize matrices,
2795  # e.g., as in jax.image.resize().  A benefit is to seek the optimal order of multiplications.
2796  # However, efficiency often requires sparse resize matrices, which are unsupported in einsum().
2797  # Sparse tensors requested for tf.einsum: https://github.com/tensorflow/tensorflow/issues/43497
2798  # https://github.com/tensor-compiler/taco: C++ library that computes tensor algebra expressions
2799  # on sparse and dense tensors; however it does not interoperate with tensorflow, torch, or jax.
2800
2801  for dim in dim_order:
2802    skip_resize_on_this_dim = (
2803        shape2[dim] == array.shape[dim]
2804        and scale2[dim] == 1.0
2805        and translate2[dim] == 0.0
2806        and filter2[dim].interpolating
2807    )
2808    if skip_resize_on_this_dim:
2809      continue
2810
2811    def get_is_minification() -> bool:
2812      src_in_samples = src_gridtype2[dim].size_in_samples(array.shape[dim])
2813      dst_in_samples = dst_gridtype2[dim].size_in_samples(shape2[dim])
2814      return dst_in_samples / src_in_samples * scale2[dim] < 1.0
2815
2816    is_minification = get_is_minification()
2817    boundary_dim = boundary2[dim]
2818    if boundary_dim == 'auto':
2819      boundary_dim = 'clamp' if is_minification else 'reflect'
2820    boundary_dim = _get_boundary(boundary_dim)
2821    resize_matrix, cval_weight = _create_resize_matrix(
2822        array.shape[dim],
2823        shape2[dim],
2824        src_gridtype=src_gridtype2[dim],
2825        dst_gridtype=dst_gridtype2[dim],
2826        boundary=boundary_dim,
2827        filter=filter2[dim],
2828        prefilter=prefilter2[dim],
2829        scale=scale2[dim],
2830        translate=translate2[dim],
2831        dtype=weight_precision,
2832        arraylib=arraylib,
2833    )
2834
2835    array_dim: _Array = _arr_moveaxis(array, dim, 0)
2836    array_flat = _arr_reshape(array_dim, (array_dim.shape[0], -1))
2837    array_flat = _arr_possibly_make_contiguous(array_flat)
2838    if not is_minification and filter2[dim].requires_digital_filter:
2839      array_flat = _apply_digital_filter_1d(
2840          array_flat, src_gridtype2[dim], boundary_dim, cval, filter2[dim]
2841      )
2842
2843    array_flat = _arr_matmul_sparse_dense(resize_matrix, array_flat, num_threads=num_threads)
2844    if cval_weight is not None:
2845      cval_flat = np.broadcast_to(cval, array_dim.shape[1:]).reshape(-1)
2846      if np.issubdtype(array_dtype, np.complexfloating):
2847        cval_weight = _arr_astype(cval_weight, array_dtype)  # (Only necessary for 'tensorflow'.)
2848      array_flat += cval_weight[:, None] * cval_flat
2849
2850    if is_minification and filter2[dim].requires_digital_filter:  # use prefilter2[dim]?
2851      array_flat = _apply_digital_filter_1d(
2852          array_flat, dst_gridtype2[dim], boundary_dim, cval, filter2[dim]
2853      )
2854    array_dim = _arr_reshape(array_flat, (array_flat.shape[0], *array_dim.shape[1:]))
2855    array = _arr_moveaxis(array_dim, 0, dim)
2856
2857  array = dst_gamma2.encode(array, dtype)
2858  return array
2859
2860
2861_original_resize = resize
2862
2863
2864def resize_in_arraylib(array: _NDArray, /, *args: Any, arraylib: str, **kwargs: Any) -> _NDArray:
2865  """Evaluate the `resize()` operation using the specified array library from `ARRAYLIBS`."""
2866  _check_eq(_arr_arraylib(array), 'numpy')
2867  return _arr_numpy(_original_resize(_make_array(array, arraylib), *args, **kwargs))
2868
2869
2870def resize_in_numpy(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2871  """Evaluate the `resize()` operation using the `numpy` library."""
2872  return resize_in_arraylib(array, *args, arraylib='numpy', **kwargs)
2873
2874
2875def resize_in_tensorflow(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2876  """Evaluate the `resize()` operation using the `tensorflow` library."""
2877  return resize_in_arraylib(array, *args, arraylib='tensorflow', **kwargs)
2878
2879
2880def resize_in_torch(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2881  """Evaluate the `resize()` operation using the `torch` library."""
2882  return resize_in_arraylib(array, *args, arraylib='torch', **kwargs)
2883
2884
2885def resize_in_jax(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2886  """Evaluate the `resize()` operation using the `jax` library."""
2887  return resize_in_arraylib(array, *args, arraylib='jax', **kwargs)
2888
2889
2890def _resize_possibly_in_arraylib(
2891    array: _Array, /, *args: Any, arraylib: str, **kwargs: Any
2892) -> _AnyArray:
2893  """If `array` is from numpy, evaluate `resize()` using the array library from `ARRAYLIBS`."""
2894  if _arr_arraylib(array) == 'numpy':
2895    return _arr_numpy(
2896        _original_resize(_make_array(cast(_ArrayLike, array), arraylib), *args, **kwargs)
2897    )
2898  return _original_resize(array, *args, **kwargs)
2899
2900
2901@functools.cache
2902def _create_jaxjit_resize() -> Callable[..., _Array]:
2903  """Lazily invoke `jax.jit` on `resize`."""
2904  import jax
2905
2906  jitted: Any = jax.jit(
2907      _original_resize,
2908      static_argnums=(1,),
2909      static_argnames=list(_original_resize.__kwdefaults__ or []),
2910  )
2911  return jitted
2912
2913
2914def jaxjit_resize(array: _Array, /, *args: Any, **kwargs: Any) -> _Array:
2915  """Compute `resize` but with resize function jitted using Jax."""
2916  return _create_jaxjit_resize()(array, *args, **kwargs)  # pylint: disable=not-callable
2917
2918
2919def uniform_resize(
2920    array: _Array,
2921    /,
2922    shape: Iterable[int],
2923    *,
2924    object_fit: Literal['contain', 'cover'] = 'contain',
2925    gridtype: str | Gridtype | None = None,
2926    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2927    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2928    boundary: str | Boundary | Iterable[str | Boundary] = 'natural',  # Instead of 'auto' default.
2929    scale: float | Iterable[float] = 1.0,
2930    translate: float | Iterable[float] = 0.0,
2931    **kwargs: Any,
2932) -> _Array:
2933  """Resample `array` onto a grid with resolution `shape` but with uniform scaling.
2934
2935  Calls function `resize` with `scale` and `translate` set such that the aspect ratio of `array`
2936  is preserved.  The effect is similar to CSS `object-fit: contain`.
2937  The parameter `boundary` (whose default is changed to `'natural'`) determines the values assigned
2938  outside the source domain.
2939
2940  Args:
2941    array: Regular grid of source sample values.
2942    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2943      `array` must have at least as many dimensions as `len(shape)`.
2944    object_fit: Like CSS `object-fit`.  If `'contain'`, `array` is resized uniformly to fit within
2945      `shape`. If `'cover'`, `array` is resized to fully cover `shape`.
2946    gridtype: Placement of samples on all dimensions of both the source and output domain grids.
2947    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2948    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2949    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2950      a name in `BOUNDARIES` or a `Boundary` instance.  The default is `'natural'`, which assigns
2951      `cval` to output points that map outside the source unit domain.
2952    scale: Parameter may not be specified.
2953    translate: Parameter may not be specified.
2954    **kwargs: Additional parameters for `resize` function (including `cval`).
2955
2956  Returns:
2957    An array with shape `shape + array.shape[len(shape):]`.
2958
2959  >>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
2960  array([[0., 1., 1., 0.],
2961         [0., 1., 1., 0.]])
2962
2963  >>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
2964  array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
2965         [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
2966
2967  >>> a = np.arange(6.0).reshape(2, 3)
2968  >>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
2969  array([[0.5, 1.5],
2970         [3.5, 4.5]])
2971  """
2972  if scale != 1.0 or translate != 0.0:
2973    raise ValueError('`uniform_resize()` does not accept `scale` or `translate` parameters.')
2974  if isinstance(array, (tuple, list)):
2975    array = np.asarray(array)
2976  shape = tuple(shape)
2977  array_ndim = len(array.shape)
2978  if not 0 < len(shape) <= array_ndim:
2979    raise ValueError(f'Shape {array.shape} cannot be resized to {shape}.')
2980  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2981      gridtype, src_gridtype, dst_gridtype, len(shape), len(shape)
2982  )
2983  raw_scales = [
2984      dst_gridtype2[dim].size_in_samples(shape[dim])
2985      / src_gridtype2[dim].size_in_samples(array.shape[dim])
2986      for dim in range(len(shape))
2987  ]
2988  scale0 = {'contain': min(raw_scales), 'cover': max(raw_scales)}[object_fit]
2989  scale2 = scale0 / np.array(raw_scales)
2990  translate = (1.0 - scale2) / 2
2991  return resize(array, shape, boundary=boundary, scale=scale2, translate=translate, **kwargs)
2992
2993
2994_MAX_BLOCK_SIZE_RECURSING = -999  # Special value to indicate re-invocation on partitioned blocks.
2995
2996
2997def resample(
2998    array: _Array,
2999    /,
3000    coords: _ArrayLike,
3001    *,
3002    gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual',
3003    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
3004    cval: _ArrayLike = 0.0,
3005    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3006    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3007    gamma: str | Gamma | None = None,
3008    src_gamma: str | Gamma | None = None,
3009    dst_gamma: str | Gamma | None = None,
3010    jacobian: _ArrayLike | None = None,
3011    precision: _DTypeLike | None = None,
3012    dtype: _DTypeLike | None = None,
3013    max_block_size: int = 40_000,
3014    debug: bool = False,
3015) -> _Array:
3016  """Interpolate `array` (a grid of samples) at specified unit-domain coordinates `coords`.
3017
3018  The last dimension of `coords` contains unit-domain coordinates at which to interpolate the
3019  domain grid samples in `array`.
3020
3021  The number of coordinates (`coords.shape[-1]`) determines how to interpret `array`: its first
3022  `coords.shape[-1]` dimensions define the grid, and the remaining dimensions describe each grid
3023  sample (e.g., scalar, vector, tensor).
3024
3025  Concretely, the grid has shape `array.shape[:coords.shape[-1]]` and each grid sample has shape
3026  `array.shape[coords.shape[-1]:]`.
3027
3028  Examples include:
3029
3030  - Resample a grayscale image with `array.shape = height, width` onto a new grayscale image with
3031    `new.shape = height2, width2` by using `coords.shape = height2, width2, 2`.
3032
3033  - Resample an RGB image with `array.shape = height, width, 3` onto a new RGB image with
3034    `new.shape = height2, width2, 3` by using `coords.shape = height2, width2, 2`.
3035
3036  - Sample an RGB image at `num` 2D points along a line segment by using `coords.shape = num, 2`.
3037
3038  - Sample an RGB image at a single 2D point by using `coords.shape = (2,)`.
3039
3040  - Sample a 3D grid of 3x3 Jacobians with `array.shape = nz, ny, nx, 3, 3` along a 2D plane by
3041    using `coords.shape = height, width, 3`.
3042
3043  - Map a grayscale image through a color map by using `array.shape = 256, 3` and
3044    `coords.shape = height, width`.
3045
3046  Args:
3047    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3048      The array must have numeric type.  The coordinate dimensions appear first, and
3049      each grid sample may have an arbitrary shape.  Each grid dimension must be at least 1 for
3050      a `'dual'` grid or at least 2 for a `'primal'` grid.
3051    coords: Grid of points at which to resample `array`.  The point coordinates are in the last
3052      dimension of `coords`.  The domain associated with the source grid is a unit hypercube,
3053      i.e. with a range [0, 1] on each coordinate dimension.  The output grid has shape
3054      `coords.shape[:-1]` and each of its grid samples has shape `array.shape[coords.shape[-1]:]`.
3055    gridtype: Placement of the samples in the source domain grid for each dimension, specified as
3056      either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`.
3057    boundary: The reconstruction boundary rule for each dimension in `coords.shape[-1]`, specified
3058      as either a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses
3059      `'reflect'` for upsampling and `'clamp'` for downsampling.
3060    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
3061      onto the shape `array.shape[coords.shape[-1]:]`.  It is subject to `src_gamma`.
3062    filter: The reconstruction kernel for each dimension in `coords.shape[-1]`, specified as either
3063      a filter name in `FILTERS` or a `Filter` instance.
3064    prefilter: The prefilter kernel for each dimension in `coords.shape[:-1]`, specified as either
3065      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3066      (i.e., minification).  If `None`, it inherits the value of `filter`.
3067    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples
3068      from `array` and when creating output grid samples.  It is specified as either a name in
3069      `GAMMAS` or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default
3070      is `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
3071      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
3072      range [0.0, 1.0].
3073    src_gamma: Component transfer function used to "decode" `array` samples.
3074      Parameters `gamma` and `src_gamma` cannot both be set.
3075    dst_gamma: Component transfer function used to "encode" the output samples.
3076      Parameters `gamma` and `dst_gamma` cannot both be set.
3077    jacobian: Optional array, which must be broadcastable onto the shape
3078      `coords.shape[:-1] + (coords.shape[-1], coords.shape[-1])`, storing for each point in the
3079      output grid the Jacobian matrix of the map from the unit output domain to the unit source
3080      domain.  If omitted, it is estimated by computing finite differences on `coords`.
3081    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3082      on `array.dtype`, `coords.dtype`, and `dtype`.
3083    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3084      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3085      to the uint range.
3086    max_block_size: If nonzero, maximum number of grid points in `coords` before the resampling
3087      evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
3088    debug: Show internal information.
3089
3090  Returns:
3091    A new sample grid of shape `coords.shape[:-1]`, represented as an array of shape
3092    `coords.shape[:-1] + array.shape[coords.shape[-1]:]`, of the same array library type as
3093    the source array.
3094
3095  **Example of resample operation:**
3096
3097  <center>
3098  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_warp_coords.png"/>
3099  </center>
3100
3101  For reference, the identity resampling for a scalar-valued grid with the default grid-type
3102  `'dual'` is:
3103
3104  >>> array = np.random.default_rng(0).random((5, 7, 3))
3105  >>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
3106  >>> new_array = resample(array, coords)
3107  >>> assert np.allclose(new_array, array)
3108
3109  It is more efficient to use the function `resize` for the special case where the `coords` are
3110  obtained as simple scaling and translation of a new regular grid over the source domain:
3111
3112  >>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
3113  >>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
3114  >>> coords = (coords - translate) / scale
3115  >>> resampled = resample(array, coords)
3116  >>> resized = resize(array, new_shape, scale=scale, translate=translate)
3117  >>> assert np.allclose(resampled, resized)
3118  """
3119  if isinstance(array, (tuple, list)):
3120    array = np.asarray(array)
3121  arraylib = _arr_arraylib(array)
3122  if len(array.shape) == 0:
3123    array = array[None]
3124  coords = np.atleast_1d(coords)
3125  if not np.issubdtype(_arr_dtype(array), np.number):
3126    raise ValueError(f'Type {array.dtype} is not numeric.')
3127  if not np.issubdtype(coords.dtype, np.floating):
3128    raise ValueError(f'Type {coords.dtype} is not floating.')
3129  array_ndim = len(array.shape)
3130  if coords.ndim == 1 and coords.shape[0] > 1 and array_ndim == 1:
3131    coords = coords[:, None]
3132  grid_ndim = coords.shape[-1]
3133  grid_shape = array.shape[:grid_ndim]
3134  sample_shape = array.shape[grid_ndim:]
3135  resampled_ndim = coords.ndim - 1
3136  resampled_shape = coords.shape[:-1]
3137  if grid_ndim > array_ndim:
3138    raise ValueError(
3139        f'There are more coordinate dimensions ({grid_ndim}) in {coords=} than in {array.shape=}.'
3140    )
3141  gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(gridtype), grid_ndim)]
3142  boundary2 = np.broadcast_to(np.array(boundary), grid_ndim).tolist()
3143  cval = np.broadcast_to(cval, sample_shape)
3144  prefilter = filter if prefilter is None else prefilter
3145  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), grid_ndim)]
3146  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), resampled_ndim)]
3147  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3148  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, _arr_dtype(array), dtype)
3149  del gridtype, boundary, filter, prefilter, src_gamma, dst_gamma
3150  if jacobian is not None:
3151    jacobian = np.broadcast_to(jacobian, resampled_shape + (coords.shape[-1],) * 2)
3152  precision = _get_precision(precision, [_arr_dtype(array), dtype], [coords.dtype])
3153  weight_precision = _real_precision(precision)
3154  coords = coords.astype(weight_precision, copy=False)
3155  is_minification = False  # Current limitation; no prefiltering!
3156  assert max_block_size >= 0 or max_block_size == _MAX_BLOCK_SIZE_RECURSING
3157  for dim in range(grid_ndim):
3158    if boundary2[dim] == 'auto':
3159      boundary2[dim] = 'clamp' if is_minification else 'reflect'
3160    boundary2[dim] = _get_boundary(boundary2[dim])
3161
3162  if max_block_size != _MAX_BLOCK_SIZE_RECURSING:
3163    array = src_gamma2.decode(array, precision)
3164    for dim in range(grid_ndim):
3165      assert not is_minification
3166      if filter2[dim].requires_digital_filter:
3167        array = _apply_digital_filter_1d(
3168            array, gridtype2[dim], boundary2[dim], cval, filter2[dim], axis=dim
3169        )
3170    cval = _arr_numpy(src_gamma2.decode(cval, precision))
3171
3172  if math.prod(resampled_shape) > max_block_size > 0:
3173    block_shape = _block_shape_with_min_size(resampled_shape, max_block_size)
3174    if debug:
3175      print(f'(resample: splitting coords into blocks {block_shape}).')
3176    coord_blocks = _split_array_into_blocks(coords, block_shape)
3177
3178    def process_block(coord_block: _NDArray) -> _Array:
3179      return resample(
3180          array,
3181          coord_block,
3182          gridtype=gridtype2,
3183          boundary=boundary2,
3184          cval=cval,
3185          filter=filter2,
3186          prefilter=prefilter2,
3187          src_gamma='identity',
3188          dst_gamma=dst_gamma2,
3189          jacobian=jacobian,
3190          precision=precision,
3191          dtype=dtype,
3192          max_block_size=_MAX_BLOCK_SIZE_RECURSING,
3193      )
3194
3195    result_blocks = _map_function_over_blocks(coord_blocks, process_block)
3196    array = _merge_array_from_blocks(result_blocks)
3197    return array
3198
3199  # A concrete example of upsampling:
3200  #   array = np.ones((5, 7, 3))  # source RGB image has height=5 width=7
3201  #   coords = np.random.default_rng(0).random((8, 9, 2))  # output RGB image has height=8 width=9
3202  #   resample(array, coords, filter=('cubic', 'lanczos3'))
3203  #   grid_shape = 5, 7  grid_ndim = 2
3204  #   resampled_shape = 8, 9  resampled_ndim = 2
3205  #   sample_shape = (3,)
3206  #   src_float_index.shape = 8, 9
3207  #   src_first_index.shape = 8, 9
3208  #   sample_index.shape = (4,) for dim == 0, then (6,) for dim == 1
3209  #   weight = [shape(8, 9, 4), shape(8, 9, 6)]
3210  #   src_index = [shape(8, 9, 4), shape(8, 9, 6)]
3211
3212  # Both:[shape(8, 9, 4), shape(8, 9, 6)]
3213  weight: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3214  src_index: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3215  uses_cval = False
3216  all_num_samples = []  # will be [4, 6]
3217
3218  for dim in range(grid_ndim):
3219    src_size = grid_shape[dim]  # scalar
3220    coords_dim = coords[..., dim]  # (8, 9)
3221    radius = filter2[dim].radius  # scalar
3222    num_samples = int(np.ceil(radius * 2))  # scalar
3223    all_num_samples.append(num_samples)
3224
3225    boundary_dim = boundary2[dim]
3226    coords_dim = boundary_dim.preprocess_coordinates(coords_dim)
3227
3228    # Sample positions mapped back to source unit domain [0, 1].
3229    src_float_index = gridtype2[dim].index_from_point(coords_dim, src_size)  # (8, 9)
3230    src_first_index = (
3231        np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
3232        - (num_samples - 1) // 2
3233    )  # (8, 9)
3234
3235    sample_index = np.arange(num_samples, dtype=np.int32)  # (4,) then (6,)
3236    src_index[dim] = src_first_index[..., None] + sample_index  # (8, 9, 4) then (8, 9, 6)
3237    if filter2[dim].name == 'trapezoid':
3238      # (It might require changing the filter radius at every sample.)
3239      raise ValueError('resample() cannot use adaptive `trapezoid` filter.')
3240    if filter2[dim].name == 'impulse':
3241      weight[dim] = np.ones_like(src_index[dim], weight_precision)
3242    else:
3243      x = src_float_index[..., None] - src_index[dim].astype(weight_precision)
3244      weight[dim] = filter2[dim](x).astype(weight_precision, copy=False)
3245      if filter2[dim].name != 'narrowbox' and (
3246          is_minification or not filter2[dim].partition_of_unity
3247      ):
3248        weight[dim] = weight[dim] / weight[dim].sum(axis=-1)[..., None]
3249
3250    src_index[dim], weight[dim] = boundary_dim.apply(
3251        src_index[dim], weight[dim], coords_dim, src_size, gridtype2[dim]
3252    )
3253    if boundary_dim.uses_cval or filter2[dim].name == 'narrowbox':
3254      uses_cval = True
3255
3256  # Gather the samples.
3257
3258  # Recall that src_index = [shape(8, 9, 4), shape(8, 9, 6)].
3259  src_index_expanded = []
3260  for dim in range(grid_ndim):
3261    src_index_dim = np.moveaxis(
3262        src_index[dim].reshape(src_index[dim].shape + (1,) * (grid_ndim - 1)),
3263        resampled_ndim,
3264        resampled_ndim + dim,
3265    )
3266    src_index_expanded.append(src_index_dim)
3267  indices = tuple(src_index_expanded)  # (shape(8, 9, 4, 1), shape(8, 9, 1, 6))
3268  samples = _arr_getitem(array, indices)  # (8, 9, 4, 6, 3)
3269
3270  # Indirectly derive samples.ndim (which is unavailable during Tensorflow grad computation).
3271  samples_ndim = resampled_ndim + grid_ndim + len(sample_shape)
3272
3273  # Compute an Einstein summation over the samples and each of the per-dimension weights.
3274
3275  def label(dims: Iterable[int]) -> str:
3276    return ''.join(chr(ord('a') + i) for i in dims)
3277
3278  operands = [samples]  # (8, 9, 4, 6, 3)
3279  assert samples_ndim < 26  # Letters 'a' through 'z'.
3280  labels = [label(range(samples_ndim))]  # ['abcde']
3281  for dim in range(grid_ndim):
3282    operands.append(weight[dim])  # (8, 9, 4), then (8, 9, 6)
3283    labels.append(label(list(range(resampled_ndim)) + [resampled_ndim + dim]))  # 'abc' then 'abd'
3284  output_label = label(
3285      list(range(resampled_ndim)) + list(range(resampled_ndim + grid_ndim, samples_ndim))
3286  )  # 'abe'
3287  subscripts = ','.join(labels) + '->' + output_label  # 'abcde,abc,abd->abe'
3288  # Starting in numpy 2.0, np.einsum() outputs np.float64 even with all np.float32 inputs;
3289  # GPT: "aligns np.einsum with other functions where intermediate calculations use higher
3290  # precision (np.float64) regardless of input type when floating-point arithmetic is involved."
3291  # we could explicitly add the parameter `dtype=precision`.
3292  array = _arr_einsum(subscripts, *operands)  # (8, 9, 3)
3293
3294  # Gathering `samples` is the memory bottleneck.  It would be ideal if the gather() and einsum()
3295  # computations could be fused.  In Jax, https://github.com/google/jax/issues/3206 suggests
3296  # that this may become possible.  In any case, for large outputs it helps to partition the
3297  # evaluation over output tiles (using max_block_size).
3298
3299  if uses_cval:
3300    cval_weight = 1.0 - np.multiply.reduce(
3301        [weight[dim].sum(axis=-1) for dim in range(resampled_ndim)]
3302    )  # (8, 9)
3303    cval_weight_reshaped = cval_weight.reshape(cval_weight.shape + (1,) * len(sample_shape))
3304    array += _make_array((cval_weight_reshaped * cval).astype(precision, copy=False), arraylib)
3305
3306  array = dst_gamma2.encode(array, dtype)
3307  return array
3308
3309
3310def resample_affine(
3311    array: _Array,
3312    /,
3313    shape: Iterable[int],
3314    matrix: _ArrayLike,
3315    *,
3316    gridtype: str | Gridtype | None = None,
3317    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3318    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3319    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3320    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3321    precision: _DTypeLike | None = None,
3322    dtype: _DTypeLike | None = None,
3323    **kwargs: Any,
3324) -> _Array:
3325  """Resample a source array using an affinely transformed grid of given shape.
3326
3327  The `matrix` transformation can be linear,
3328    `source_point = matrix @ destination_point`,
3329  or it can be affine where the last matrix column is an offset vector,
3330    `source_point = matrix @ (destination_point, 1.0)`.
3331
3332  Args:
3333    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3334      The array must have numeric type.  The number of grid dimensions is determined from
3335      `matrix.shape[0]`; the remaining dimensions are for each sample value and are all
3336      linearly interpolated.
3337    shape: Dimensions of the desired destination grid.  The number of destination grid dimensions
3338      may be different from that of the source grid.
3339    matrix: 2D array for a linear or affine transform from unit-domain destination points
3340      (in a space with `len(shape)` dimensions) into unit-domain source points (in a space with
3341      `matrix.shape[0]` dimensions).  If the matrix has `len(shape) + 1` columns, the last column
3342      is the affine offset (i.e., translation).
3343    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
3344      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
3345      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
3346    src_gridtype: Placement of samples in the source domain grid for each dimension.
3347      Parameters `gridtype` and `src_gridtype` cannot both be set.
3348    dst_gridtype: Placement of samples in the output domain grid for each dimension.
3349      Parameters `gridtype` and `dst_gridtype` cannot both be set.
3350    filter: The reconstruction kernel for each dimension in `matrix.shape[0]`, specified as either
3351      a filter name in `FILTERS` or a `Filter` instance.
3352    prefilter: The prefilter kernel for each dimension in `len(shape)`, specified as either
3353      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3354      (i.e., minification).  If `None`, it inherits the value of `filter`.
3355    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3356      on `array.dtype` and `dtype`.
3357    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3358      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3359      to the uint range.
3360    **kwargs: Additional parameters for `resample` function.
3361
3362  Returns:
3363    An array of the same class as the source `array`, representing a grid with specified `shape`,
3364    where each grid value is resampled from `array`.  Thus the shape of the returned array is
3365    `shape + array.shape[matrix.shape[0]:]`.
3366  """
3367  if isinstance(array, (tuple, list)):
3368    array = np.asarray(array)
3369  shape = tuple(shape)
3370  matrix = np.asarray(matrix)
3371  dst_ndim = len(shape)
3372  if matrix.ndim != 2:
3373    raise ValueError(f'Array {matrix} is not 2D matrix.')
3374  src_ndim = matrix.shape[0]
3375  # grid_shape = array.shape[:src_ndim]
3376  is_affine = matrix.shape[1] == dst_ndim + 1
3377  if src_ndim > len(array.shape):
3378    raise ValueError(
3379        f'Matrix {matrix} has more rows ({matrix.shape[0]}) than ndim in {array.shape=}.'
3380    )
3381  if matrix.shape[1] != dst_ndim and not is_affine:
3382    raise ValueError(
3383        f'Matrix has {matrix.shape=}, but we expect either {dst_ndim} or {dst_ndim + 1} columns.'
3384    )
3385  src_gridtype2, dst_gridtype2 = _get_gridtypes(
3386      gridtype, src_gridtype, dst_gridtype, src_ndim, dst_ndim
3387  )
3388  prefilter = filter if prefilter is None else prefilter
3389  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), src_ndim)]
3390  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), dst_ndim)]
3391  del src_gridtype, dst_gridtype, filter, prefilter
3392  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3393  precision = _get_precision(precision, [_arr_dtype(array), dtype], [])
3394  weight_precision = _real_precision(precision)
3395
3396  dst_position_list = []  # per dimension
3397  for dim in range(dst_ndim):
3398    dst_size = shape[dim]
3399    dst_index = np.arange(dst_size, dtype=weight_precision)
3400    dst_position_list.append(dst_gridtype2[dim].point_from_index(dst_index, dst_size))
3401  dst_position = np.meshgrid(*dst_position_list, indexing='ij')
3402
3403  linear_matrix = matrix[:, :-1] if is_affine else matrix
3404  src_position = np.tensordot(linear_matrix, dst_position, 1)
3405  coords = np.moveaxis(src_position, 0, -1)
3406  if is_affine:
3407    coords += matrix[:, -1]
3408
3409  # TODO: Based on grid_shape, shape, linear_matrix, and prefilter, determine a
3410  # convolution prefilter and apply it to bandlimit 'array', using boundary for padding.
3411
3412  return resample(
3413      array,
3414      coords,
3415      gridtype=src_gridtype2,
3416      filter=filter2,
3417      prefilter=prefilter2,
3418      precision=precision,
3419      dtype=dtype,
3420      **kwargs,
3421  )
3422
3423
3424def _resize_using_resample(
3425    array: _Array,
3426    /,
3427    shape: Iterable[int],
3428    *,
3429    scale: _ArrayLike = 1.0,
3430    translate: _ArrayLike = 0.0,
3431    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3432    fallback: bool = False,
3433    **kwargs: Any,
3434) -> _Array:
3435  """Use the more general `resample` operation for `resize`, as a debug tool."""
3436  if isinstance(array, (tuple, list)):
3437    array = np.asarray(array)
3438  shape = tuple(shape)
3439  scale = np.broadcast_to(scale, len(shape))
3440  translate = np.broadcast_to(translate, len(shape))
3441  # TODO: let resample() do prefiltering for proper downsampling.
3442  has_minification = np.any(np.array(shape) < array.shape[: len(shape)]) or np.any(scale < 1.0)
3443  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape))]
3444  has_auto_trapezoid = any(f.name == 'trapezoid' for f in filter2)
3445  if fallback and (has_minification or has_auto_trapezoid):
3446    return _original_resize(array, shape, scale=scale, translate=translate, filter=filter, **kwargs)
3447  offset = -translate / scale
3448  matrix = np.concatenate([np.diag(1.0 / scale), offset[:, None]], axis=1)
3449  return resample_affine(array, shape, matrix, filter=filter, **kwargs)
3450
3451
3452def rotation_about_center_in_2d(
3453    src_shape: _ArrayLike,
3454    /,
3455    angle: float,
3456    *,
3457    new_shape: _ArrayLike | None = None,
3458    scale: float = 1.0,
3459) -> _NDArray:
3460  """Return the 3x3 matrix mapping destination into a source unit domain.
3461
3462  The returned matrix accounts for the possibly non-square domain shapes.
3463
3464  Args:
3465    src_shape: Resolution `(ny, nx)` of the source domain grid.
3466    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3467      onto the destination domain.
3468    new_shape: Resolution `(ny, nx)` of the destination domain grid; it defaults to `src_shape`.
3469    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3470  """
3471
3472  def translation_matrix(vector: _NDArray) -> _NDArray:
3473    matrix = np.eye(len(vector) + 1)
3474    matrix[:-1, -1] = vector
3475    return matrix
3476
3477  def scaling_matrix(scale: _NDArray) -> _NDArray:
3478    return np.diag(tuple(scale) + (1.0,))
3479
3480  def rotation_matrix_2d(angle: float) -> _NDArray:
3481    cos, sin = np.cos(angle), np.sin(angle)
3482    return np.array([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
3483
3484  src_shape = np.asarray(src_shape)
3485  new_shape = src_shape if new_shape is None else np.asarray(new_shape)
3486  _check_eq(src_shape.shape, (2,))
3487  _check_eq(new_shape.shape, (2,))
3488  half = np.array([0.5, 0.5])
3489  matrix = (
3490      translation_matrix(half)
3491      @ scaling_matrix(min(src_shape) / src_shape)
3492      @ rotation_matrix_2d(angle)
3493      @ scaling_matrix(scale * new_shape / min(new_shape))
3494      @ translation_matrix(-half)
3495  )
3496  assert np.allclose(matrix[-1], [0.0, 0.0, 1.0])
3497  return matrix
3498
3499
3500def rotate_image_about_center(
3501    image: _NDArray,
3502    /,
3503    angle: float,
3504    *,
3505    new_shape: _ArrayLike | None = None,
3506    scale: float = 1.0,
3507    num_rotations: int = 1,
3508    **kwargs: Any,
3509) -> _NDArray:
3510  """Return a copy of `image` rotated about its center.
3511
3512  Args:
3513    image: Source grid samples; the first two dimensions are spatial (ny, nx).
3514    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3515      onto the destination domain.
3516    new_shape: Resolution `(ny, nx)` of the output grid; it defaults to `image.shape[:2]`.
3517    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3518    num_rotations: Number of rotations (each by `angle`).  Successive resamplings are useful in
3519      analyzing the filtering quality.
3520    **kwargs: Additional parameters for `resample_affine`.
3521  """
3522  new_shape = image.shape[:2] if new_shape is None else np.asarray(new_shape)
3523  matrix = rotation_about_center_in_2d(image.shape[:2], angle, new_shape=new_shape, scale=scale)
3524  for _ in range(num_rotations):
3525    image = resample_affine(image, new_shape, matrix[:-1], **kwargs)
3526  return image
3527
3528
3529def _pil_image_resize(
3530    array: _ArrayLike,
3531    /,
3532    shape: Iterable[int],
3533    *,
3534    filter: str,
3535    boundary: str = 'natural',
3536    cval: float = 0.0,
3537) -> _NDArray:
3538  """Invoke `PIL.Image.resize` using the same parameters as `resize`."""
3539  import PIL.Image
3540
3541  if boundary != 'natural':
3542    raise ValueError(f"{boundary=} must equal 'natural'.")
3543  del cval
3544  array = np.asarray(array)
3545  assert 1 <= array.ndim <= 3
3546  assert np.issubdtype(array.dtype, np.floating)
3547  shape = tuple(shape)
3548  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3549  if array.ndim == 1:
3550    return _pil_image_resize(array[None], (1, *shape), filter=filter)[0]
3551  if not hasattr(PIL.Image, 'Resampling'):  # Pillow<9.0
3552    PIL.Image.Resampling = PIL.Image  # type: ignore
3553  filters = {
3554      'impulse': PIL.Image.Resampling.NEAREST,
3555      'box': PIL.Image.Resampling.BOX,
3556      'triangle': PIL.Image.Resampling.BILINEAR,
3557      'hamming1': PIL.Image.Resampling.HAMMING,
3558      'cubic': PIL.Image.Resampling.BICUBIC,
3559      'lanczos3': PIL.Image.Resampling.LANCZOS,
3560  }
3561  if filter not in filters:
3562    raise ValueError(f'{filter=} not in {filters=}.')
3563  pil_resample = filters[filter]
3564  ny, nx = shape
3565  if array.ndim == 2:
3566    return np.array(PIL.Image.fromarray(array).resize((nx, ny), resample=pil_resample), array.dtype)
3567  stack = []
3568  for channel in np.moveaxis(array, -1, 0):
3569    pil_image = PIL.Image.fromarray(channel).resize((nx, ny), resample=pil_resample)
3570    stack.append(np.array(pil_image, array.dtype))
3571  return np.dstack(stack)
3572
3573
3574def _cv_resize(
3575    array: _ArrayLike,
3576    /,
3577    shape: Iterable[int],
3578    *,
3579    filter: str,
3580    boundary: str = 'clamp',
3581    cval: float = 0.0,
3582) -> _NDArray:
3583  """Invoke `cv.resize` using the same parameters as `resize`."""
3584  import cv2 as cv
3585
3586  if boundary != 'clamp':
3587    raise ValueError(f"{boundary=} must equal 'clamp'.")
3588  del cval
3589  array = np.asarray(array)
3590  assert 1 <= array.ndim <= 3
3591  shape = tuple(shape)
3592  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3593  if array.ndim == 1:
3594    return _cv_resize(array[None], (1, *shape), filter=filter)[0]
3595  filters = {
3596      'impulse': cv.INTER_NEAREST,  # Or consider cv.INTER_NEAREST_EXACT.
3597      'triangle': cv.INTER_LINEAR_EXACT,  # Or just cv.INTER_LINEAR.
3598      'trapezoid': cv.INTER_AREA,
3599      'sharpcubic': cv.INTER_CUBIC,
3600      'lanczos4': cv.INTER_LANCZOS4,
3601  }
3602  if filter not in filters:
3603    raise ValueError(f'{filter=} not in {filters=}.')
3604  interpolation = filters[filter]
3605  result = cv.resize(array, shape[::-1], interpolation=interpolation)
3606  if array.ndim == 3 and result.ndim == 2:
3607    assert array.shape[2] == 1
3608    return result[..., None]  # Add back the last dimension dropped by cv.resize().
3609  return result
3610
3611
3612def _scipy_ndimage_resize(
3613    array: _ArrayLike,
3614    /,
3615    shape: Iterable[int],
3616    *,
3617    filter: str,
3618    boundary: str = 'reflect',
3619    cval: float = 0.0,
3620    scale: float | Iterable[float] = 1.0,
3621    translate: float | Iterable[float] = 0.0,
3622) -> _NDArray:
3623  """Invoke `scipy.ndimage.map_coordinates` using the same parameters as `resize`."""
3624  array = np.asarray(array)
3625  shape = tuple(shape)
3626  assert 1 <= len(shape) <= array.ndim
3627  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3628  if filter not in filters:
3629    raise ValueError(f'{filter=} not in {filters=}.')
3630  order = filters[filter]
3631  boundaries = {'reflect': 'reflect', 'wrap': 'grid-wrap', 'clamp': 'nearest', 'border': 'constant'}
3632  if boundary not in boundaries:
3633    raise ValueError(f'{boundary=} not in {boundaries=}.')
3634  mode = boundaries[boundary]
3635  shape_all = shape + array.shape[len(shape) :]
3636  coords = np.moveaxis(np.indices(shape_all, array.dtype), 0, -1)
3637  coords[..., : len(shape)] = (
3638      (coords[..., : len(shape)] + 0.5) / shape - np.asarray(translate)
3639  ) / np.asarray(scale) * np.array(array.shape)[: len(shape)] - 0.5
3640  coords = np.moveaxis(coords, -1, 0)
3641  return scipy.ndimage.map_coordinates(array, coords, order=order, mode=mode, cval=cval)
3642
3643
3644def _skimage_transform_resize(
3645    array: _ArrayLike,
3646    /,
3647    shape: Iterable[int],
3648    *,
3649    filter: str,
3650    boundary: str = 'reflect',
3651    cval: float = 0.0,
3652) -> _NDArray:
3653  """Invoke `skimage.transform.resize` using the same parameters as `resize`."""
3654  import skimage.transform
3655
3656  array = np.asarray(array)
3657  shape = tuple(shape)
3658  assert 1 <= len(shape) <= array.ndim
3659  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3660  if filter not in filters:
3661    raise ValueError(f'{filter=} not in {filters=}.')
3662  order = filters[filter]
3663  boundaries = {'reflect': 'symmetric', 'wrap': 'wrap', 'clamp': 'edge', 'border': 'constant'}
3664  if boundary not in boundaries:
3665    raise ValueError(f'{boundary=} not in {boundaries=}.')
3666  mode = boundaries[boundary]
3667  shape_all = shape + array.shape[len(shape) :]
3668  # Default anti_aliasing=None automatically enables (poor) Gaussian prefilter if downsampling.
3669  # clip=False is the default behavior in `resampler` if the output type is non-integer.
3670  return skimage.transform.resize(
3671      array, shape_all, order=order, mode=mode, cval=cval, clip=False
3672  )  # type: ignore[no-untyped-call]
3673
3674
3675_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER = {
3676    'impulse': 'nearest',
3677    'trapezoid': 'area',
3678    'triangle': 'bilinear',
3679    'mitchell': 'mitchellcubic',
3680    'cubic': 'bicubic',
3681    'lanczos3': 'lanczos3',
3682    'lanczos5': 'lanczos5',
3683    # GaussianFilter(0.5): 'gaussian',  # radius_4 > desired_radius_3.
3684}
3685
3686
3687def _tf_image_resize(
3688    array: _ArrayLike,
3689    /,
3690    shape: Iterable[int],
3691    *,
3692    filter: str,
3693    boundary: str = 'natural',
3694    cval: float = 0.0,
3695    antialias: bool = True,
3696) -> _TensorflowTensor:
3697  """Invoke `tf.image.resize` using the same parameters as `resize`."""
3698  import tensorflow as tf
3699
3700  if filter not in _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER:
3701    raise ValueError(f'{filter=} not in {_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER=}.')
3702  if boundary != 'natural':
3703    raise ValueError(f"{boundary=} must equal 'natural'.")
3704  del cval
3705  array2 = tf.convert_to_tensor(array)
3706  ndim = len(array2.shape)
3707  del array
3708  assert 1 <= ndim <= 3
3709  shape = tuple(shape)
3710  _check_eq(len(shape), 2 if ndim >= 2 else 1)
3711  match ndim:
3712    case 1:
3713      return _tf_image_resize(array2[None], (1, *shape), filter=filter, antialias=antialias)[0]
3714    case 2:
3715      return _tf_image_resize(array2[..., None], shape, filter=filter, antialias=antialias)[..., 0]
3716    case _:
3717      method = _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER[filter]
3718      return tf.image.resize(array2, shape, method=method, antialias=antialias)
3719
3720
3721_TORCH_INTERPOLATE_MODE_FROM_FILTER = {
3722    'impulse': 'nearest-exact',  # ('nearest' matches buggy OpenCV's INTER_NEAREST)
3723    'trapezoid': 'area',
3724    'triangle': 'bilinear',
3725    'sharpcubic': 'bicubic',
3726}
3727
3728
3729def _torch_nn_resize(
3730    array: _ArrayLike,
3731    /,
3732    shape: Iterable[int],
3733    *,
3734    filter: str,
3735    boundary: str = 'clamp',
3736    cval: float = 0.0,
3737    antialias: bool = False,
3738) -> _TorchTensor:
3739  """Invoke `torch.nn.functional.interpolate` using the same parameters as `resize`."""
3740  import torch
3741
3742  if filter not in _TORCH_INTERPOLATE_MODE_FROM_FILTER:
3743    raise ValueError(f'{filter=} not in {_TORCH_INTERPOLATE_MODE_FROM_FILTER=}.')
3744  if boundary != 'clamp':
3745    raise ValueError(f"{boundary=} must equal 'clamp'.")
3746  del cval
3747  a = torch.as_tensor(array)
3748  del array
3749  assert 1 <= a.ndim <= 3
3750  shape = tuple(shape)
3751  _check_eq(len(shape), 2 if a.ndim >= 2 else 1)
3752  mode = _TORCH_INTERPOLATE_MODE_FROM_FILTER[filter]
3753
3754  def local_resize(a: _TorchTensor) -> _TorchTensor:
3755    # For upsampling, BILINEAR antialias is same PSNR and slower,
3756    #  and BICUBIC antialias is worse PSNR and faster.
3757    # For downsampling, antialias improves PSNR for both BILINEAR and BICUBIC.
3758    # Default align_corners=None corresponds to False which is what we desire.
3759    return torch.nn.functional.interpolate(a, shape, mode=mode, antialias=antialias)
3760
3761  match a.ndim:
3762    case 1:
3763      shape = (1, *shape)
3764      return local_resize(a[None, None, None])[0, 0, 0]
3765    case 2:
3766      return local_resize(a[None, None])[0, 0]
3767    case _:
3768      return local_resize(a.moveaxis(2, 0)[None])[0].moveaxis(0, 2)
3769
3770
3771def _jax_image_resize(
3772    array: _ArrayLike,
3773    /,
3774    shape: Iterable[int],
3775    *,
3776    filter: str,
3777    boundary: str = 'natural',
3778    cval: float = 0.0,
3779    scale: float | Iterable[float] = 1.0,
3780    translate: float | Iterable[float] = 0.0,
3781) -> _JaxArray:
3782  """Invoke `jax.image.scale_and_translate` using the same parameters as `resize`."""
3783  import jax.image
3784  import jax.numpy as jnp
3785
3786  filters = 'triangle cubic lanczos3 lanczos5'.split()
3787  if filter not in filters:
3788    raise ValueError(f'{filter=} not in {filters=}.')
3789  if boundary != 'natural':
3790    raise ValueError(f"{boundary=} must equal 'natural'.")
3791  # When `scale` or `translate` are applied, any region outside the unit domain is assigned value 0.
3792  # To be consistent, the parameter `cval` must be zero.
3793  if scale != 1.0 and cval != 0.0:
3794    raise ValueError(f'Non-unity {scale=} requires that {cval=} be zero.')
3795  if translate != 0.0 and cval != 0.0:
3796    raise ValueError(f'Nonzero {translate=} requires that {cval=} be zero.')
3797  array2 = jnp.asarray(array)
3798  del array
3799  shape = tuple(shape)
3800  assert len(shape) <= array2.ndim
3801  completed_shape = shape + (1,) * (array2.ndim - len(shape))
3802  spatial_dims = list(range(len(shape)))
3803  scale2 = np.broadcast_to(np.array(scale), len(shape))
3804  scale2 = scale2 / np.array(array2.shape[: len(shape)]) * np.array(shape)
3805  translate2 = np.broadcast_to(np.array(translate), len(shape))
3806  translate2 = translate2 * np.array(shape)
3807  return jax.image.scale_and_translate(
3808      array2, completed_shape, spatial_dims, scale2, translate2, filter
3809  )
3810
3811
3812_CANDIDATE_RESIZERS = {
3813    'resampler.resize': resize,
3814    'PIL.Image.resize': _pil_image_resize,
3815    'cv.resize': _cv_resize,
3816    'scipy.ndimage.map_coordinates': _scipy_ndimage_resize,
3817    'skimage.transform.resize': _skimage_transform_resize,
3818    'tf.image.resize': _tf_image_resize,
3819    'torch.nn.functional.interpolate': _torch_nn_resize,
3820    'jax.image.scale_and_translate': _jax_image_resize,
3821}
3822
3823
3824def _resizer_is_available(library_function: str) -> bool:
3825  """Return whether the resizer is available as an installed package."""
3826  top_name = library_function.split('.', 1)[0]
3827  module = {'PIL': 'Pillow', 'cv': 'cv2', 'tf': 'tensorflow'}.get(top_name, top_name)
3828  return importlib.util.find_spec(module) is not None  # type: ignore[attr-defined]
3829
3830
3831_RESIZERS = {
3832    library_function: resizer
3833    for library_function, resizer in _CANDIDATE_RESIZERS.items()
3834    if _resizer_is_available(library_function)
3835}
3836
3837
3838def _find_closest_filter(filter: str, resizer: Callable[..., Any]) -> str:
3839  """Return the filter supported by `resizer` (i.e., `*_resize`) that is closest to `filter`."""
3840  match filter:
3841    case 'box_like':
3842      return {
3843          _cv_resize: 'trapezoid',
3844          _skimage_transform_resize: 'box',
3845          _tf_image_resize: 'trapezoid',
3846          _torch_nn_resize: 'trapezoid',
3847      }.get(resizer, 'box')
3848    case 'cubic_like':
3849      return {
3850          _cv_resize: 'sharpcubic',
3851          _scipy_ndimage_resize: 'cardinal3',
3852          _skimage_transform_resize: 'cardinal3',
3853          _torch_nn_resize: 'sharpcubic',
3854      }.get(resizer, 'cubic')
3855    case 'high_quality':
3856      return {
3857          _pil_image_resize: 'lanczos3',
3858          _cv_resize: 'lanczos4',
3859          _scipy_ndimage_resize: 'cardinal5',
3860          _skimage_transform_resize: 'cardinal5',
3861          _torch_nn_resize: 'sharpcubic',
3862      }.get(resizer, 'lanczos5')
3863    case _:
3864      return filter
3865
3866
3867# For Emacs:
3868# Local Variables:
3869# fill-column: 100
3870# End:
ARRAYLIBS: list[str] = ['numpy']

Array libraries supported automatically in the resize and resampling operations.

  • The library is selected automatically based on the type of the array function parameter.

  • The class _Arraylib provides library-specific implementations of needed basic functions.

  • The _arr_*() functions dispatch the _Arraylib methods based on the array type.

@dataclasses.dataclass(frozen=True)
class Gridtype(abc.ABC):
1041@dataclasses.dataclass(frozen=True)
1042class Gridtype(abc.ABC):
1043  """Abstract base class for grid-types such as `'dual'` and `'primal'`.
1044
1045  In resampling operations, the grid-type may be specified separately as `src_gridtype` for the
1046  source domain and `dst_gridtype` for the destination domain.  Moreover, the grid-type may be
1047  specified per domain dimension.
1048
1049  Examples:
1050    `resize(source, shape, gridtype='primal')`  # Sets both src and dst to be `'primal'` grids.
1051
1052    `resize(source, shape, src_gridtype=['dual', 'primal'],
1053            dst_gridtype='dual')`  # Source is `'dual'` in dim0 and `'primal'` in dim1.
1054  """
1055
1056  name: str
1057  """Gridtype name."""
1058
1059  @abc.abstractmethod
1060  def min_size(self) -> int:
1061    """Return the necessary minimum number of grid samples."""
1062
1063  @abc.abstractmethod
1064  def size_in_samples(self, size: int, /) -> int:
1065    """Return the domain size in units of inter-sample spacing."""
1066
1067  @abc.abstractmethod
1068  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1069    """Return [0.0, 1.0] coordinates given [0, size - 1] indices."""
1070
1071  @abc.abstractmethod
1072  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1073    """Return location x given coordinates [0.0, 1.0], where x == 0.0 is the first grid sample
1074    and x == size - 1.0 is the last grid sample."""
1075
1076  @abc.abstractmethod
1077  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1078    """Map integer sample indices to interior ones using boundary reflection."""
1079
1080  @abc.abstractmethod
1081  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1082    """Map integer sample indices to interior ones using wrapping."""
1083
1084  @abc.abstractmethod
1085  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1086    """Map integer sample indices to interior ones using reflect-clamp."""

Abstract base class for grid-types such as 'dual' and 'primal'.

In resampling operations, the grid-type may be specified separately as src_gridtype for the source domain and dst_gridtype for the destination domain. Moreover, the grid-type may be specified per domain dimension.

Examples:

resize(source, shape, gridtype='primal') # Sets both src and dst to be 'primal' grids.

resize(source, shape, src_gridtype=['dual', 'primal'], dst_gridtype='dual') # Source is 'dual' in dim0 and 'primal' in dim1.

GRIDTYPES: list[str] = ['dual', 'primal']

Shortcut names for the two predefined grid types (specified per dimension):

gridtype 'dual'
DualGridtype()
(default)
'primal'
PrimalGridtype()
 
Sample positions in 2D
and in 1D at different resolutions
Dual Primal
Nesting of samples across resolutions The samples positions do not nest. The even samples remain at coarser scale.
Number $N_\ell$ of samples (per-dimension) at resolution level $\ell$ $N_\ell=2^\ell$ $N_\ell=2^\ell+1$
Position of sample index $i$ within domain $[0, 1]$ $\frac{i + 0.5}{N}$ ("half-integer" coordinates) $\frac{i}{N-1}$
Image resolutions ($N_\ell\times N_\ell$) for dyadic scales $1\times1, ~~2\times2, ~~4\times4, ~~8\times8, ~\ldots$ $2\times2, ~~3\times3, ~~5\times5, ~~9\times9, ~\ldots$

See the source code for extensibility.

@dataclasses.dataclass(frozen=True)
class Boundary:
1473@dataclasses.dataclass(frozen=True)
1474class Boundary:
1475  """Domain boundary rules.  These define the reconstruction over the source domain near and beyond
1476  the domain boundaries.  The rules may be specified separately for each domain dimension."""
1477
1478  name: str = ''
1479  """Boundary rule name."""
1480
1481  coord_remap: RemapCoordinates = NoRemapCoordinates()
1482  """Modify specified coordinates prior to evaluating the reconstruction kernels."""
1483
1484  extend_samples: ExtendSamples = ReflectExtendSamples()
1485  """Define the value of each grid sample outside the unit domain as an affine combination of
1486  interior sample(s) and possibly the constant value (`cval`)."""
1487
1488  override_value: OverrideExteriorValue = NoOverrideExteriorValue()
1489  """Set the value outside some extent to a constant value (`cval`)."""
1490
1491  @property
1492  def uses_cval(self) -> bool:
1493    """True if weights may be non-affine, involving the constant value (`cval`)."""
1494    return self.extend_samples.uses_cval or self.override_value.uses_cval
1495
1496  def preprocess_coordinates(self, point: _NDArray, /) -> _NDArray:
1497    """Modify coordinates prior to evaluating the filter kernels."""
1498    # Antialiasing across the tile boundaries may be feasible but seems hard.
1499    point = self.coord_remap(point)
1500    return point
1501
1502  def apply(
1503      self, index: _NDArray, weight: _NDArray, point: _NDArray, size: int, gridtype: Gridtype, /
1504  ) -> tuple[_NDArray, _NDArray]:
1505    """Replace exterior samples by combinations of interior samples."""
1506    index, weight = self.extend_samples(index, weight, size, gridtype)
1507    self.override_reconstruction(weight, point)
1508    return index, weight
1509
1510  def override_reconstruction(self, weight: _NDArray, point: _NDArray, /) -> None:
1511    """For points outside an extent, modify weight to zero to assign `cval`."""
1512    self.override_value(weight, point)

Domain boundary rules. These define the reconstruction over the source domain near and beyond the domain boundaries. The rules may be specified separately for each domain dimension.

BOUNDARIES: list[str] = ['reflect', 'wrap', 'tile', 'clamp', 'border', 'natural', 'linear_constant', 'quadratic_constant', 'reflect_clamp', 'constant', 'linear', 'quadratic']

Shortcut names for some predefined boundary rules (as defined by _DICT_BOUNDARIES):

name a.k.a. / comments
'reflect' reflected, symm, symmetric, mirror, grid-mirror
'wrap' periodic, repeat, grid-wrap
'tile' like 'reflect' within unit domain, then tile discontinuously
'clamp' clamped, nearest, edge, clamp-to-edge, repeat last sample
'border' grid-constant, use cval for samples outside unit domain
'natural' renormalize using only interior samples, use cval outside domain
'reflect_clamp' mirror-clamp-to-edge
'constant' like 'reflect' but replace by cval outside unit domain
'linear' extrapolate from 2 last samples
'quadratic' extrapolate from 3 last samples
'linear_constant' like 'linear' but replace by cval outside unit domain
'quadratic_constant' like 'quadratic' but replace by cval outside unit domain

These boundary rules may be specified per dimension. See the source code for extensibility using the classes RemapCoordinates, ExtendSamples, and OverrideExteriorValue.

Boundary rules illustrated in 1D:

Boundary rules illustrated in 2D:

@dataclasses.dataclass(frozen=True)
class Filter(abc.ABC):
1593@dataclasses.dataclass(frozen=True)
1594class Filter(abc.ABC):
1595  """Abstract base class for filter kernel functions.
1596
1597  Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support
1598  interval [-radius, radius].  (Some sites instead define kernels over the interval [0, N]
1599  where N = 2 * radius.)
1600
1601  Portions of this code are adapted from the C++ library in
1602  https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp
1603
1604  See also https://hhoppe.com/proj/filtering/.
1605  """
1606
1607  name: str
1608  """Filter kernel name."""
1609
1610  radius: float
1611  """Max absolute value of x for which self(x) is nonzero."""
1612
1613  interpolating: bool = True
1614  """True if self(0) == 1.0 and self(i) == 0.0 for all nonzero integers i."""
1615
1616  continuous: bool = True
1617  """True if the kernel function has $C^0$ continuity."""
1618
1619  partition_of_unity: bool = True
1620  """True if the convolution of the kernel with a Dirac comb reproduces the
1621  unity function."""
1622
1623  unit_integral: bool = True
1624  """True if the integral of the kernel function is 1."""
1625
1626  requires_digital_filter: bool = False
1627  """True if the filter needs a pre/post digital filter for interpolation."""
1628
1629  @abc.abstractmethod
1630  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1631    """Return evaluation of filter kernel at locations x."""

Abstract base class for filter kernel functions.

Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support interval [-radius, radius]. (Some sites instead define kernels over the interval [0, N] where N = 2 * radius.)

Portions of this code are adapted from the C++ library in https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp

See also https://hhoppe.com/proj/filtering/.

FILTERS: list[str] = ['impulse', 'box', 'trapezoid', 'triangle', 'cubic', 'sharpcubic', 'lanczos3', 'lanczos5', 'lanczos10', 'cardinal3', 'cardinal5', 'omoms3', 'omoms5', 'hamming3', 'kaiser3', 'gaussian', 'bspline3', 'mitchell', 'narrowbox']

Shortcut names for some predefined filter kernels (specified per dimension). The names expand to:

name Filter a.k.a. / comments
'impulse' ImpulseFilter() nearest
'box' BoxFilter() non-antialiased box, e.g. ImageMagick
'trapezoid' TrapezoidFilter() area antialiasing, e.g. cv.INTER_AREA
'triangle' TriangleFilter() linear (bilinear in 2D), spline order=1
'cubic' CatmullRomFilter() catmullrom, keys, bicubic
'sharpcubic' SharpCubicFilter() cv.INTER_CUBIC, torch 'bicubic'
'lanczos3' LanczosFilter(radius=3) support window [-3, 3]
'lanczos5' LanczosFilter(radius=5) [-5, 5]
'lanczos10' LanczosFilter(radius=10) [-10, 10]
'cardinal3' CardinalBsplineFilter(degree=3) spline interpolation, order=3, GF
'cardinal5' CardinalBsplineFilter(degree=5) spline interpolation, order=5, GF
'omoms3' OmomsFilter(degree=3) non-$C^1$, [-3, 3], GF
'omoms5' OmomsFilter(degree=5) non-$C^1$, [-5, 5], GF
'hamming3' GeneralizedHammingFilter(...) (radius=3, a0=25/46)
'kaiser3' KaiserFilter(radius=3.0, beta=7.12)
'gaussian' GaussianFilter() non-interpolating, default $\sigma=1.25/3$
'bspline3' BsplineFilter(degree=3) non-interpolating
'mitchell' MitchellFilter() mitchellcubic
'narrowbox' NarrowBoxFilter() for visualization of sample positions

The comment label GF denotes a generalized filter, formed as the composition of a finitely supported kernel and a discrete inverse convolution.

Some example filter kernels:


A more extensive set of filters is presented here in the notebook, together with visualizations and analyses of the filter properties. See the source code for extensibility.

@dataclasses.dataclass(frozen=True)
class Gamma(abc.ABC):
2142@dataclasses.dataclass(frozen=True)
2143class Gamma(abc.ABC):
2144  """Abstract base class for transfer functions on sample values.
2145
2146  Image/video content is often stored using a color component transfer function.
2147  See https://en.wikipedia.org/wiki/Gamma_correction.
2148
2149  Converts between integer types and [0.0, 1.0] internal value range.
2150  """
2151
2152  name: str
2153  """Name of component transfer function."""
2154
2155  @abc.abstractmethod
2156  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2157    """Decode source sample values into floating-point, possibly nonlinearly.
2158
2159    Uint source values are mapped to the range [0.0, 1.0].
2160    """
2161
2162  @abc.abstractmethod
2163  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2164    """Encode float signal into destination samples, possibly nonlinearly.
2165
2166    Uint destination values are mapped from the range [0.0, 1.0].
2167
2168    Note that non-integer destination types are not clipped to the range [0.0, 1.0].
2169    If that is desired, it can be performed as a postprocess using `output.clip(0.0, 1.0)`.
2170    """

Abstract base class for transfer functions on sample values.

Image/video content is often stored using a color component transfer function. See https://en.wikipedia.org/wiki/Gamma_correction.

Converts between integer types and [0.0, 1.0] internal value range.

GAMMAS: list[str] = ['identity', 'power2', 'power22', 'srgb']

Shortcut names for some predefined gamma-correction schemes:

name Gamma Decoding function
(linear space from stored value)
Encoding function
(stored value from linear space)
'identity' IdentityGamma() $l = e$ $e = l$
'power2' PowerGamma(2.0) $l = e^{2.0}$ $e = l^{1/2.0}$
'power22' PowerGamma(2.2) $l = e^{2.2}$ $e = l^{1/2.2}$
'srgb' SrgbGamma() $l = \left(\left(e + 0.055\right) / 1.055\right)^{2.4}$ $e = l^{1/2.4} * 1.055 - 0.055$

See the source code for extensibility.

def resize( array: Array, /, shape: Iterable[int], *, gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, boundary: str | Boundary | Iterable[str | Boundary] = 'auto', cval: ArrayLike = 0, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, gamma: str | Gamma | None = None, src_gamma: str | Gamma | None = None, dst_gamma: str | Gamma | None = None, scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0, precision: DTypeLike | None = None, dtype: DTypeLike | None = None, dim_order: Iterable[int] | None = None, num_threads: int | Literal['auto'] = 'auto') -> Array:
2608def resize(
2609    array: _Array,
2610    /,
2611    shape: Iterable[int],
2612    *,
2613    gridtype: str | Gridtype | None = None,
2614    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2615    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2616    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2617    cval: _ArrayLike = 0.0,
2618    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2619    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2620    gamma: str | Gamma | None = None,
2621    src_gamma: str | Gamma | None = None,
2622    dst_gamma: str | Gamma | None = None,
2623    scale: float | Iterable[float] = 1.0,
2624    translate: float | Iterable[float] = 0.0,
2625    precision: _DTypeLike | None = None,
2626    dtype: _DTypeLike | None = None,
2627    dim_order: Iterable[int] | None = None,
2628    num_threads: int | Literal['auto'] = 'auto',
2629) -> _Array:
2630  """Resample `array` (a grid of sample values) onto a grid with resolution `shape`.
2631
2632  The source `array` is any object recognized by `ARRAYLIBS`.  It is interpreted as a grid
2633  with `len(shape)` domain coordinate dimensions, where each grid sample value has shape
2634  `array.shape[len(shape):]`.
2635
2636  Some examples:
2637
2638  - A grayscale image has `array.shape = height, width` and resizing it with `len(shape) == 2`
2639    produces a new image of scalar values.
2640  - An RGB image has `array.shape = height, width, 3` and resizing it with `len(shape) == 2`
2641    produces a new image of RGB values.
2642  - An 3D grid of 3x3 Jacobians has `array.shape = Z, Y, X, 3, 3` and resizing it with
2643    `len(shape) == 3` produces a new 3D grid of Jacobians.
2644
2645  This function also allows scaling and translation from the source domain to the output domain
2646  through the parameters `scale` and `translate`.  For more general transforms, see `resample`.
2647
2648  Args:
2649    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
2650      The array must have numeric type.  Its first `len(shape)` dimensions are the domain
2651      coordinate dimensions.  Each grid dimension must be at least 1 for a `'dual'` grid or
2652      at least 2 for a `'primal'` grid.
2653    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2654      `array` must have at least as many dimensions as `len(shape)`.
2655    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
2656      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
2657      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
2658    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2659      Parameters `gridtype` and `src_gridtype` cannot both be set.
2660    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2661      Parameters `gridtype` and `dst_gridtype` cannot both be set.
2662    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2663      a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses `'reflect'`
2664      for upsampling and `'clamp'` for downsampling.
2665    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
2666      onto `array.shape[len(shape):]`.  It is subject to `src_gamma`.
2667    filter: The reconstruction kernel for each dimension in `shape`, specified as either a filter
2668      name in `FILTERS` or a `Filter` instance.  It is used during upsampling (i.e., magnification).
2669    prefilter: The prefilter kernel for each dimension in `shape`, specified as either a filter
2670      name in `FILTERS` or a `Filter` instance.  It is used during downsampling
2671      (i.e., minification).  If `None`, it inherits the value of `filter`.  The default
2672      `'lanczos3'` is good for natural images.  For vector graphics images, `'trapezoid'` is better
2673      because it avoids ringing artifacts.
2674    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from
2675      `array` and when creating output grid samples.  It is specified as either a name in `GAMMAS`
2676      or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default is
2677      `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
2678      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
2679      range [0.0, 1.0].
2680    src_gamma: Component transfer function used to "decode" `array` samples.
2681      Parameters `gamma` and `src_gamma` cannot both be set.
2682    dst_gamma: Component transfer function used to "encode" the output samples.
2683      Parameters `gamma` and `dst_gamma` cannot both be set.
2684    scale: Scaling factor applied to each dimension of the source domain when it is mapped onto
2685      the destination domain.
2686    translate: Offset applied to each dimension of the scaled source domain when it is mapped onto
2687      the destination domain.
2688    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
2689      on `array.dtype` and `dtype`.
2690    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
2691      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
2692      to the uint range.
2693    dim_order: Override the automatically selected order in which the grid dimensions are resized.
2694      Must contain a permutation of `range(len(shape))`.
2695    num_threads: Used to determine multithread parallelism if `array` is from `numpy`.  If set to
2696      `'auto'`, it is selected automatically.  Otherwise, it must be a positive integer.
2697
2698  Returns:
2699    An array of the same class as the source `array`, with shape `shape + array.shape[len(shape):]`
2700      and data type `dtype`.
2701
2702  **Example of image upsampling:**
2703
2704  >>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
2705  >>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.
2706
2707  <center>
2708  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_upsampled.png"/>
2709  </center>
2710
2711  **Example of image downsampling:**
2712
2713  >>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
2714  >>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
2715  >>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
2716  >>> downsampled = resize(array, (24, 48))
2717
2718  <center>
2719  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_downsampled2.png"/>
2720  </center>
2721
2722  **Unit test:**
2723
2724  >>> result = resize([1.0, 4.0, 5.0], shape=(4,))
2725  >>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
2726  """
2727  if isinstance(array, (tuple, list)):
2728    array = np.asarray(array)
2729  arraylib = _arr_arraylib(array)
2730  array_dtype = _arr_dtype(array)
2731  if not np.issubdtype(array_dtype, np.number):
2732    raise ValueError(f'Type {array.dtype} is not numeric.')
2733  shape2 = tuple(shape)
2734  array_ndim = len(array.shape)
2735  if not 0 < len(shape2) <= array_ndim:
2736    raise ValueError(f'Shape {array.shape} cannot be resized to {shape2}.')
2737  src_shape = array.shape[: len(shape2)]
2738  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2739      gridtype, src_gridtype, dst_gridtype, len(shape2), len(shape2)
2740  )
2741  boundary2 = np.broadcast_to(np.array(boundary), len(shape2))
2742  cval = np.broadcast_to(cval, array.shape[len(shape2) :])
2743  prefilter = filter if prefilter is None else prefilter
2744  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape2))]
2745  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), len(shape2))]
2746  dtype = array_dtype if dtype is None else np.dtype(dtype)
2747  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, array_dtype, dtype)
2748  scale2 = np.broadcast_to(np.array(scale), len(shape2))
2749  translate2 = np.broadcast_to(np.array(translate), len(shape2))
2750  del (shape, src_gridtype, dst_gridtype, boundary, filter, prefilter)
2751  del (src_gamma, dst_gamma, scale, translate)
2752  precision = _get_precision(precision, [array_dtype, dtype], [])
2753  weight_precision = _real_precision(precision)
2754
2755  is_noop = (
2756      all(src == dst for src, dst in zip(src_shape, shape2))
2757      and all(gt1 == gt2 for gt1, gt2 in zip(src_gridtype2, dst_gridtype2))
2758      and all(f.interpolating for f in prefilter2)
2759      and np.all(scale2 == 1.0)
2760      and np.all(translate2 == 0.0)
2761      and src_gamma2 == dst_gamma2
2762  )
2763  if is_noop:
2764    return array
2765
2766  if dim_order is None:
2767    dim_order = _arr_best_dims_order_for_resize(array, shape2)
2768  else:
2769    dim_order = tuple(dim_order)
2770    if sorted(dim_order) != list(range(len(shape2))):
2771      raise ValueError(f'{dim_order} not a permutation of {list(range(len(shape2)))}.')
2772
2773  array = src_gamma2.decode(array, precision)
2774  cval = _arr_numpy(src_gamma2.decode(cval, precision))
2775
2776  can_use_fast_box_downsampling = (
2777      _USING_NUMBA
2778      and arraylib == 'numpy'
2779      and len(shape2) == 2
2780      and array_ndim in (2, 3)
2781      and all(src > dst for src, dst in zip(src_shape, shape2))
2782      and all(src % dst == 0 for src, dst in zip(src_shape, shape2))
2783      and all(gridtype.name == 'dual' for gridtype in src_gridtype2)
2784      and all(gridtype.name == 'dual' for gridtype in dst_gridtype2)
2785      and all(f.name in ('box', 'trapezoid') for f in prefilter2)
2786      and np.all(scale2 == 1.0)
2787      and np.all(translate2 == 0.0)
2788  )
2789  if can_use_fast_box_downsampling:
2790    assert isinstance(array, np.ndarray)  # Help mypy.
2791    array = _downsample_in_2d_using_box_filter(array, cast(Any, shape2))
2792    array = dst_gamma2.encode(array, dtype)
2793    return array
2794
2795  # Multidimensional resize can be expressed using einsum() with multiple per-dim resize matrices,
2796  # e.g., as in jax.image.resize().  A benefit is to seek the optimal order of multiplications.
2797  # However, efficiency often requires sparse resize matrices, which are unsupported in einsum().
2798  # Sparse tensors requested for tf.einsum: https://github.com/tensorflow/tensorflow/issues/43497
2799  # https://github.com/tensor-compiler/taco: C++ library that computes tensor algebra expressions
2800  # on sparse and dense tensors; however it does not interoperate with tensorflow, torch, or jax.
2801
2802  for dim in dim_order:
2803    skip_resize_on_this_dim = (
2804        shape2[dim] == array.shape[dim]
2805        and scale2[dim] == 1.0
2806        and translate2[dim] == 0.0
2807        and filter2[dim].interpolating
2808    )
2809    if skip_resize_on_this_dim:
2810      continue
2811
2812    def get_is_minification() -> bool:
2813      src_in_samples = src_gridtype2[dim].size_in_samples(array.shape[dim])
2814      dst_in_samples = dst_gridtype2[dim].size_in_samples(shape2[dim])
2815      return dst_in_samples / src_in_samples * scale2[dim] < 1.0
2816
2817    is_minification = get_is_minification()
2818    boundary_dim = boundary2[dim]
2819    if boundary_dim == 'auto':
2820      boundary_dim = 'clamp' if is_minification else 'reflect'
2821    boundary_dim = _get_boundary(boundary_dim)
2822    resize_matrix, cval_weight = _create_resize_matrix(
2823        array.shape[dim],
2824        shape2[dim],
2825        src_gridtype=src_gridtype2[dim],
2826        dst_gridtype=dst_gridtype2[dim],
2827        boundary=boundary_dim,
2828        filter=filter2[dim],
2829        prefilter=prefilter2[dim],
2830        scale=scale2[dim],
2831        translate=translate2[dim],
2832        dtype=weight_precision,
2833        arraylib=arraylib,
2834    )
2835
2836    array_dim: _Array = _arr_moveaxis(array, dim, 0)
2837    array_flat = _arr_reshape(array_dim, (array_dim.shape[0], -1))
2838    array_flat = _arr_possibly_make_contiguous(array_flat)
2839    if not is_minification and filter2[dim].requires_digital_filter:
2840      array_flat = _apply_digital_filter_1d(
2841          array_flat, src_gridtype2[dim], boundary_dim, cval, filter2[dim]
2842      )
2843
2844    array_flat = _arr_matmul_sparse_dense(resize_matrix, array_flat, num_threads=num_threads)
2845    if cval_weight is not None:
2846      cval_flat = np.broadcast_to(cval, array_dim.shape[1:]).reshape(-1)
2847      if np.issubdtype(array_dtype, np.complexfloating):
2848        cval_weight = _arr_astype(cval_weight, array_dtype)  # (Only necessary for 'tensorflow'.)
2849      array_flat += cval_weight[:, None] * cval_flat
2850
2851    if is_minification and filter2[dim].requires_digital_filter:  # use prefilter2[dim]?
2852      array_flat = _apply_digital_filter_1d(
2853          array_flat, dst_gridtype2[dim], boundary_dim, cval, filter2[dim]
2854      )
2855    array_dim = _arr_reshape(array_flat, (array_flat.shape[0], *array_dim.shape[1:]))
2856    array = _arr_moveaxis(array_dim, 0, dim)
2857
2858  array = dst_gamma2.encode(array, dtype)
2859  return array

Resample array (a grid of sample values) onto a grid with resolution shape.

The source array is any object recognized by ARRAYLIBS. It is interpreted as a grid with len(shape) domain coordinate dimensions, where each grid sample value has shape array.shape[len(shape):].

Some examples:

  • A grayscale image has array.shape = height, width and resizing it with len(shape) == 2 produces a new image of scalar values.
  • An RGB image has array.shape = height, width, 3 and resizing it with len(shape) == 2 produces a new image of RGB values.
  • An 3D grid of 3x3 Jacobians has array.shape = Z, Y, X, 3, 3 and resizing it with len(shape) == 3 produces a new 3D grid of Jacobians.

This function also allows scaling and translation from the source domain to the output domain through the parameters scale and translate. For more general transforms, see resample.

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. Its first len(shape) dimensions are the domain coordinate dimensions. Each grid dimension must be at least 1 for a 'dual' grid or at least 2 for a 'primal' grid.
  • shape: The number of grid samples in each coordinate dimension of the output array. The source array must have at least as many dimensions as len(shape).
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual' if gridtype, src_gridtype, and dst_gridtype are all kept None.
  • src_gridtype: Placement of the samples in the source domain grid for each dimension. Parameters gridtype and src_gridtype cannot both be set.
  • dst_gridtype: Placement of the samples in the output domain grid for each dimension. Parameters gridtype and dst_gridtype cannot both be set.
  • boundary: The reconstruction boundary rule for each dimension in shape, specified as either a name in BOUNDARIES or a Boundary instance. The special value 'auto' uses 'reflect' for upsampling and 'clamp' for downsampling.
  • cval: Constant value used beyond the samples by some boundary rules. It must be broadcastable onto array.shape[len(shape):]. It is subject to src_gamma.
  • filter: The reconstruction kernel for each dimension in shape, specified as either a filter name in FILTERS or a Filter instance. It is used during upsampling (i.e., magnification).
  • prefilter: The prefilter kernel for each dimension in shape, specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter. The default 'lanczos3' is good for natural images. For vector graphics images, 'trapezoid' is better because it avoids ringing artifacts.
  • gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from array and when creating output grid samples. It is specified as either a name in GAMMAS or a Gamma instance. If both array.dtype and dtype are uint, the default is 'power2'. If both are non-uint, the default is 'identity'. Otherwise, gamma or src_gamma/dst_gamma must be set. Gamma correction assumes that float values are in the range [0.0, 1.0].
  • src_gamma: Component transfer function used to "decode" array samples. Parameters gamma and src_gamma cannot both be set.
  • dst_gamma: Component transfer function used to "encode" the output samples. Parameters gamma and dst_gamma cannot both be set.
  • scale: Scaling factor applied to each dimension of the source domain when it is mapped onto the destination domain.
  • translate: Offset applied to each dimension of the scaled source domain when it is mapped onto the destination domain.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • dim_order: Override the automatically selected order in which the grid dimensions are resized. Must contain a permutation of range(len(shape)).
  • num_threads: Used to determine multithread parallelism if array is from numpy. If set to 'auto', it is selected automatically. Otherwise, it must be a positive integer.
Returns:

An array of the same class as the source array, with shape shape + array.shape[len(shape):] and data type dtype.

Example of image upsampling:

>>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
>>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.

Example of image downsampling:

>>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
>>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
>>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
>>> downsampled = resize(array, (24, 48))

Unit test:

>>> result = resize([1.0, 4.0, 5.0], shape=(4,))
>>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
def jaxjit_resize(array: Array, /, *args: Any, **kwargs: Any) -> Array:
2915def jaxjit_resize(array: _Array, /, *args: Any, **kwargs: Any) -> _Array:
2916  """Compute `resize` but with resize function jitted using Jax."""
2917  return _create_jaxjit_resize()(array, *args, **kwargs)  # pylint: disable=not-callable

Compute resize but with resize function jitted using Jax.

def uniform_resize( array: Array, /, shape: Iterable[int], *, object_fit: Literal['contain', 'cover'] = 'contain', gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, boundary: str | Boundary | Iterable[str | Boundary] = 'border', scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0, **kwargs: Any) -> Array:
2920def uniform_resize(
2921    array: _Array,
2922    /,
2923    shape: Iterable[int],
2924    *,
2925    object_fit: Literal['contain', 'cover'] = 'contain',
2926    gridtype: str | Gridtype | None = None,
2927    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2928    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2929    boundary: str | Boundary | Iterable[str | Boundary] = 'natural',  # Instead of 'auto' default.
2930    scale: float | Iterable[float] = 1.0,
2931    translate: float | Iterable[float] = 0.0,
2932    **kwargs: Any,
2933) -> _Array:
2934  """Resample `array` onto a grid with resolution `shape` but with uniform scaling.
2935
2936  Calls function `resize` with `scale` and `translate` set such that the aspect ratio of `array`
2937  is preserved.  The effect is similar to CSS `object-fit: contain`.
2938  The parameter `boundary` (whose default is changed to `'natural'`) determines the values assigned
2939  outside the source domain.
2940
2941  Args:
2942    array: Regular grid of source sample values.
2943    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2944      `array` must have at least as many dimensions as `len(shape)`.
2945    object_fit: Like CSS `object-fit`.  If `'contain'`, `array` is resized uniformly to fit within
2946      `shape`. If `'cover'`, `array` is resized to fully cover `shape`.
2947    gridtype: Placement of samples on all dimensions of both the source and output domain grids.
2948    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2949    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2950    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2951      a name in `BOUNDARIES` or a `Boundary` instance.  The default is `'natural'`, which assigns
2952      `cval` to output points that map outside the source unit domain.
2953    scale: Parameter may not be specified.
2954    translate: Parameter may not be specified.
2955    **kwargs: Additional parameters for `resize` function (including `cval`).
2956
2957  Returns:
2958    An array with shape `shape + array.shape[len(shape):]`.
2959
2960  >>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
2961  array([[0., 1., 1., 0.],
2962         [0., 1., 1., 0.]])
2963
2964  >>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
2965  array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
2966         [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
2967
2968  >>> a = np.arange(6.0).reshape(2, 3)
2969  >>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
2970  array([[0.5, 1.5],
2971         [3.5, 4.5]])
2972  """
2973  if scale != 1.0 or translate != 0.0:
2974    raise ValueError('`uniform_resize()` does not accept `scale` or `translate` parameters.')
2975  if isinstance(array, (tuple, list)):
2976    array = np.asarray(array)
2977  shape = tuple(shape)
2978  array_ndim = len(array.shape)
2979  if not 0 < len(shape) <= array_ndim:
2980    raise ValueError(f'Shape {array.shape} cannot be resized to {shape}.')
2981  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2982      gridtype, src_gridtype, dst_gridtype, len(shape), len(shape)
2983  )
2984  raw_scales = [
2985      dst_gridtype2[dim].size_in_samples(shape[dim])
2986      / src_gridtype2[dim].size_in_samples(array.shape[dim])
2987      for dim in range(len(shape))
2988  ]
2989  scale0 = {'contain': min(raw_scales), 'cover': max(raw_scales)}[object_fit]
2990  scale2 = scale0 / np.array(raw_scales)
2991  translate = (1.0 - scale2) / 2
2992  return resize(array, shape, boundary=boundary, scale=scale2, translate=translate, **kwargs)

Resample array onto a grid with resolution shape but with uniform scaling.

Calls function resize with scale and translate set such that the aspect ratio of array is preserved. The effect is similar to CSS object-fit: contain. The parameter boundary (whose default is changed to 'natural') determines the values assigned outside the source domain.

Arguments:
  • array: Regular grid of source sample values.
  • shape: The number of grid samples in each coordinate dimension of the output array. The source array must have at least as many dimensions as len(shape).
  • object_fit: Like CSS object-fit. If 'contain', array is resized uniformly to fit within shape. If 'cover', array is resized to fully cover shape.
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids.
  • src_gridtype: Placement of the samples in the source domain grid for each dimension.
  • dst_gridtype: Placement of the samples in the output domain grid for each dimension.
  • boundary: The reconstruction boundary rule for each dimension in shape, specified as either a name in BOUNDARIES or a Boundary instance. The default is 'natural', which assigns cval to output points that map outside the source unit domain.
  • scale: Parameter may not be specified.
  • translate: Parameter may not be specified.
  • **kwargs: Additional parameters for resize function (including cval).
Returns:

An array with shape shape + array.shape[len(shape):].

>>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
array([[0., 1., 1., 0.],
       [0., 1., 1., 0.]])
>>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
       [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
>>> a = np.arange(6.0).reshape(2, 3)
>>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
array([[0.5, 1.5],
       [3.5, 4.5]])
def resample( array: Array, /, coords: ArrayLike, *, gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual', boundary: str | Boundary | Iterable[str | Boundary] = 'auto', cval: ArrayLike = 0, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, gamma: str | Gamma | None = None, src_gamma: str | Gamma | None = None, dst_gamma: str | Gamma | None = None, jacobian: ArrayLike | None = None, precision: DTypeLike | None = None, dtype: DTypeLike | None = None, max_block_size: int = 40000, debug: bool = False) -> Array:
2998def resample(
2999    array: _Array,
3000    /,
3001    coords: _ArrayLike,
3002    *,
3003    gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual',
3004    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
3005    cval: _ArrayLike = 0.0,
3006    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3007    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3008    gamma: str | Gamma | None = None,
3009    src_gamma: str | Gamma | None = None,
3010    dst_gamma: str | Gamma | None = None,
3011    jacobian: _ArrayLike | None = None,
3012    precision: _DTypeLike | None = None,
3013    dtype: _DTypeLike | None = None,
3014    max_block_size: int = 40_000,
3015    debug: bool = False,
3016) -> _Array:
3017  """Interpolate `array` (a grid of samples) at specified unit-domain coordinates `coords`.
3018
3019  The last dimension of `coords` contains unit-domain coordinates at which to interpolate the
3020  domain grid samples in `array`.
3021
3022  The number of coordinates (`coords.shape[-1]`) determines how to interpret `array`: its first
3023  `coords.shape[-1]` dimensions define the grid, and the remaining dimensions describe each grid
3024  sample (e.g., scalar, vector, tensor).
3025
3026  Concretely, the grid has shape `array.shape[:coords.shape[-1]]` and each grid sample has shape
3027  `array.shape[coords.shape[-1]:]`.
3028
3029  Examples include:
3030
3031  - Resample a grayscale image with `array.shape = height, width` onto a new grayscale image with
3032    `new.shape = height2, width2` by using `coords.shape = height2, width2, 2`.
3033
3034  - Resample an RGB image with `array.shape = height, width, 3` onto a new RGB image with
3035    `new.shape = height2, width2, 3` by using `coords.shape = height2, width2, 2`.
3036
3037  - Sample an RGB image at `num` 2D points along a line segment by using `coords.shape = num, 2`.
3038
3039  - Sample an RGB image at a single 2D point by using `coords.shape = (2,)`.
3040
3041  - Sample a 3D grid of 3x3 Jacobians with `array.shape = nz, ny, nx, 3, 3` along a 2D plane by
3042    using `coords.shape = height, width, 3`.
3043
3044  - Map a grayscale image through a color map by using `array.shape = 256, 3` and
3045    `coords.shape = height, width`.
3046
3047  Args:
3048    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3049      The array must have numeric type.  The coordinate dimensions appear first, and
3050      each grid sample may have an arbitrary shape.  Each grid dimension must be at least 1 for
3051      a `'dual'` grid or at least 2 for a `'primal'` grid.
3052    coords: Grid of points at which to resample `array`.  The point coordinates are in the last
3053      dimension of `coords`.  The domain associated with the source grid is a unit hypercube,
3054      i.e. with a range [0, 1] on each coordinate dimension.  The output grid has shape
3055      `coords.shape[:-1]` and each of its grid samples has shape `array.shape[coords.shape[-1]:]`.
3056    gridtype: Placement of the samples in the source domain grid for each dimension, specified as
3057      either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`.
3058    boundary: The reconstruction boundary rule for each dimension in `coords.shape[-1]`, specified
3059      as either a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses
3060      `'reflect'` for upsampling and `'clamp'` for downsampling.
3061    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
3062      onto the shape `array.shape[coords.shape[-1]:]`.  It is subject to `src_gamma`.
3063    filter: The reconstruction kernel for each dimension in `coords.shape[-1]`, specified as either
3064      a filter name in `FILTERS` or a `Filter` instance.
3065    prefilter: The prefilter kernel for each dimension in `coords.shape[:-1]`, specified as either
3066      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3067      (i.e., minification).  If `None`, it inherits the value of `filter`.
3068    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples
3069      from `array` and when creating output grid samples.  It is specified as either a name in
3070      `GAMMAS` or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default
3071      is `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
3072      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
3073      range [0.0, 1.0].
3074    src_gamma: Component transfer function used to "decode" `array` samples.
3075      Parameters `gamma` and `src_gamma` cannot both be set.
3076    dst_gamma: Component transfer function used to "encode" the output samples.
3077      Parameters `gamma` and `dst_gamma` cannot both be set.
3078    jacobian: Optional array, which must be broadcastable onto the shape
3079      `coords.shape[:-1] + (coords.shape[-1], coords.shape[-1])`, storing for each point in the
3080      output grid the Jacobian matrix of the map from the unit output domain to the unit source
3081      domain.  If omitted, it is estimated by computing finite differences on `coords`.
3082    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3083      on `array.dtype`, `coords.dtype`, and `dtype`.
3084    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3085      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3086      to the uint range.
3087    max_block_size: If nonzero, maximum number of grid points in `coords` before the resampling
3088      evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
3089    debug: Show internal information.
3090
3091  Returns:
3092    A new sample grid of shape `coords.shape[:-1]`, represented as an array of shape
3093    `coords.shape[:-1] + array.shape[coords.shape[-1]:]`, of the same array library type as
3094    the source array.
3095
3096  **Example of resample operation:**
3097
3098  <center>
3099  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_warp_coords.png"/>
3100  </center>
3101
3102  For reference, the identity resampling for a scalar-valued grid with the default grid-type
3103  `'dual'` is:
3104
3105  >>> array = np.random.default_rng(0).random((5, 7, 3))
3106  >>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
3107  >>> new_array = resample(array, coords)
3108  >>> assert np.allclose(new_array, array)
3109
3110  It is more efficient to use the function `resize` for the special case where the `coords` are
3111  obtained as simple scaling and translation of a new regular grid over the source domain:
3112
3113  >>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
3114  >>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
3115  >>> coords = (coords - translate) / scale
3116  >>> resampled = resample(array, coords)
3117  >>> resized = resize(array, new_shape, scale=scale, translate=translate)
3118  >>> assert np.allclose(resampled, resized)
3119  """
3120  if isinstance(array, (tuple, list)):
3121    array = np.asarray(array)
3122  arraylib = _arr_arraylib(array)
3123  if len(array.shape) == 0:
3124    array = array[None]
3125  coords = np.atleast_1d(coords)
3126  if not np.issubdtype(_arr_dtype(array), np.number):
3127    raise ValueError(f'Type {array.dtype} is not numeric.')
3128  if not np.issubdtype(coords.dtype, np.floating):
3129    raise ValueError(f'Type {coords.dtype} is not floating.')
3130  array_ndim = len(array.shape)
3131  if coords.ndim == 1 and coords.shape[0] > 1 and array_ndim == 1:
3132    coords = coords[:, None]
3133  grid_ndim = coords.shape[-1]
3134  grid_shape = array.shape[:grid_ndim]
3135  sample_shape = array.shape[grid_ndim:]
3136  resampled_ndim = coords.ndim - 1
3137  resampled_shape = coords.shape[:-1]
3138  if grid_ndim > array_ndim:
3139    raise ValueError(
3140        f'There are more coordinate dimensions ({grid_ndim}) in {coords=} than in {array.shape=}.'
3141    )
3142  gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(gridtype), grid_ndim)]
3143  boundary2 = np.broadcast_to(np.array(boundary), grid_ndim).tolist()
3144  cval = np.broadcast_to(cval, sample_shape)
3145  prefilter = filter if prefilter is None else prefilter
3146  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), grid_ndim)]
3147  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), resampled_ndim)]
3148  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3149  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, _arr_dtype(array), dtype)
3150  del gridtype, boundary, filter, prefilter, src_gamma, dst_gamma
3151  if jacobian is not None:
3152    jacobian = np.broadcast_to(jacobian, resampled_shape + (coords.shape[-1],) * 2)
3153  precision = _get_precision(precision, [_arr_dtype(array), dtype], [coords.dtype])
3154  weight_precision = _real_precision(precision)
3155  coords = coords.astype(weight_precision, copy=False)
3156  is_minification = False  # Current limitation; no prefiltering!
3157  assert max_block_size >= 0 or max_block_size == _MAX_BLOCK_SIZE_RECURSING
3158  for dim in range(grid_ndim):
3159    if boundary2[dim] == 'auto':
3160      boundary2[dim] = 'clamp' if is_minification else 'reflect'
3161    boundary2[dim] = _get_boundary(boundary2[dim])
3162
3163  if max_block_size != _MAX_BLOCK_SIZE_RECURSING:
3164    array = src_gamma2.decode(array, precision)
3165    for dim in range(grid_ndim):
3166      assert not is_minification
3167      if filter2[dim].requires_digital_filter:
3168        array = _apply_digital_filter_1d(
3169            array, gridtype2[dim], boundary2[dim], cval, filter2[dim], axis=dim
3170        )
3171    cval = _arr_numpy(src_gamma2.decode(cval, precision))
3172
3173  if math.prod(resampled_shape) > max_block_size > 0:
3174    block_shape = _block_shape_with_min_size(resampled_shape, max_block_size)
3175    if debug:
3176      print(f'(resample: splitting coords into blocks {block_shape}).')
3177    coord_blocks = _split_array_into_blocks(coords, block_shape)
3178
3179    def process_block(coord_block: _NDArray) -> _Array:
3180      return resample(
3181          array,
3182          coord_block,
3183          gridtype=gridtype2,
3184          boundary=boundary2,
3185          cval=cval,
3186          filter=filter2,
3187          prefilter=prefilter2,
3188          src_gamma='identity',
3189          dst_gamma=dst_gamma2,
3190          jacobian=jacobian,
3191          precision=precision,
3192          dtype=dtype,
3193          max_block_size=_MAX_BLOCK_SIZE_RECURSING,
3194      )
3195
3196    result_blocks = _map_function_over_blocks(coord_blocks, process_block)
3197    array = _merge_array_from_blocks(result_blocks)
3198    return array
3199
3200  # A concrete example of upsampling:
3201  #   array = np.ones((5, 7, 3))  # source RGB image has height=5 width=7
3202  #   coords = np.random.default_rng(0).random((8, 9, 2))  # output RGB image has height=8 width=9
3203  #   resample(array, coords, filter=('cubic', 'lanczos3'))
3204  #   grid_shape = 5, 7  grid_ndim = 2
3205  #   resampled_shape = 8, 9  resampled_ndim = 2
3206  #   sample_shape = (3,)
3207  #   src_float_index.shape = 8, 9
3208  #   src_first_index.shape = 8, 9
3209  #   sample_index.shape = (4,) for dim == 0, then (6,) for dim == 1
3210  #   weight = [shape(8, 9, 4), shape(8, 9, 6)]
3211  #   src_index = [shape(8, 9, 4), shape(8, 9, 6)]
3212
3213  # Both:[shape(8, 9, 4), shape(8, 9, 6)]
3214  weight: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3215  src_index: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3216  uses_cval = False
3217  all_num_samples = []  # will be [4, 6]
3218
3219  for dim in range(grid_ndim):
3220    src_size = grid_shape[dim]  # scalar
3221    coords_dim = coords[..., dim]  # (8, 9)
3222    radius = filter2[dim].radius  # scalar
3223    num_samples = int(np.ceil(radius * 2))  # scalar
3224    all_num_samples.append(num_samples)
3225
3226    boundary_dim = boundary2[dim]
3227    coords_dim = boundary_dim.preprocess_coordinates(coords_dim)
3228
3229    # Sample positions mapped back to source unit domain [0, 1].
3230    src_float_index = gridtype2[dim].index_from_point(coords_dim, src_size)  # (8, 9)
3231    src_first_index = (
3232        np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
3233        - (num_samples - 1) // 2
3234    )  # (8, 9)
3235
3236    sample_index = np.arange(num_samples, dtype=np.int32)  # (4,) then (6,)
3237    src_index[dim] = src_first_index[..., None] + sample_index  # (8, 9, 4) then (8, 9, 6)
3238    if filter2[dim].name == 'trapezoid':
3239      # (It might require changing the filter radius at every sample.)
3240      raise ValueError('resample() cannot use adaptive `trapezoid` filter.')
3241    if filter2[dim].name == 'impulse':
3242      weight[dim] = np.ones_like(src_index[dim], weight_precision)
3243    else:
3244      x = src_float_index[..., None] - src_index[dim].astype(weight_precision)
3245      weight[dim] = filter2[dim](x).astype(weight_precision, copy=False)
3246      if filter2[dim].name != 'narrowbox' and (
3247          is_minification or not filter2[dim].partition_of_unity
3248      ):
3249        weight[dim] = weight[dim] / weight[dim].sum(axis=-1)[..., None]
3250
3251    src_index[dim], weight[dim] = boundary_dim.apply(
3252        src_index[dim], weight[dim], coords_dim, src_size, gridtype2[dim]
3253    )
3254    if boundary_dim.uses_cval or filter2[dim].name == 'narrowbox':
3255      uses_cval = True
3256
3257  # Gather the samples.
3258
3259  # Recall that src_index = [shape(8, 9, 4), shape(8, 9, 6)].
3260  src_index_expanded = []
3261  for dim in range(grid_ndim):
3262    src_index_dim = np.moveaxis(
3263        src_index[dim].reshape(src_index[dim].shape + (1,) * (grid_ndim - 1)),
3264        resampled_ndim,
3265        resampled_ndim + dim,
3266    )
3267    src_index_expanded.append(src_index_dim)
3268  indices = tuple(src_index_expanded)  # (shape(8, 9, 4, 1), shape(8, 9, 1, 6))
3269  samples = _arr_getitem(array, indices)  # (8, 9, 4, 6, 3)
3270
3271  # Indirectly derive samples.ndim (which is unavailable during Tensorflow grad computation).
3272  samples_ndim = resampled_ndim + grid_ndim + len(sample_shape)
3273
3274  # Compute an Einstein summation over the samples and each of the per-dimension weights.
3275
3276  def label(dims: Iterable[int]) -> str:
3277    return ''.join(chr(ord('a') + i) for i in dims)
3278
3279  operands = [samples]  # (8, 9, 4, 6, 3)
3280  assert samples_ndim < 26  # Letters 'a' through 'z'.
3281  labels = [label(range(samples_ndim))]  # ['abcde']
3282  for dim in range(grid_ndim):
3283    operands.append(weight[dim])  # (8, 9, 4), then (8, 9, 6)
3284    labels.append(label(list(range(resampled_ndim)) + [resampled_ndim + dim]))  # 'abc' then 'abd'
3285  output_label = label(
3286      list(range(resampled_ndim)) + list(range(resampled_ndim + grid_ndim, samples_ndim))
3287  )  # 'abe'
3288  subscripts = ','.join(labels) + '->' + output_label  # 'abcde,abc,abd->abe'
3289  # Starting in numpy 2.0, np.einsum() outputs np.float64 even with all np.float32 inputs;
3290  # GPT: "aligns np.einsum with other functions where intermediate calculations use higher
3291  # precision (np.float64) regardless of input type when floating-point arithmetic is involved."
3292  # we could explicitly add the parameter `dtype=precision`.
3293  array = _arr_einsum(subscripts, *operands)  # (8, 9, 3)
3294
3295  # Gathering `samples` is the memory bottleneck.  It would be ideal if the gather() and einsum()
3296  # computations could be fused.  In Jax, https://github.com/google/jax/issues/3206 suggests
3297  # that this may become possible.  In any case, for large outputs it helps to partition the
3298  # evaluation over output tiles (using max_block_size).
3299
3300  if uses_cval:
3301    cval_weight = 1.0 - np.multiply.reduce(
3302        [weight[dim].sum(axis=-1) for dim in range(resampled_ndim)]
3303    )  # (8, 9)
3304    cval_weight_reshaped = cval_weight.reshape(cval_weight.shape + (1,) * len(sample_shape))
3305    array += _make_array((cval_weight_reshaped * cval).astype(precision, copy=False), arraylib)
3306
3307  array = dst_gamma2.encode(array, dtype)
3308  return array

Interpolate array (a grid of samples) at specified unit-domain coordinates coords.

The last dimension of coords contains unit-domain coordinates at which to interpolate the domain grid samples in array.

The number of coordinates (coords.shape[-1]) determines how to interpret array: its first coords.shape[-1] dimensions define the grid, and the remaining dimensions describe each grid sample (e.g., scalar, vector, tensor).

Concretely, the grid has shape array.shape[:coords.shape[-1]] and each grid sample has shape array.shape[coords.shape[-1]:].

Examples include:

  • Resample a grayscale image with array.shape = height, width onto a new grayscale image with new.shape = height2, width2 by using coords.shape = height2, width2, 2.

  • Resample an RGB image with array.shape = height, width, 3 onto a new RGB image with new.shape = height2, width2, 3 by using coords.shape = height2, width2, 2.

  • Sample an RGB image at num 2D points along a line segment by using coords.shape = num, 2.

  • Sample an RGB image at a single 2D point by using coords.shape = (2,).

  • Sample a 3D grid of 3x3 Jacobians with array.shape = nz, ny, nx, 3, 3 along a 2D plane by using coords.shape = height, width, 3.

  • Map a grayscale image through a color map by using array.shape = 256, 3 and coords.shape = height, width.

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. The coordinate dimensions appear first, and each grid sample may have an arbitrary shape. Each grid dimension must be at least 1 for a 'dual' grid or at least 2 for a 'primal' grid.
  • coords: Grid of points at which to resample array. The point coordinates are in the last dimension of coords. The domain associated with the source grid is a unit hypercube, i.e. with a range [0, 1] on each coordinate dimension. The output grid has shape coords.shape[:-1] and each of its grid samples has shape array.shape[coords.shape[-1]:].
  • gridtype: Placement of the samples in the source domain grid for each dimension, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual'.
  • boundary: The reconstruction boundary rule for each dimension in coords.shape[-1], specified as either a name in BOUNDARIES or a Boundary instance. The special value 'auto' uses 'reflect' for upsampling and 'clamp' for downsampling.
  • cval: Constant value used beyond the samples by some boundary rules. It must be broadcastable onto the shape array.shape[coords.shape[-1]:]. It is subject to src_gamma.
  • filter: The reconstruction kernel for each dimension in coords.shape[-1], specified as either a filter name in FILTERS or a Filter instance.
  • prefilter: The prefilter kernel for each dimension in coords.shape[:-1], specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter.
  • gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from array and when creating output grid samples. It is specified as either a name in GAMMAS or a Gamma instance. If both array.dtype and dtype are uint, the default is 'power2'. If both are non-uint, the default is 'identity'. Otherwise, gamma or src_gamma/dst_gamma must be set. Gamma correction assumes that float values are in the range [0.0, 1.0].
  • src_gamma: Component transfer function used to "decode" array samples. Parameters gamma and src_gamma cannot both be set.
  • dst_gamma: Component transfer function used to "encode" the output samples. Parameters gamma and dst_gamma cannot both be set.
  • jacobian: Optional array, which must be broadcastable onto the shape coords.shape[:-1] + (coords.shape[-1], coords.shape[-1]), storing for each point in the output grid the Jacobian matrix of the map from the unit output domain to the unit source domain. If omitted, it is estimated by computing finite differences on coords.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype, coords.dtype, and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • max_block_size: If nonzero, maximum number of grid points in coords before the resampling evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
  • debug: Show internal information.
Returns:

A new sample grid of shape coords.shape[:-1], represented as an array of shape coords.shape[:-1] + array.shape[coords.shape[-1]:], of the same array library type as the source array.

Example of resample operation:

For reference, the identity resampling for a scalar-valued grid with the default grid-type 'dual' is:

>>> array = np.random.default_rng(0).random((5, 7, 3))
>>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
>>> new_array = resample(array, coords)
>>> assert np.allclose(new_array, array)

It is more efficient to use the function resize for the special case where the coords are obtained as simple scaling and translation of a new regular grid over the source domain:

>>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
>>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
>>> coords = (coords - translate) / scale
>>> resampled = resample(array, coords)
>>> resized = resize(array, new_shape, scale=scale, translate=translate)
>>> assert np.allclose(resampled, resized)
def resample_affine( array: Array, /, shape: Iterable[int], matrix: ArrayLike, *, gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, precision: DTypeLike | None = None, dtype: DTypeLike | None = None, **kwargs: Any) -> Array:
3311def resample_affine(
3312    array: _Array,
3313    /,
3314    shape: Iterable[int],
3315    matrix: _ArrayLike,
3316    *,
3317    gridtype: str | Gridtype | None = None,
3318    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3319    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3320    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3321    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3322    precision: _DTypeLike | None = None,
3323    dtype: _DTypeLike | None = None,
3324    **kwargs: Any,
3325) -> _Array:
3326  """Resample a source array using an affinely transformed grid of given shape.
3327
3328  The `matrix` transformation can be linear,
3329    `source_point = matrix @ destination_point`,
3330  or it can be affine where the last matrix column is an offset vector,
3331    `source_point = matrix @ (destination_point, 1.0)`.
3332
3333  Args:
3334    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3335      The array must have numeric type.  The number of grid dimensions is determined from
3336      `matrix.shape[0]`; the remaining dimensions are for each sample value and are all
3337      linearly interpolated.
3338    shape: Dimensions of the desired destination grid.  The number of destination grid dimensions
3339      may be different from that of the source grid.
3340    matrix: 2D array for a linear or affine transform from unit-domain destination points
3341      (in a space with `len(shape)` dimensions) into unit-domain source points (in a space with
3342      `matrix.shape[0]` dimensions).  If the matrix has `len(shape) + 1` columns, the last column
3343      is the affine offset (i.e., translation).
3344    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
3345      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
3346      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
3347    src_gridtype: Placement of samples in the source domain grid for each dimension.
3348      Parameters `gridtype` and `src_gridtype` cannot both be set.
3349    dst_gridtype: Placement of samples in the output domain grid for each dimension.
3350      Parameters `gridtype` and `dst_gridtype` cannot both be set.
3351    filter: The reconstruction kernel for each dimension in `matrix.shape[0]`, specified as either
3352      a filter name in `FILTERS` or a `Filter` instance.
3353    prefilter: The prefilter kernel for each dimension in `len(shape)`, specified as either
3354      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3355      (i.e., minification).  If `None`, it inherits the value of `filter`.
3356    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3357      on `array.dtype` and `dtype`.
3358    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3359      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3360      to the uint range.
3361    **kwargs: Additional parameters for `resample` function.
3362
3363  Returns:
3364    An array of the same class as the source `array`, representing a grid with specified `shape`,
3365    where each grid value is resampled from `array`.  Thus the shape of the returned array is
3366    `shape + array.shape[matrix.shape[0]:]`.
3367  """
3368  if isinstance(array, (tuple, list)):
3369    array = np.asarray(array)
3370  shape = tuple(shape)
3371  matrix = np.asarray(matrix)
3372  dst_ndim = len(shape)
3373  if matrix.ndim != 2:
3374    raise ValueError(f'Array {matrix} is not 2D matrix.')
3375  src_ndim = matrix.shape[0]
3376  # grid_shape = array.shape[:src_ndim]
3377  is_affine = matrix.shape[1] == dst_ndim + 1
3378  if src_ndim > len(array.shape):
3379    raise ValueError(
3380        f'Matrix {matrix} has more rows ({matrix.shape[0]}) than ndim in {array.shape=}.'
3381    )
3382  if matrix.shape[1] != dst_ndim and not is_affine:
3383    raise ValueError(
3384        f'Matrix has {matrix.shape=}, but we expect either {dst_ndim} or {dst_ndim + 1} columns.'
3385    )
3386  src_gridtype2, dst_gridtype2 = _get_gridtypes(
3387      gridtype, src_gridtype, dst_gridtype, src_ndim, dst_ndim
3388  )
3389  prefilter = filter if prefilter is None else prefilter
3390  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), src_ndim)]
3391  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), dst_ndim)]
3392  del src_gridtype, dst_gridtype, filter, prefilter
3393  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3394  precision = _get_precision(precision, [_arr_dtype(array), dtype], [])
3395  weight_precision = _real_precision(precision)
3396
3397  dst_position_list = []  # per dimension
3398  for dim in range(dst_ndim):
3399    dst_size = shape[dim]
3400    dst_index = np.arange(dst_size, dtype=weight_precision)
3401    dst_position_list.append(dst_gridtype2[dim].point_from_index(dst_index, dst_size))
3402  dst_position = np.meshgrid(*dst_position_list, indexing='ij')
3403
3404  linear_matrix = matrix[:, :-1] if is_affine else matrix
3405  src_position = np.tensordot(linear_matrix, dst_position, 1)
3406  coords = np.moveaxis(src_position, 0, -1)
3407  if is_affine:
3408    coords += matrix[:, -1]
3409
3410  # TODO: Based on grid_shape, shape, linear_matrix, and prefilter, determine a
3411  # convolution prefilter and apply it to bandlimit 'array', using boundary for padding.
3412
3413  return resample(
3414      array,
3415      coords,
3416      gridtype=src_gridtype2,
3417      filter=filter2,
3418      prefilter=prefilter2,
3419      precision=precision,
3420      dtype=dtype,
3421      **kwargs,
3422  )

Resample a source array using an affinely transformed grid of given shape.

The matrix transformation can be linear, source_point = matrix @ destination_point, or it can be affine where the last matrix column is an offset vector, source_point = matrix @ (destination_point, 1.0).

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. The number of grid dimensions is determined from matrix.shape[0]; the remaining dimensions are for each sample value and are all linearly interpolated.
  • shape: Dimensions of the desired destination grid. The number of destination grid dimensions may be different from that of the source grid.
  • matrix: 2D array for a linear or affine transform from unit-domain destination points (in a space with len(shape) dimensions) into unit-domain source points (in a space with matrix.shape[0] dimensions). If the matrix has len(shape) + 1 columns, the last column is the affine offset (i.e., translation).
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual' if gridtype, src_gridtype, and dst_gridtype are all kept None.
  • src_gridtype: Placement of samples in the source domain grid for each dimension. Parameters gridtype and src_gridtype cannot both be set.
  • dst_gridtype: Placement of samples in the output domain grid for each dimension. Parameters gridtype and dst_gridtype cannot both be set.
  • filter: The reconstruction kernel for each dimension in matrix.shape[0], specified as either a filter name in FILTERS or a Filter instance.
  • prefilter: The prefilter kernel for each dimension in len(shape), specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • **kwargs: Additional parameters for resample function.
Returns:

An array of the same class as the source array, representing a grid with specified shape, where each grid value is resampled from array. Thus the shape of the returned array is shape + array.shape[matrix.shape[0]:].

def rotation_about_center_in_2d( src_shape: ArrayLike, /, angle: float, *, new_shape: ArrayLike | None = None, scale: float = 1.0) -> np.ndarray:
3453def rotation_about_center_in_2d(
3454    src_shape: _ArrayLike,
3455    /,
3456    angle: float,
3457    *,
3458    new_shape: _ArrayLike | None = None,
3459    scale: float = 1.0,
3460) -> _NDArray:
3461  """Return the 3x3 matrix mapping destination into a source unit domain.
3462
3463  The returned matrix accounts for the possibly non-square domain shapes.
3464
3465  Args:
3466    src_shape: Resolution `(ny, nx)` of the source domain grid.
3467    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3468      onto the destination domain.
3469    new_shape: Resolution `(ny, nx)` of the destination domain grid; it defaults to `src_shape`.
3470    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3471  """
3472
3473  def translation_matrix(vector: _NDArray) -> _NDArray:
3474    matrix = np.eye(len(vector) + 1)
3475    matrix[:-1, -1] = vector
3476    return matrix
3477
3478  def scaling_matrix(scale: _NDArray) -> _NDArray:
3479    return np.diag(tuple(scale) + (1.0,))
3480
3481  def rotation_matrix_2d(angle: float) -> _NDArray:
3482    cos, sin = np.cos(angle), np.sin(angle)
3483    return np.array([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
3484
3485  src_shape = np.asarray(src_shape)
3486  new_shape = src_shape if new_shape is None else np.asarray(new_shape)
3487  _check_eq(src_shape.shape, (2,))
3488  _check_eq(new_shape.shape, (2,))
3489  half = np.array([0.5, 0.5])
3490  matrix = (
3491      translation_matrix(half)
3492      @ scaling_matrix(min(src_shape) / src_shape)
3493      @ rotation_matrix_2d(angle)
3494      @ scaling_matrix(scale * new_shape / min(new_shape))
3495      @ translation_matrix(-half)
3496  )
3497  assert np.allclose(matrix[-1], [0.0, 0.0, 1.0])
3498  return matrix

Return the 3x3 matrix mapping destination into a source unit domain.

The returned matrix accounts for the possibly non-square domain shapes.

Arguments:
  • src_shape: Resolution (ny, nx) of the source domain grid.
  • angle: Angle in radians (positive from x to y axis) applied when mapping the source domain onto the destination domain.
  • new_shape: Resolution (ny, nx) of the destination domain grid; it defaults to src_shape.
  • scale: Scaling factor applied when mapping the source domain onto the destination domain.
def rotate_image_about_center( image: np.ndarray, /, angle: float, *, new_shape: ArrayLike | None = None, scale: float = 1.0, num_rotations: int = 1, **kwargs: Any) -> np.ndarray:
3501def rotate_image_about_center(
3502    image: _NDArray,
3503    /,
3504    angle: float,
3505    *,
3506    new_shape: _ArrayLike | None = None,
3507    scale: float = 1.0,
3508    num_rotations: int = 1,
3509    **kwargs: Any,
3510) -> _NDArray:
3511  """Return a copy of `image` rotated about its center.
3512
3513  Args:
3514    image: Source grid samples; the first two dimensions are spatial (ny, nx).
3515    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3516      onto the destination domain.
3517    new_shape: Resolution `(ny, nx)` of the output grid; it defaults to `image.shape[:2]`.
3518    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3519    num_rotations: Number of rotations (each by `angle`).  Successive resamplings are useful in
3520      analyzing the filtering quality.
3521    **kwargs: Additional parameters for `resample_affine`.
3522  """
3523  new_shape = image.shape[:2] if new_shape is None else np.asarray(new_shape)
3524  matrix = rotation_about_center_in_2d(image.shape[:2], angle, new_shape=new_shape, scale=scale)
3525  for _ in range(num_rotations):
3526    image = resample_affine(image, new_shape, matrix[:-1], **kwargs)
3527  return image

Return a copy of image rotated about its center.

Arguments:
  • image: Source grid samples; the first two dimensions are spatial (ny, nx).
  • angle: Angle in radians (positive from x to y axis) applied when mapping the source domain onto the destination domain.
  • new_shape: Resolution (ny, nx) of the output grid; it defaults to image.shape[:2].
  • scale: Scaling factor applied when mapping the source domain onto the destination domain.
  • num_rotations: Number of rotations (each by angle). Successive resamplings are useful in analyzing the filtering quality.
  • **kwargs: Additional parameters for resample_affine.