
resampler: fast differentiable resizing and warping of arbitrary grids.

   1"""resampler: fast differentiable resizing and warping of arbitrary grids.
   3.. include:: ../
   6from __future__ import annotations
   8__docformat__ = 'google'
   9__version__ = '0.8.7'
  10__version_info__ = tuple(int(num) for num in __version__.split('.'))
  12from import Callable, Iterable, Sequence
  13import abc
  14import dataclasses
  15import functools
  16import importlib
  17import itertools
  18import math
  19import os
  20import sys
  21import types
  22import typing
  23from typing import Any, Generic, Literal, Union
  25import numpy as np
  26import numpy.typing
  27import scipy.interpolate
  28import scipy.linalg
  29import scipy.ndimage
  30import scipy.sparse
  31import scipy.sparse.linalg
  34def _noop_decorator(*args: Any, **kwargs: Any) -> Any:
  35  """Return function decorated with no-operation; invokable with or without args."""
  36  if len(args) != 1 or not callable(args[0]) or kwargs:
  37    return _noop_decorator  # Decorator is invoked with arguments; ignore them.
  38  func: Callable[..., Any] = args[0]
  39  return func
  43  import numba
  44except ModuleNotFoundError:
  45  numba = sys.modules['numba'] = types.ModuleType('numba')
  46  numba.njit = _noop_decorator
  47using_numba = hasattr(numba, 'jit')
  50  import jax.numpy
  51  import tensorflow as tf
  52  import torch
  54  _DType = np.dtype[Any]  # (Requires Python 3.9 or TYPE_CHECKING.)
  55  _NDArray = np.ndarray[Any, Any]
  56  _DTypeLike = numpy.DTypeLike
  57  _ArrayLike = numpy.ArrayLike
  58  _TensorflowTensor: TypeAlias = tf.Tensor
  59  _TorchTensor: TypeAlias = torch.Tensor
  60  _JaxArray: TypeAlias = jax.numpy.ndarray
  63  # Typically, create named types for use in the `pdoc` documentation.
  64  # But here, these are superseded by the declarations in __init__.pyi!
  65  _DType = Any
  66  _NDArray = Any
  67  _DTypeLike = Any
  68  _ArrayLike = Any
  69  _TensorflowTensor = Any
  70  _TorchTensor = Any
  71  _JaxArray = Any
  73_Array = TypeVar('_Array', _NDArray, _TensorflowTensor, _TorchTensor, _JaxArray)
  74_AnyArray = Union[_NDArray, _TensorflowTensor, _TorchTensor, _JaxArray]
  77def _check_eq(a: Any, b: Any, /) -> None:
  78  """If the two values or arrays are not equal, raise an exception with a useful message."""
  79  are_equal = np.all(a == b) if isinstance(a, np.ndarray) else a == b
  80  if not are_equal:
  81    raise AssertionError(f'{a!r} == {b!r}')
  84def _real_precision(dtype: _DTypeLike, /) -> _DType:
  85  """Return the type of the real part of a complex number."""
  86  return np.array([], dtype).real.dtype
  89def _complex_precision(dtype: _DTypeLike, /) -> _DType:
  90  """Return a complex type to represent a non-complex type."""
  91  return np.result_type(dtype, np.complex64)
  94def _get_precision(
  95    precision: _DTypeLike, dtypes: list[_DType], weight_dtypes: list[_DType], /
  96) -> _DType:
  97  """Return dtype based on desired precision or on data and weight types."""
  98  precision2 = np.dtype(
  99      precision if precision is not None else np.result_type(np.float32, *dtypes, *weight_dtypes)
 100  )
 101  if not np.issubdtype(precision2, np.inexact):
 102    raise ValueError(f'Precision {precision2} is not floating or complex.')
 103  check_complex = [precision2, *dtypes]
 104  is_complex = [np.issubdtype(dtype, np.complexfloating) for dtype in check_complex]
 105  if len(set(is_complex)) != 1:
 106    s_types = ','.join(str(dtype) for dtype in check_complex)
 107    raise ValueError(f'Types {s_types} must be all real or all complex.')
 108  return precision2
 111def _sinc(x: _ArrayLike, /) -> _NDArray:
 112  """Return the value `np.sinc(x)` but improved to:
 113  (1) ignore underflow that occurs at 0.0 for np.float32, and
 114  (2) output exact zero for integer input values.
 116  >>> _sinc(np.array([-3, -2, -1, 0], np.float32))
 117  array([0., 0., 0., 1.], dtype=float32)
 119  >>> _sinc(np.array([-3, -2, -1, 0]))
 120  array([0., 0., 0., 1.])
 122  >>> _sinc(0)
 123  1.0
 124  """
 125  x = np.asarray(x)
 126  x_is_scalar = x.ndim == 0
 127  with np.errstate(under='ignore'):
 128    result = np.sinc(np.atleast_1d(x))
 129    result[x == np.floor(x)] = 0.0
 130    result[x == 0] = 1.0
 131    return result.item() if x_is_scalar else result
 134def _is_symmetric(matrix: scipy.sparse.spmatrix, /, tol: float = 1e-10) -> bool:
 135  """Return True if the sparse matrix is symmetric."""
 136  norm: float = scipy.sparse.linalg.norm(matrix - matrix.T, np.inf)
 137  return norm <= tol
 140def _cache_sampled_1d_function(
 141    xmin: float,
 142    xmax: float,
 143    *,
 144    num_samples: int = 3_600,
 145    enable: bool = True,
 146) -> Callable[[Callable[[_ArrayLike], _NDArray]], Callable[[_ArrayLike], _NDArray]]:
 147  """Function decorator to linearly interpolate cached function values."""
 148  # Speed unchanged up to num_samples=12_000, then slow decrease until 100_000.
 150  def wrap_it(func: Callable[[_ArrayLike], _NDArray]) -> Callable[[_ArrayLike], _NDArray]:
 151    if not enable:
 152      return func
 154    dx = (xmax - xmin) / num_samples
 155    x = np.linspace(xmin, xmax + dx, num_samples + 2, dtype=np.float32)
 156    samples_func = func(x)
 157    assert np.all(samples_func[[0, -1, -2]] == 0.0)
 159    @functools.wraps(func)
 160    def interpolate_using_cached_samples(x: _ArrayLike) -> _NDArray:
 161      x = np.asarray(x)
 162      index_float = np.clip((x - xmin) / dx, 0.0, num_samples)
 163      index = index_float.astype(np.int64)
 164      frac = np.subtract(index_float, index, dtype=np.float32)
 165      return (1 - frac) * samples_func[index] + frac * samples_func[index + 1]
 167    return interpolate_using_cached_samples
 169  return wrap_it
 172class _DownsampleIn2dUsingBoxFilter:
 173  """Fast 2D box-filter downsampling using cached numba-jitted functions."""
 175  def __init__(self) -> None:
 176    # Downsampling function for params (dtype, block_height, block_width, ch).
 177    self._jitted_function: dict[tuple[_DType, int, int, int], Callable[[_NDArray], _NDArray]] = {}
 179  def __call__(self, array: _NDArray, shape: tuple[int, int]) -> _NDArray:
 180    assert using_numba
 181    assert array.ndim in (2, 3), array.ndim
 182    _check_eq(len(shape), 2)
 183    dtype = array.dtype
 184    a = array[..., None] if array.ndim == 2 else array
 185    height, width, ch = a.shape
 186    new_height, new_width = shape
 187    if height % new_height != 0 or width % new_width != 0:
 188      raise ValueError(f'Shape {array.shape} not a multiple of {shape}.')
 189    block_height, block_width = height // new_height, width // new_width
 191    def func(array: _NDArray) -> _NDArray:
 192      new_height = array.shape[0] // block_height
 193      new_width = array.shape[1] // block_width
 194      result = np.empty((new_height, new_width, ch), dtype)
 195      totals = np.empty(ch, dtype)
 196      factor = dtype.type(1.0 / (block_height * block_width))
 197      for y in numba.prange(new_height):  # pylint: disable=not-an-iterable
 198        for x in range(new_width):
 199          # Introducing "y2, x2 = y * block_height, x * block_width" is actually slower.
 200          if ch == 1:  # All the branches involve compile-time constants.
 201            total = dtype.type(0.0)
 202            for yy in range(block_height):
 203              for xx in range(block_width):
 204                total += array[y * block_height + yy, x * block_width + xx, 0]
 205            result[y, x, 0] = total * factor
 206          elif ch == 3:
 207            total0 = total1 = total2 = dtype.type(0.0)
 208            for yy in range(block_height):
 209              for xx in range(block_width):
 210                total0 += array[y * block_height + yy, x * block_width + xx, 0]
 211                total1 += array[y * block_height + yy, x * block_width + xx, 1]
 212                total2 += array[y * block_height + yy, x * block_width + xx, 2]
 213            result[y, x, 0] = total0 * factor
 214            result[y, x, 1] = total1 * factor
 215            result[y, x, 2] = total2 * factor
 216          elif block_height * block_width >= 9:
 217            for c in range(ch):
 218              totals[c] = 0.0
 219            for yy in range(block_height):
 220              for xx in range(block_width):
 221                for c in range(ch):
 222                  totals[c] += array[y * block_height + yy, x * block_width + xx, c]
 223            for c in range(ch):
 224              result[y, x, c] = totals[c] * factor
 225          else:
 226            for c in range(ch):
 227              total = dtype.type(0.0)
 228              for yy in range(block_height):
 229                for xx in range(block_width):
 230                  total += array[y * block_height + yy, x * block_width + xx, c]
 231              result[y, x, c] = total * factor
 232      return result
 234    signature = dtype, block_height, block_width, ch
 235    jitted_function = self._jitted_function.get(signature)
 236    if not jitted_function:
 237      if 0:
 238        print(f'Creating numba jit-wrapper for {signature}.')
 239      jitted_function = numba.njit(func, parallel=True, fastmath=True, cache=True)
 240      self._jitted_function[signature] = jitted_function
 242    try:
 243      result = jitted_function(a)
 244    except RuntimeError:
 245      message = (
 246          'resampler: This runtime error may be due to a corrupt resampler/__pycache__;'
 247          ' try deleting that directory.'
 248      )
 249      print(message, file=sys.stderr, flush=True)
 250      raise
 252    return result[..., 0] if array.ndim == 2 else result
 255_downsample_in_2d_using_box_filter = _DownsampleIn2dUsingBoxFilter()
 258@numba.njit(nogil=True, fastmath=True, cache=True)  # type: ignore[misc]
 259def _numba_serial_csr_dense_mult(
 260    indptr: _NDArray,
 261    indices: _NDArray,
 262    data: _NDArray,
 263    src: _NDArray,
 264    dst: _NDArray,
 265) -> None:
 266  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 268  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 269  """
 270  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 271  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 272  acc = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 273  for i in range(dst.shape[0]):
 274    acc[:] = 0
 275    for jj in range(indptr[i], indptr[i + 1]):
 276      j = indices[jj]
 277      acc += data[jj] * src[j]
 278    dst[i] = acc
 281# I tried using the "minimal" "parallel=" config but this did not result in any jit speedup:
 282#  parallel=dict(comprehension=False, prange=True, numpy=True, reduction=False,
 283#                setitem=False, stencil=False, fusion=False)
 284@numba.njit(parallel=True, fastmath=True, cache=True)  # type: ignore[misc]
 285def _numba_parallel_csr_dense_mult(
 286    indptr: _NDArray,
 287    indices: _NDArray,
 288    data: _NDArray,
 289    src: _NDArray,
 290    dst: _NDArray,
 291) -> None:
 292  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 294  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 295  The introduction of parallel omp threads provides another 2-4x speedup.
 296  However, "parallel=True" leads to slow jitting (~3 s), so we cache the jitted code on disk.
 297  """
 298  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 299  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 300  acc0 = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 301  # Default is static scheduling, which is fine.
 302  for i in numba.prange(dst.shape[0]):  # pylint: disable=not-an-iterable
 303    acc = np.zeros_like(acc0)  # Numba automatically hoists the allocation outside the loop.
 304    for jj in range(indptr[i], indptr[i + 1]):
 305      j = indices[jj]
 306      acc += data[jj] * src[j]
 307    dst[i] = acc
 311class _Arraylib(abc.ABC, Generic[_Array]):
 312  """Abstract base class for abstraction of array libraries."""
 314  arraylib: str
 315  """Name of array library (e.g., `'numpy'`, `'tensorflow'`, `'torch'`, `'jax'`)."""
 317  array: _Array
 319  @staticmethod
 320  @abc.abstractmethod
 321  def recognize(array: Any) -> bool:
 322    """Return True if `array` is recognized by this _Arraylib."""
 324  @abc.abstractmethod
 325  def numpy(self) -> _NDArray:
 326    """Return a `numpy` version of `self.array`."""
 328  @abc.abstractmethod
 329  def dtype(self) -> _DType:
 330    """Return the equivalent of `self.array.dtype` as a `numpy` `dtype`."""
 332  @abc.abstractmethod
 333  def astype(self, dtype: _DTypeLike) -> _Array:
 334    """Return the equivalent of `self.array.astype(dtype, copy=False)` with `numpy` `dtype`."""
 336  def reshape(self, shape: tuple[int, ...]) -> _Array:
 337    """Return the equivalent of `self.array.reshape(shape)`."""
 338    return self.array.reshape(shape)
 340  def possibly_make_contiguous(self) -> _Array:
 341    """Return a contiguous copy of `self.array` or just `self.array` if already contiguous."""
 342    return self.array
 344  @abc.abstractmethod
 345  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _Array:
 346    """Return the equivalent of `self.array.clip(low, high, dtype=dtype)` with `numpy` `dtype`."""
 348  @abc.abstractmethod
 349  def square(self) -> _Array:
 350    """Return the equivalent of `np.square(self.array)`."""
 352  @abc.abstractmethod
 353  def sqrt(self) -> _Array:
 354    """Return the equivalent of `np.sqrt(self.array)`."""
 356  def getitem(self, indices: Any) -> _Array:
 357    """Return the equivalent of `self.array[indices]` (a "gather" operation)."""
 358    return self.array[indices]
 360  @abc.abstractmethod
 361  def where(self, if_true: Any, if_false: Any) -> _Array:
 362    """Return the equivalent of `np.where(self.array, if_true, if_false)`."""
 364  @abc.abstractmethod
 365  def transpose(self, axes: Sequence[int]) -> _Array:
 366    """Return the equivalent of `np.transpose(self.array, axes)`."""
 368  @abc.abstractmethod
 369  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 370    """Return the best order in which to process dims for resizing `self.array` to `dst_shape`."""
 372  @abc.abstractmethod
 373  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _Array:
 374    """Return the multiplication of the `sparse` matrix and `self.array`."""
 376  @staticmethod
 377  @abc.abstractmethod
 378  def concatenate(arrays: Sequence[_Array], axis: int) -> _Array:
 379    """Return the equivalent of `np.concatenate(arrays, axis)`."""
 381  @staticmethod
 382  @abc.abstractmethod
 383  def einsum(subscripts: str, *operands: _Array) -> _Array:
 384    """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 386  @staticmethod
 387  @abc.abstractmethod
 388  def make_sparse_matrix(
 389      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 390  ) -> _Array:
 391    """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 392    However, the indices must be ordered and unique."""
 395class _NumpyArraylib(_Arraylib[_NDArray]):
 396  """Numpy implementation of the array abstraction."""
 398  def __init__(self, array: _NDArray) -> None:
 399    super().__init__(arraylib='numpy', array=np.asarray(array))
 401  @staticmethod
 402  def recognize(array: Any) -> bool:
 403    return isinstance(array, np.ndarray)
 405  def numpy(self) -> _NDArray:
 406    return self.array
 408  def dtype(self) -> _DType:
 409    dtype: _DType = self.array.dtype
 410    return dtype
 412  def astype(self, dtype: _DTypeLike) -> _NDArray:
 413    return self.array.astype(dtype, copy=False)
 415  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _NDArray:
 416    return self.array.clip(low, high, dtype=dtype)
 418  def square(self) -> _NDArray:
 419    return np.square(self.array)
 421  def sqrt(self) -> _NDArray:
 422    return np.sqrt(self.array)
 424  def where(self, if_true: Any, if_false: Any) -> _NDArray:
 425    condition = self.array
 426    return np.where(condition, if_true, if_false)
 428  def transpose(self, axes: Sequence[int]) -> _NDArray:
 429    return np.transpose(self.array, tuple(axes))
 431  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 432    # Our heuristics: (1) a dimension with small scaling (especially minification) gets priority,
 433    # and (2) timings show preference to resizing dimensions with larger strides first.
 434    # (Of course, tensorflow.Tensor lacks strides, so (2) does not apply.)
 435    # The optimal ordering might be related to the logic in np.einsum_path().  (Unfortunately,
 436    # np.einsum() does not support the sparse multiplications that we require here.)
 437    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 438    strides: Sequence[int] = self.array.strides
 439    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 441    def priority(dim: int) -> float:
 442      scaling = dst_shape[dim] / src_shape[dim]
 443      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 445    return sorted(range(len(src_shape)), key=priority)
 447  def premult_with_sparse(
 448      self, sparse: scipy.sparse.csr_matrix, num_threads: int | Literal['auto']
 449  ) -> _NDArray:
 450    assert self.array.ndim == sparse.ndim == 2 and sparse.shape[1] == self.array.shape[0]
 451    # Empirically faster than with default numba.config.NUMBA_NUM_THREADS (e.g., 24).
 452    if using_numba:
 453      num_threads2 = min(6, os.cpu_count()) if num_threads == 'auto' else num_threads
 454      src = np.ascontiguousarray(self.array)  # Like .ravel() in _mul_multivector().
 455      dtype = np.result_type(sparse.dtype, src.dtype)
 456      dst = np.empty((sparse.shape[0], src.shape[1]), dtype)
 457      num_scalar_multiplies = len( * src.shape[1]
 458      is_small_size = num_scalar_multiplies < 200_000
 459      if is_small_size or num_threads2 == 1:
 460        _numba_serial_csr_dense_mult(sparse.indptr, sparse.indices,, src, dst)
 461      else:
 462        numba.set_num_threads(num_threads2)
 463        _numba_parallel_csr_dense_mult(sparse.indptr, sparse.indices,, src, dst)
 464      return dst
 466    # Note that sicpy.sparse does not use multithreading.  The "@" operation
 467    # calls _spbase.__matmul__() -> _spbase._mul_dispatch() -> _cs_matrix._mul_multivector() ->
 468    # scipy.sparse._sparsetools.csr_matvecs() in
 469    #
 470    # which iteratively calls the (in theory, LEVEL 1 BLAS) function axpy() in
 471    #
 472    return sparse @ self.array
 474  @staticmethod
 475  def concatenate(arrays: Sequence[_NDArray], axis: int) -> _NDArray:
 476    return np.concatenate(arrays, axis)
 478  @staticmethod
 479  def einsum(subscripts: str, *operands: _NDArray) -> _NDArray:
 480    return np.einsum(subscripts, *operands, optimize=True)
 482  @staticmethod
 483  def make_sparse_matrix(
 484      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 485  ) -> _NDArray:
 486    return scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=shape)
 489class _TensorflowArraylib(_Arraylib[_TensorflowTensor]):
 490  """Tensorflow implementation of the array abstraction."""
 492  def __init__(self, array: _NDArray) -> None:
 493    import tensorflow
 495 = tensorflow
 496    super().__init__(arraylib='tensorflow',
 498  @staticmethod
 499  def recognize(array: Any) -> bool:
 500    # Eager: tensorflow.python.framework.ops.Tensor
 501    # Non-eager: tensorflow.python.ops.resource_variable_ops.ResourceVariable
 502    return type(array).__module__.startswith('tensorflow.')
 504  def numpy(self) -> _NDArray:
 505    return self.array.numpy()
 507  def dtype(self) -> _DType:
 508    return np.dtype(self.array.dtype.as_numpy_dtype)
 510  def astype(self, dtype: _DTypeLike) -> _TensorflowTensor:
 511    return, dtype)
 513  def reshape(self, shape: tuple[int, ...]) -> _TensorflowTensor:
 514    return, shape)
 516  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _TensorflowTensor:
 517    array = self.array
 518    if dtype is not None:
 519      array =, dtype)
 520    return, low, high)
 522  def square(self) -> _TensorflowTensor:
 523    return
 525  def sqrt(self) -> _TensorflowTensor:
 526    return
 528  def getitem(self, indices: Any) -> _TensorflowTensor:
 529    if isinstance(indices, tuple):
 530      basic = all(isinstance(x, (type(None), type(Ellipsis), int, slice)) for x in indices)
 531      if not basic:
 532        # We require tf.gather_nd(), which unfortunately requires broadcast expansion of indices.
 533        assert all(isinstance(a, np.ndarray) for a in indices)
 534        assert all(a.ndim == indices[0].ndim for a in indices)
 535        broadcast_indices = np.broadcast_arrays(*indices)  # list of np.ndarray
 536        indices_array = np.moveaxis(np.array(broadcast_indices), 0, -1)
 537        return, indices_array)
 538    elif _arr_dtype(indices).type in (np.uint8, np.uint16):
 539      indices =, np.int32)
 540    return, indices)
 542  def where(self, if_true: Any, if_false: Any) -> _TensorflowTensor:
 543    condition = self.array
 544    return, if_true, if_false)
 546  def transpose(self, axes: Sequence[int]) -> _TensorflowTensor:
 547    return, tuple(axes))
 549  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 550    # Note that a tensorflow.Tensor does not have strides.
 551    # Our heuristic is to process dimension 1 first iff dimension 0 is upsampling.  Improve?
 552    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 553    dims = list(range(len(src_shape)))
 554    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 555      dims[:2] = [1, 0]
 556    return dims
 558  def premult_with_sparse(
 559      self, sparse: tf.sparse.SparseTensor, num_threads: int | Literal['auto']
 560  ) -> _TensorflowTensor:
 561    import tensorflow as tf
 563    del num_threads
 564    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 565      sparse = sparse.with_values(_arr_astype(sparse.values, _arr_dtype(self.array)))
 566    return tf.sparse.sparse_dense_matmul(sparse, self.array)
 568  @staticmethod
 569  def concatenate(arrays: Sequence[_TensorflowTensor], axis: int) -> _TensorflowTensor:
 570    import tensorflow as tf
 572    return tf.concat(arrays, axis)
 574  @staticmethod
 575  def einsum(subscripts: str, *operands: _TensorflowTensor) -> _TensorflowTensor:
 576    import tensorflow as tf
 578    return tf.einsum(subscripts, *operands, optimize='greedy')
 580  @staticmethod
 581  def make_sparse_matrix(
 582      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 583  ) -> _TensorflowTensor:
 584    import tensorflow as tf
 586    indices = np.vstack((row_ind, col_ind)).T
 587    return tf.sparse.SparseTensor(indices, data, shape)
 590# pylint: disable=missing-function-docstring
 593class _TorchArraylib(_Arraylib[_TorchTensor]):
 594  """Torch implementation of the array abstraction."""
 596  def __init__(self, array: _NDArray) -> None:
 597    import torch
 599    self.torch = torch
 600    super().__init__(arraylib='torch', array=self.torch.as_tensor(array))
 602  @staticmethod
 603  def recognize(array: Any) -> bool:
 604    return type(array).__module__ == 'torch'
 606  def numpy(self) -> _NDArray:
 607    return self.array.numpy()
 609  def dtype(self) -> _DType:
 610    numpy_type = {
 611        self.torch.float32: np.float32,
 612        self.torch.float64: np.float64,
 613        self.torch.complex64: np.complex64,
 614        self.torch.complex128: np.complex128,
 615        self.torch.uint8: np.uint8,  # No uint16, uint32, uint64.
 616        self.torch.int16: np.int16,
 617        self.torch.int32: np.int32,
 618        self.torch.int64: np.int64,
 619    }[self.array.dtype]
 620    return np.dtype(numpy_type)
 622  def astype(self, dtype: _DTypeLike) -> _TorchTensor:
 623    torch_type = {
 624        np.float32: self.torch.float32,
 625        np.float64: self.torch.float64,
 626        np.complex64: self.torch.complex64,
 627        np.complex128: self.torch.complex128,
 628        np.uint8: self.torch.uint8,  # No uint16, uint32, uint64.
 629        np.int16: self.torch.int16,
 630        np.int32: self.torch.int32,
 631        np.int64: self.torch.int64,
 632    }[np.dtype(dtype).type]
 633    return self.array.type(torch_type)
 635  def possibly_make_contiguous(self) -> _TorchTensor:
 636    return self.array.contiguous()
 638  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _TorchTensor:
 639    array = self.array
 640    array = _arr_astype(array, dtype) if dtype is not None else array
 641    return array.clip(low, high)
 643  def square(self) -> _TorchTensor:
 644    return self.array.square()
 646  def sqrt(self) -> _TorchTensor:
 647    return self.array.sqrt()
 649  def getitem(self, indices: Any) -> _TorchTensor:
 650    if not isinstance(indices, tuple):
 651      indices = indices.type(self.torch.int64)
 652    return self.array[indices]  # pylint: disable=unsubscriptable-object
 654  def where(self, if_true: Any, if_false: Any) -> _TorchTensor:
 655    condition = self.array
 656    return if_true.where(condition, if_false)
 658  def transpose(self, axes: Sequence[int]) -> _TorchTensor:
 659    return self.torch.permute(self.array, tuple(axes))
 661  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 662    # Similar to `_NumpyArraylib`.  We access `array.stride()` instead of `array.strides`.
 663    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 664    strides: Sequence[int] = self.array.stride()
 665    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 667    def priority(dim: int) -> float:
 668      scaling = dst_shape[dim] / src_shape[dim]
 669      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 671    return sorted(range(len(src_shape)), key=priority)
 673  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _TorchTensor:
 674    del num_threads
 675    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 676      sparse = _arr_astype(sparse, _arr_dtype(self.array))
 677    return sparse @ self.array  # Calls
 679  @staticmethod
 680  def concatenate(arrays: Sequence[_TorchTensor], axis: int) -> _TorchTensor:
 681    import torch
 683    return, axis)
 685  @staticmethod
 686  def einsum(subscripts: str, *operands: _TorchTensor) -> _TorchTensor:
 687    import torch
 689    operands = tuple(torch.as_tensor(operand) for operand in operands)
 690    if any(np.issubdtype(_arr_dtype(array), np.complexfloating) for array in operands):
 691      operands = tuple(
 692          _arr_astype(array, _complex_precision(_arr_dtype(array))) for array in operands
 693      )
 694    return torch.einsum(subscripts, *operands)
 696  @staticmethod
 697  def make_sparse_matrix(
 698      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 699  ) -> _TorchTensor:
 700    import torch
 702    indices = np.vstack((row_ind, col_ind))
 703    return torch.sparse_coo_tensor(torch.as_tensor(indices), torch.as_tensor(data), shape)
 704    # .coalesce() is unnecessary because indices/data are already merged.
 707# pylint: enable=missing-function-docstring
 710class _JaxArraylib(_Arraylib[_JaxArray]):
 711  """Jax implementation of the array abstraction."""
 713  def __init__(self, array: _NDArray) -> None:
 714    import jax.numpy
 716    self.jnp = jax.numpy
 717    super().__init__(arraylib='jax', array=self.jnp.asarray(array))
 719  @staticmethod
 720  def recognize(array: Any) -> bool:
 721    # e.g., jaxlib.xla_extension.DeviceArray,
 722    return type(array).__module__.startswith(('jaxlib.', 'jax.'))
 724  def numpy(self) -> _NDArray:
 725    # 2023-01-09: jax 0.3.17 "DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead."
 726    # Whereas array.to_py() and np.asarray(array) may return a non-writable np.ndarray,
 727    # np.array(array) always returns a writable array but the copy may be more costly.
 728    # return self.array.to_py()
 729    return np.asarray(self.array)
 731  def dtype(self) -> _DType:
 732    return np.dtype(self.array.dtype)
 734  def astype(self, dtype: _DTypeLike) -> _JaxArray:
 735    return self.array.astype(dtype)  # (copy=False is unavailable)
 737  def possibly_make_contiguous(self) -> _JaxArray:
 738    return self.array.copy()
 740  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _JaxArray:
 741    array = self.array
 742    if dtype is not None:
 743      array = array.astype(dtype)  # (copy=False is unavailable)
 744    return self.jnp.clip(array, low, high)
 746  def square(self) -> _JaxArray:
 747    return self.jnp.square(self.array)
 749  def sqrt(self) -> _JaxArray:
 750    return self.jnp.sqrt(self.array)
 752  def where(self, if_true: Any, if_false: Any) -> _JaxArray:
 753    condition = self.array
 754    return self.jnp.where(condition, if_true, if_false)
 756  def transpose(self, axes: Sequence[int]) -> _JaxArray:
 757    return self.jnp.transpose(self.array, tuple(axes))
 759  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 760    # Jax/XLA does not have strides.  Arrays are contiguous, almost always in C order; see
 761    #
 762    # We use a heuristic similar to `_TensorflowArraylib`.
 763    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 764    dims = list(range(len(src_shape)))
 765    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 766      dims[:2] = [1, 0]
 767    return dims
 769  def premult_with_sparse(
 770      self, sparse: jax.experimental.sparse.BCOO, num_threads: int | Literal['auto']
 771  ) -> _JaxArray:
 772    del num_threads
 773    return sparse @ self.array  # Calls jax.bcoo_multiply_dense().
 775  @staticmethod
 776  def concatenate(arrays: Sequence[_JaxArray], axis: int) -> _JaxArray:
 777    import jax.numpy as jnp
 779    return jnp.concatenate(arrays, axis)
 781  @staticmethod
 782  def einsum(subscripts: str, *operands: _JaxArray) -> _JaxArray:
 783    import jax.numpy as jnp
 785    return jnp.einsum(subscripts, *operands, optimize='greedy')
 787  @staticmethod
 788  def make_sparse_matrix(
 789      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 790  ) -> _JaxArray:
 791    #
 792    import jax.experimental.sparse
 794    indices = np.vstack((row_ind, col_ind)).T
 795    return jax.experimental.sparse.BCOO(
 796        (data, indices), shape=shape, indices_sorted=True, unique_indices=True
 797    )
 801    'numpy': _NumpyArraylib,
 802    'tensorflow': _TensorflowArraylib,
 803    'torch': _TorchArraylib,
 804    'jax': _JaxArraylib,
 808def is_available(arraylib: str) -> bool:
 809  """Return whether the array library (e.g. 'tensorflow') is available as an installed package."""
 810  return importlib.util.find_spec(arraylib) is not None  # Faster than trying to import it.
 814    arraylib: cls for arraylib, cls in _CANDIDATE_ARRAYLIBS.items() if is_available(arraylib)
 818"""Array libraries supported automatically in the resize and resampling operations.
 820- The library is selected automatically based on the type of the `array` function parameter.
 822- The class `_Arraylib` provides library-specific implementations of needed basic functions.
 824- The `_arr_*()` functions dispatch the `_Arraylib` methods based on the array type.
 828def _as_arr(array: _Array, /) -> _Arraylib[_Array]:
 829  """Return `array` wrapped as an `_Arraylib` for dispatch of functions."""
 830  for cls in _DICT_ARRAYLIBS.values():
 831    if cls.recognize(array):
 832      return cls(array)  # type: ignore[abstract]
 833  raise ValueError(f'{array} {type(array)} {type(array).__module__} unrecognized by {ARRAYLIBS}.')
 836def _arr_arraylib(array: _Array, /) -> str:
 837  """Return the name of the `Arraylib` representing `array`."""
 838  return _as_arr(array).arraylib
 841def _arr_numpy(array: _Array, /) -> _NDArray:
 842  """Return a `numpy` version of `array`."""
 843  return _as_arr(array).numpy()
 846def _arr_dtype(array: _Array, /) -> _DType:
 847  """Return the equivalent of `array.dtype` as a `numpy` `dtype`."""
 848  return _as_arr(array).dtype()
 851def _arr_astype(array: _Array, dtype: _DTypeLike, /) -> _Array:
 852  """Return the equivalent of `array.astype(dtype)` with `numpy` `dtype`."""
 853  return _as_arr(array).astype(dtype)
 856def _arr_reshape(array: _Array, shape: tuple[int, ...], /) -> _Array:
 857  """Return the equivalent of `array.reshape(shape)."""
 858  return _as_arr(array).reshape(shape)
 861def _arr_possibly_make_contiguous(array: _Array, /) -> _Array:
 862  """Return a contiguous copy of `array` or just `array` if already contiguous."""
 863  return _as_arr(array).possibly_make_contiguous()
 866def _arr_clip(array: _Array, low: _Array, high: _Array, /, dtype: _DTypeLike = None) -> _Array:
 867  """Return the equivalent of `array.clip(low, high, dtype)` with `numpy` `dtype`."""
 868  return _as_arr(array).clip(low, high, dtype)
 871def _arr_square(array: _Array, /) -> _Array:
 872  """Return the equivalent of `np.square(array)`."""
 873  return _as_arr(array).square()
 876def _arr_sqrt(array: _Array, /) -> _Array:
 877  """Return the equivalent of `np.sqrt(array)`."""
 878  return _as_arr(array).sqrt()
 881def _arr_getitem(array: _Array, indices: _Array, /) -> _Array:
 882  """Return the equivalent of `array[indices]`."""
 883  return _as_arr(array).getitem(indices)
 886def _arr_where(condition: _Array, if_true: _Array, if_false: _Array, /) -> _Array:
 887  """Return the equivalent of `np.where(condition, if_true, if_false)`."""
 888  return _as_arr(condition).where(if_true, if_false)
 891def _arr_transpose(array: _Array, axes: Sequence[int], /) -> _Array:
 892  """Return the equivalent of `np.transpose(array, axes)`."""
 893  return _as_arr(array).transpose(axes)
 896def _arr_best_dims_order_for_resize(array: _Array, dst_shape: tuple[int, ...], /) -> list[int]:
 897  """Return the best order in which to process dims for resizing `array` to `dst_shape`."""
 898  return _as_arr(array).best_dims_order_for_resize(dst_shape)
 901def _arr_matmul_sparse_dense(
 902    sparse: Any, dense: _Array, /, *, num_threads: int | Literal['auto'] = 'auto'
 903) -> _Array:
 904  """Return the multiplication of the `sparse` and `dense` matrices."""
 905  assert num_threads == 'auto' or num_threads >= 1
 906  return _as_arr(dense).premult_with_sparse(sparse, num_threads)
 909def _arr_concatenate(arrays: Sequence[_Array], axis: int, /) -> _Array:
 910  """Return the equivalent of `np.concatenate(arrays, axis)`."""
 911  arraylib = _arr_arraylib(arrays[0])
 912  return _DICT_ARRAYLIBS[arraylib].concatenate(arrays, axis)  # type: ignore
 915def _arr_einsum(subscripts: str, /, *operands: _Array) -> _Array:
 916  """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 917  arraylib = _arr_arraylib(operands[0])
 918  return _DICT_ARRAYLIBS[arraylib].einsum(subscripts, *operands)  # type: ignore
 921def _arr_swapaxes(array: _Array, axis1: int, axis2: int, /) -> _Array:
 922  """Return the equivalent of `np.swapaxes(array, axis1, axis2)`."""
 923  ndim = len(array.shape)
 924  assert 0 <= axis1 < ndim and 0 <= axis2 < ndim, (axis1, axis2, ndim)
 925  axes = list(range(ndim))
 926  axes[axis1] = axis2
 927  axes[axis2] = axis1
 928  return _arr_transpose(array, axes)
 931def _arr_moveaxis(array: _Array, source: int, destination: int, /) -> _Array:
 932  """Return the equivalent of `np.moveaxis(array, source, destination)`."""
 933  ndim = len(array.shape)
 934  assert 0 <= source < ndim and 0 <= destination < ndim, (source, destination, ndim)
 935  axes = [n for n in range(ndim) if n != source]
 936  axes.insert(destination, source)
 937  return _arr_transpose(array, axes)
 940def _make_sparse_matrix(
 941    data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int], arraylib: str, /
 942) -> Any:
 943  """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 944  However, indices must be ordered and unique."""
 945  return _DICT_ARRAYLIBS[arraylib].make_sparse_matrix(data, row_ind, col_ind, shape)
 948def _make_array(array: _ArrayLike, arraylib: str, /) -> Any:
 949  """Return an array from the library `arraylib` initialized with the `numpy` `array`."""
 950  return _DICT_ARRAYLIBS[arraylib](np.asarray(array)).array  # type: ignore[abstract]
 953# Because np.ndarray supports strides, np.moveaxis() and np.permute() are constant-time.
 954# However, ndarray.reshape() often creates a copy of the array if the data is non-contiguous,
 955# e.g. dim=1 in an RGB image.
 957# In contrast, tf.Tensor does not support strides, so tf.transpose() returns a new permuted
 958# tensor.  However, tf.reshape() is always efficient.
 961def _block_shape_with_min_size(
 962    shape: tuple[int, ...], min_size: int, compact: bool = True
 963) -> tuple[int, ...]:
 964  """Return shape of block (of size at least `min_size`) to subdivide shape."""
 965  if < min_size:
 966    raise ValueError(f'Shape {shape} smaller than min_size {min_size}.')
 967  if compact:
 968    root = int(math.ceil(min_size ** (1 / len(shape))))
 969    block_shape = np.minimum(shape, root)
 970    for dim in range(len(shape)):
 971      if block_shape[dim] == 2 and >= min_size * 2:
 972        block_shape[dim] = 1
 973    for dim in range(len(shape) - 1, -1, -1):
 974      if < min_size:
 975        block_shape[dim] = shape[dim]
 976  else:
 977    block_shape = np.ones_like(shape)
 978    for dim in range(len(shape) - 1, -1, -1):
 979      if < min_size:
 980        block_shape[dim] = min(shape[dim], math.ceil(min_size /
 981  return tuple(block_shape)
 984def _array_split(array: _Array, axis: int, num_sections: int) -> list[Any]:
 985  """Split `array` into `num_sections` along `axis`."""
 986  assert 0 <= axis < len(array.shape)
 987  assert 1 <= num_sections <= array.shape[axis]
 989  if 0:
 990    split = np.array_split(array, num_sections, axis=axis)  # Numpy-specific.
 992  else:
 993    # Adapted from
 994    num_total = array.shape[axis]
 995    num_each, num_extra = divmod(num_total, num_sections)
 996    section_sizes = [0] + num_extra * [num_each + 1] + (num_sections - num_extra) * [num_each]
 997    div_points = np.array(section_sizes).cumsum()
 998    split = []
 999    tmp = _arr_swapaxes(array, axis, 0)
1000    for i in range(num_sections):
1001      split.append(_arr_swapaxes(tmp[div_points[i] : div_points[i + 1]], axis, 0))
1003  return split
1006def _split_array_into_blocks(array: _Array, block_shape: Sequence[int], start_axis: int = 0) -> Any:
1007  """Split `array` into nested lists of blocks of size at most `block_shape`."""
1008  # See  (If the block_shape is known to
1009  # exactly partition the array, see
1010  if len(block_shape) > len(array.shape):
1011    raise ValueError(f'Block ndim {len(block_shape)} > array ndim {len(array.shape)}.')
1012  if start_axis == len(block_shape):
1013    return array
1015  num_sections = math.ceil(array.shape[start_axis] / block_shape[start_axis])
1016  split = _array_split(array, start_axis, num_sections)
1017  return [_split_array_into_blocks(split_a, block_shape, start_axis + 1) for split_a in split]
1020def _map_function_over_blocks(blocks: Any, func: Callable[[Any], Any]) -> Any:
1021  """Apply `func` to each block in the nested lists of `blocks`."""
1022  if isinstance(blocks, list):
1023    return [_map_function_over_blocks(block, func) for block in blocks]
1024  return func(blocks)
1027def _merge_array_from_blocks(blocks: Any, axis: int = 0) -> Any:
1028  """Merge an array from the nested lists of array blocks in `blocks`."""
1029  # More general than np.block() because the blocks can have additional dims.
1030  if isinstance(blocks, list):
1031    new_blocks = [_merge_array_from_blocks(block, axis + 1) for block in blocks]
1032    return _arr_concatenate(new_blocks, axis)
1033  return blocks
1037class Gridtype(abc.ABC):
1038  """Abstract base class for grid-types such as `'dual'` and `'primal'`.
1040  In resampling operations, the grid-type may be specified separately as `src_gridtype` for the
1041  source domain and `dst_gridtype` for the destination domain.  Moreover, the grid-type may be
1042  specified per domain dimension.
1044  Examples:
1045    `resize(source, shape, gridtype='primal')`  # Sets both src and dst to be `'primal'` grids.
1047    `resize(source, shape, src_gridtype=['dual', 'primal'],
1048            dst_gridtype='dual')`  # Source is `'dual'` in dim0 and `'primal'` in dim1.
1049  """
1051  name: str
1052  """Gridtype name."""
1054  @abc.abstractmethod
1055  def min_size(self) -> int:
1056    """Return the necessary minimum number of grid samples."""
1058  @abc.abstractmethod
1059  def size_in_samples(self, size: int, /) -> int:
1060    """Return the domain size in units of inter-sample spacing."""
1062  @abc.abstractmethod
1063  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1064    """Return [0.0, 1.0] coordinates given [0, size - 1] indices."""
1066  @abc.abstractmethod
1067  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1068    """Return location x given coordinates [0.0, 1.0], where x == 0.0 is the first grid sample
1069    and x == size - 1.0 is the last grid sample."""
1071  @abc.abstractmethod
1072  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1073    """Map integer sample indices to interior ones using boundary reflection."""
1075  @abc.abstractmethod
1076  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1077    """Map integer sample indices to interior ones using wrapping."""
1079  @abc.abstractmethod
1080  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1081    """Map integer sample indices to interior ones using reflect-clamp."""
1084class DualGridtype(Gridtype):
1085  """Samples are at the center of cells in a uniform partition of the domain.
1087  For a unit-domain dimension with N samples, each sample 0 <= i < N has position (i + 0.5) / N,
1088  e.g., [0.125, 0.375, 0.625, 0.875] for N = 4.
1089  """
1091  def __init__(self) -> None:
1092    super().__init__(name='dual')
1094  def min_size(self) -> int:
1095    return 1
1097  def size_in_samples(self, size: int, /) -> int:
1098    return size
1100  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1101    return (index + 0.5) / size
1103  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1104    return point * size - 0.5
1106  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1107    index = np.mod(index, size * 2)
1108    return np.where(index < size, index, 2 * size - 1 - index)
1110  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1111    return np.mod(index, size)
1113  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1114    return np.minimum(np.where(index < 0, -1 - index, index), size - 1)
1117class PrimalGridtype(Gridtype):
1118  """Samples are at the vertices of cells in a uniform partition of the domain.
1120  For a unit-domain dimension with N samples, each sample 0 <= i < N has position i / (N - 1),
1121  e.g., [0, 1/3, 2/3, 1] for N = 4.
1122  """
1124  def __init__(self) -> None:
1125    super().__init__(name='primal')
1127  def min_size(self) -> int:
1128    return 2
1130  def size_in_samples(self, size: int, /) -> int:
1131    return size - 1
1133  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1134    return index / (size - 1)
1136  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1137    return point * (size - 1)
1139  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1140    index = np.mod(index, size * 2 - 2)
1141    return np.where(index < size, index, 2 * size - 2 - index)
1143  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1144    return np.mod(index, size - 1)
1146  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1147    return np.minimum(np.abs(index), size - 1)
1151    'dual': DualGridtype(),
1152    'primal': PrimalGridtype(),
1156r"""Shortcut names for the two predefined grid types (specified per dimension):
1158| `gridtype` | `'dual'`<br/>`DualGridtype()`<br/>(default) | `'primal'`<br/>`PrimalGridtype()`<br/>&nbsp; |
1159| --- |:---:|:---:|
1160| Sample positions in 2D<br/>and in 1D at different resolutions | ![Dual]( | ![Primal]( |
1161| Nesting of samples across resolutions | The samples positions do *not* nest. | The *even* samples remain at coarser scale. |
1162| Number $N_\ell$ of samples (per-dimension) at resolution level $\ell$ | $N_\ell=2^\ell$ | $N_\ell=2^\ell+1$ |
1163| Position of sample index $i$ within domain $[0, 1]$ | $\frac{i + 0.5}{N}$ ("half-integer" coordinates) | $\frac{i}{N-1}$ |
1164| Image resolutions ($N_\ell\times N_\ell$) for dyadic scales | $1\times1, ~~2\times2, ~~4\times4, ~~8\times8, ~\ldots$ | $2\times2, ~~3\times3, ~~5\times5, ~~9\times9, ~\ldots$ |
1166See the source code for extensibility.
1170def _get_gridtype(gridtype: str | Gridtype) -> Gridtype:
1171  """Return a `Gridtype`, which can be specified as a name in `GRIDTYPES`."""
1172  return gridtype if isinstance(gridtype, Gridtype) else _DICT_GRIDTYPES[gridtype]
1175def _get_gridtypes(
1176    gridtype: str | Gridtype | None,
1177    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1178    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1179    src_ndim: int,
1180    dst_ndim: int,
1181) -> tuple[list[Gridtype], list[Gridtype]]:
1182  """Return per-dim source and destination grid types given all parameters."""
1183  if gridtype is None and src_gridtype is None and dst_gridtype is None:
1184    gridtype = 'dual'
1185  if gridtype is not None:
1186    if src_gridtype is not None:
1187      raise ValueError('Cannot have both gridtype and src_gridtype.')
1188    if dst_gridtype is not None:
1189      raise ValueError('Cannot have both gridtype and dst_gridtype.')
1190    src_gridtype = dst_gridtype = gridtype
1191  src_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(src_gridtype), src_ndim)]
1192  dst_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(dst_gridtype), dst_ndim)]
1193  return src_gridtype2, dst_gridtype2
1197class RemapCoordinates(abc.ABC):
1198  """Abstract base class for modifying the specified coordinates prior to evaluating the
1199  reconstruction kernels."""
1201  @abc.abstractmethod
1202  def __call__(self, point: _NDArray, /) -> _NDArray:
1203    ...
1206class NoRemapCoordinates(RemapCoordinates):
1207  """The coordinates are not remapped."""
1209  def __call__(self, point: _NDArray, /) -> _NDArray:
1210    return point
1213class MirrorRemapCoordinates(RemapCoordinates):
1214  """The coordinates are reflected across the domain boundaries so that they lie in the unit
1215  interval.  The resulting function is continuous but not smooth across the boundaries."""
1217  def __call__(self, point: _NDArray, /) -> _NDArray:
1218    point = np.mod(point, 2.0)
1219    return np.where(point >= 1.0, 2.0 - point, point)
1222class TileRemapCoordinates(RemapCoordinates):
1223  """The coordinates are mapped to the unit interval using a "modulo 1.0" operation.  The resulting
1224  function is generally discontinuous across the domain boundaries."""
1226  def __call__(self, point: _NDArray, /) -> _NDArray:
1227    return np.mod(point, 1.0)
1231class ExtendSamples(abc.ABC):
1232  """Abstract base class for replacing references to grid samples exterior to the unit domain by
1233  affine combinations of interior sample(s) and possibly the constant value (`cval`)."""
1235  uses_cval: bool = False
1236  """True if some exterior samples are defined in terms of `cval`, i.e., if the computed weight
1237  is non-affine."""
1239  @abc.abstractmethod
1240  def __call__(
1241      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1242  ) -> tuple[_NDArray, _NDArray]:
1243    """Detect references to exterior samples, i.e., entries of `index` that lie outside the
1244    interval [0, size), and update these indices (and possibly their associated weights) to
1245    reference only interior samples.  Return `new_index, new_weight`."""
1248class ReflectExtendSamples(ExtendSamples):
1249  """Find the interior sample by reflecting across domain boundaries."""
1251  def __call__(
1252      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1253  ) -> tuple[_NDArray, _NDArray]:
1254    index = gridtype.reflect(index, size)
1255    return index, weight
1258class WrapExtendSamples(ExtendSamples):
1259  """Wrap the interior samples periodically.  For a `'primal'` grid, the last
1260  sample is ignored as its value is replaced by the first sample."""
1262  def __call__(
1263      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1264  ) -> tuple[_NDArray, _NDArray]:
1265    index = gridtype.wrap(index, size)
1266    return index, weight
1269class ClampExtendSamples(ExtendSamples):
1270  """Use the nearest interior sample."""
1272  def __call__(
1273      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1274  ) -> tuple[_NDArray, _NDArray]:
1275    index = index.clip(0, size - 1)
1276    return index, weight
1279class ReflectClampExtendSamples(ExtendSamples):
1280  """Extend the grid samples from [0, 1] into [-1, 0] using reflection and then define grid
1281  samples outside [-1, 1] as that of the nearest sample."""
1283  def __call__(
1284      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1285  ) -> tuple[_NDArray, _NDArray]:
1286    index = gridtype.reflect_clamp(index, size)
1287    return index, weight
1290class BorderExtendSamples(ExtendSamples):
1291  """Let all exterior samples have the constant value (`cval`)."""
1293  def __init__(self) -> None:
1294    super().__init__(uses_cval=True)
1296  def __call__(
1297      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1298  ) -> tuple[_NDArray, _NDArray]:
1299    low = index < 0
1300    weight[low] = 0.0
1301    index[low] = 0
1302    high = index >= size
1303    weight[high] = 0.0
1304    index[high] = size - 1
1305    return index, weight
1308class ValidExtendSamples(ExtendSamples):
1309  """Assign all domain samples weight 1 and all outside samples weight 0.
1310  Compute a weighted reconstruction and divide by the reconstructed weight."""
1312  def __init__(self) -> None:
1313    super().__init__(uses_cval=True)
1315  def __call__(
1316      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1317  ) -> tuple[_NDArray, _NDArray]:
1318    low = index < 0
1319    weight[low] = 0.0
1320    index[low] = 0
1321    high = index >= size
1322    weight[high] = 0.0
1323    index[high] = size - 1
1324    sum_weight = weight.sum(axis=-1)
1325    nonzero_sum = sum_weight != 0.0
1326    np.divide(weight, sum_weight[..., None], out=weight, where=nonzero_sum[..., None])
1327    return index, weight
1330class LinearExtendSamples(ExtendSamples):
1331  """Linearly extrapolate beyond boundary samples."""
1333  def __call__(
1334      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1335  ) -> tuple[_NDArray, _NDArray]:
1336    if size < 2:
1337      index = gridtype.reflect(index, size)
1338      return index, weight
1339    # For each boundary, define new columns in index and weight arrays to represent the last and
1340    # next-to-last samples.  When we later construct the sparse resize matrix, we will sum the
1341    # duplicate index entries.
1342    low = index < 0
1343    high = index >= size
1344    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 4), weight.dtype)
1345    x = index
1346    w[..., -4] = ((1 - x) * weight).sum(where=low, axis=-1)
1347    w[..., -3] = ((x) * weight).sum(where=low, axis=-1)
1348    x = (size - 1) - index
1349    w[..., -2] = ((x) * weight).sum(where=high, axis=-1)
1350    w[..., -1] = ((1 - x) * weight).sum(where=high, axis=-1)
1351    weight[low] = 0.0
1352    index[low] = 0
1353    weight[high] = 0.0
1354    index[high] = size - 1
1355    w[..., :-4] = weight
1356    weight = w
1357    new_index = np.empty(w.shape, index.dtype)
1358    new_index[..., :-4] = index
1359    # Let matrix (including zero values) be banded.
1360    new_index[..., -4:] = np.where(w[..., -4:] != 0.0, [0, 1, size - 2, size - 1], index[..., :1])
1361    index = new_index
1362    return index, weight
1365class QuadraticExtendSamples(ExtendSamples):
1366  """Quadratically extrapolate beyond boundary samples."""
1368  def __call__(
1369      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1370  ) -> tuple[_NDArray, _NDArray]:
1371    # [Keys 1981] suggests this as x[-1] = 3*x[0] - 3*x[1] + x[2], calling it "cubic precision",
1372    # but it seems just quadratic.
1373    if size < 3:
1374      index = gridtype.reflect(index, size)
1375      return index, weight
1376    low = index < 0
1377    high = index >= size
1378    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 6), weight.dtype)
1379    x = index
1380    w[..., -6] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=low, axis=-1)
1381    w[..., -5] = (((-x + 2) * x) * weight).sum(where=low, axis=-1)
1382    w[..., -4] = (((0.5 * x - 0.5) * x) * weight).sum(where=low, axis=-1)
1383    x = (size - 1) - index
1384    w[..., -3] = (((0.5 * x - 0.5) * x) * weight).sum(where=high, axis=-1)
1385    w[..., -2] = (((-x + 2) * x) * weight).sum(where=high, axis=-1)
1386    w[..., -1] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=high, axis=-1)
1387    weight[low] = 0.0
1388    index[low] = 0
1389    weight[high] = 0.0
1390    index[high] = size - 1
1391    w[..., :-6] = weight
1392    weight = w
1393    new_index = np.empty(w.shape, index.dtype)
1394    new_index[..., :-6] = index
1395    # Let matrix (including zero values) be banded.
1396    new_index[..., -6:] = np.where(
1397        w[..., -6:] != 0.0, [0, 1, 2, size - 3, size - 2, size - 1], index[..., :1]
1398    )
1399    index = new_index
1400    return index, weight
1404class OverrideExteriorValue:
1405  """Abstract base class to set the value outside some domain extent to a
1406  constant value (`cval`)."""
1408  boundary_antialiasing: bool = True
1409  """Antialias the pixel values adjacent to the boundary of the extent."""
1411  uses_cval: bool = False
1412  """Modify some weights to introduce references to the `cval` constant value."""
1414  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1415    """For all `point` outside some extent, modify the weight to be zero."""
1417  def override_using_signed_distance(
1418      self, weight: _NDArray, point: _NDArray, signed_distance: _NDArray, /
1419  ) -> None:
1420    """Reduce sample weights for "outside" values based on the signed distance function,
1421    to effectively assign the constant value `cval`."""
1422    all_points_inside_domain = np.all(signed_distance <= 0.0)
1423    if all_points_inside_domain:
1424      return
1425    if self.boundary_antialiasing and min(point.shape) >= 2:
1426      # For discontinuous coordinate mappings, we may need to somehow ignore
1427      # the large finite differences computed across the map discontinuities.
1428      gradient = np.gradient(point)
1429      gradient_norm = np.linalg.norm(np.atleast_2d(gradient), axis=0)
1430      signed_distance_in_samples = signed_distance / (gradient_norm + 1e-20)
1431      # Opacity is in linear space, which is correct if Gamma is set.
1432      opacity = (0.5 - signed_distance_in_samples).clip(0.0, 1.0)
1433      weight *= opacity[..., None]
1434    else:
1435      is_outside = signed_distance > 0.0
1436      weight[is_outside, :] = 0.0
1439class NoOverrideExteriorValue(OverrideExteriorValue):
1440  """The function value is not overridden."""
1442  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1443    pass
1446class UnitDomainOverrideExteriorValue(OverrideExteriorValue):
1447  """Values outside the unit interval [0, 1] are replaced by the constant `cval`."""
1449  def __init__(self, **kwargs: Any) -> None:
1450    super().__init__(uses_cval=True, **kwargs)
1452  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1453    signed_distance = abs(point - 0.5) - 0.5  # Boundaries at 0.0 and 1.0.
1454    self.override_using_signed_distance(weight, point, signed_distance)
1457class PlusMinusOneOverrideExteriorValue(OverrideExteriorValue):
1458  """Values outside the interval [-1, 1] are replaced by the constant `cval`."""
1460  def __init__(self, **kwargs: Any) -> None:
1461    super().__init__(uses_cval=True, **kwargs)
1463  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1464    signed_distance = abs(point) - 1.0  # Boundaries at -1.0 and 1.0.
1465    self.override_using_signed_distance(weight, point, signed_distance)
1469class Boundary:
1470  """Domain boundary rules.  These define the reconstruction over the source domain near and beyond
1471  the domain boundaries.  The rules may be specified separately for each domain dimension."""
1473  name: str = ''
1474  """Boundary rule name."""
1476  coord_remap: RemapCoordinates = NoRemapCoordinates()
1477  """Modify specified coordinates prior to evaluating the reconstruction kernels."""
1479  extend_samples: ExtendSamples = ReflectExtendSamples()
1480  """Define the value of each grid sample outside the unit domain as an affine combination of
1481  interior sample(s) and possibly the constant value (`cval`)."""
1483  override_value: OverrideExteriorValue = NoOverrideExteriorValue()
1484  """Set the value outside some extent to a constant value (`cval`)."""
1486  @property
1487  def uses_cval(self) -> bool:
1488    """True if weights may be non-affine, involving the constant value (`cval`)."""
1489    return self.extend_samples.uses_cval or self.override_value.uses_cval
1491  def preprocess_coordinates(self, point: _NDArray, /) -> _NDArray:
1492    """Modify coordinates prior to evaluating the filter kernels."""
1493    # Antialiasing across the tile boundaries may be feasible but seems hard.
1494    point = self.coord_remap(point)
1495    return point
1497  def apply(
1498      self, index: _NDArray, weight: _NDArray, point: _NDArray, size: int, gridtype: Gridtype, /
1499  ) -> tuple[_NDArray, _NDArray]:
1500    """Replace exterior samples by combinations of interior samples."""
1501    index, weight = self.extend_samples(index, weight, size, gridtype)
1502    self.override_reconstruction(weight, point)
1503    return index, weight
1505  def override_reconstruction(self, weight: _NDArray, point: _NDArray, /) -> None:
1506    """For points outside an extent, modify weight to zero to assign `cval`."""
1507    self.override_value(weight, point)
1511    'reflect': Boundary('reflect', extend_samples=ReflectExtendSamples()),
1512    'wrap': Boundary('wrap', extend_samples=WrapExtendSamples()),
1513    'tile': Boundary(
1514        'title', coord_remap=TileRemapCoordinates(), extend_samples=ReflectExtendSamples()
1515    ),
1516    'clamp': Boundary('clamp', extend_samples=ClampExtendSamples()),
1517    'border': Boundary('border', extend_samples=BorderExtendSamples()),
1518    'natural': Boundary(
1519        'natural',
1520        extend_samples=ValidExtendSamples(),
1521        override_value=UnitDomainOverrideExteriorValue(),
1522    ),
1523    'linear_constant': Boundary(
1524        'linear_constant',
1525        extend_samples=LinearExtendSamples(),
1526        override_value=UnitDomainOverrideExteriorValue(),
1527    ),
1528    'quadratic_constant': Boundary(
1529        'quadratic_constant',
1530        extend_samples=QuadraticExtendSamples(),
1531        override_value=UnitDomainOverrideExteriorValue(),
1532    ),
1533    'reflect_clamp': Boundary('reflect_clamp', extend_samples=ReflectClampExtendSamples()),
1534    'constant': Boundary(
1535        'constant',
1536        extend_samples=ReflectExtendSamples(),
1537        override_value=UnitDomainOverrideExteriorValue(),
1538    ),
1539    'linear': Boundary('linear', extend_samples=LinearExtendSamples()),
1540    'quadratic': Boundary('quadratic', extend_samples=QuadraticExtendSamples()),
1544"""Shortcut names for some predefined boundary rules (as defined by `_DICT_BOUNDARIES`):
1546| name                   | a.k.a. / comments |
1548| `'reflect'`            | *reflected*, *symm*, *symmetric*, *mirror*, *grid-mirror* |
1549| `'wrap'`               | *periodic*, *repeat*, *grid-wrap* |
1550| `'tile'`               | like `'reflect'` within unit domain, then tile discontinuously |
1551| `'clamp'`              | *clamped*, *nearest*, *edge*, *clamp-to-edge*, repeat last sample |
1552| `'border'`             | *grid-constant*, use `cval` for samples outside unit domain |
1553| `'natural'`            | *renormalize* using only interior samples, use `cval` outside domain |
1554| `'reflect_clamp'`      | *mirror-clamp-to-edge* |
1555| `'constant'`           | like `'reflect'` but replace by `cval` outside unit domain |
1556| `'linear'`             | extrapolate from 2 last samples |
1557| `'quadratic'`          | extrapolate from 3 last samples |
1558| `'linear_constant'`    | like `'linear'` but replace by `cval` outside unit domain |
1559| `'quadratic_constant'` | like `'quadratic'` but replace by `cval` outside unit domain |
1561These boundary rules may be specified per dimension.  See the source code for extensibility
1562using the classes `RemapCoordinates`, `ExtendSamples`, and `OverrideExteriorValue`.
1564**Boundary rules illustrated in 1D:**
1567<img src="" width="100%"/>
1570**Boundary rules illustrated in 2D:**
1573<img src="" width="100%"/>
1578    'reflect wrap tile clamp border natural linear_constant quadratic_constant'.split()
1580"""A useful subset of `BOUNDARIES` for visualization in figures."""
1583def _get_boundary(boundary: str | Boundary, /) -> Boundary:
1584  """Return a `Boundary`, which can be specified as a name in `BOUNDARIES`."""
1585  return boundary if isinstance(boundary, Boundary) else _DICT_BOUNDARIES[boundary]
1589class Filter(abc.ABC):
1590  """Abstract base class for filter kernel functions.
1592  Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support
1593  interval [-radius, radius].  (Some sites instead define kernels over the interval [0, N]
1594  where N = 2 * radius.)
1596  Portions of this code are adapted from the C++ library in
1599  See also
1600  """
1602  name: str
1603  """Filter kernel name."""
1605  radius: float
1606  """Max absolute value of x for which self(x) is nonzero."""
1608  interpolating: bool = True
1609  """True if self(0) == 1.0 and self(i) == 0.0 for all nonzero integers i."""
1611  continuous: bool = True
1612  """True if the kernel function has $C^0$ continuity."""
1614  partition_of_unity: bool = True
1615  """True if the convolution of the kernel with a Dirac comb reproduces the
1616  unity function."""
1618  unit_integral: bool = True
1619  """True if the integral of the kernel function is 1."""
1621  requires_digital_filter: bool = False
1622  """True if the filter needs a pre/post digital filter for interpolation."""
1624  @abc.abstractmethod
1625  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1626    """Return evaluation of filter kernel at locations x."""
1629class ImpulseFilter(Filter):
1630  """See"""
1632  def __init__(self) -> None:
1633    super().__init__(name='impulse', radius=1e-20, continuous=False, partition_of_unity=False)
1635  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1636    raise AssertionError('The Impulse is infinitely narrow, so cannot be directly evaluated.')
1639class BoxFilter(Filter):
1640  """See
1642  The kernel function has value 1.0 over the half-open interval [-.5, .5).
1643  """
1645  def __init__(self) -> None:
1646    super().__init__(name='box', radius=0.5, continuous=False)
1648  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1649    use_asymmetric = True
1650    if use_asymmetric:
1651      x = np.asarray(x)
1652      return np.where((-0.5 <= x) & (x < 0.5), 1.0, 0.0)
1653    x = np.abs(x)
1654    return np.where(x < 0.5, 1.0, np.where(x == 0.5, 0.5, 0.0))
1657class TrapezoidFilter(Filter):
1658  """Filter for antialiased "area-based" filtering.
1660  Args:
1661    radius: Specifies the support [-radius, radius] of the filter, where 0.5 < radius <= 1.0.
1662      The special case `radius = None` is a placeholder that indicates that the filter will be
1663      replaced by a trapezoid of the appropriate radius (based on scaling) for correct
1664      antialiasing in both minification and magnification.
1666  This filter is similar to the BoxFilter but with linearly sloped sides.  It has value 1.0
1667  in the interval abs(x) <= 1.0 - radius and decreases linearly to value 0.0 in the interval
1668  1.0 - radius <= abs(x) <= radius, always with value 0.5 at x = 0.5.
1669  """
1671  def __init__(self, *, radius: float | None = None) -> None:
1672    if radius is None:
1673      super().__init__(name='trapezoid', radius=0.0)
1674      return
1675    if not 0.5 < radius <= 1.0:
1676      raise ValueError(f'Radius {radius} is outside the range (0.5, 1.0].')
1677    super().__init__(name=f'trapezoid_{radius}', radius=radius)
1679  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1680    x = np.abs(x)
1681    assert 0.5 < self.radius <= 1.0
1682    return ((0.5 + 0.25 / (self.radius - 0.5)) - (0.5 / (self.radius - 0.5)) * x).clip(0.0, 1.0)
1685class TriangleFilter(Filter):
1686  """See
1688  Also known as the hat or tent function.  It is used for piecewise-linear
1689  (or bilinear, or trilinear, ...) interpolation.
1690  """
1692  def __init__(self) -> None:
1693    super().__init__(name='triangle', radius=1.0)
1695  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1696    return (1.0 - np.abs(x)).clip(0.0, 1.0)
1699class CubicFilter(Filter):
1700  """Family of cubic filters parameterized by two scalar parameters.
1702  Args:
1703    b: first scalar parameter.
1704    c: second scalar parameter.
1706  See and
1709  [D. P. Mitchell and A. N. Netravali. Reconstruction filters in computer graphics.
1710  Computer Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1712  - The filter has quadratic precision iff b + 2 * c == 1.
1713  - The filter is interpolating iff b == 0.
1714  - (b=1, c=0) is the (non-interpolating) cubic B-spline basis;
1715  - (b=1/3, c=1/3) is the Mitchell filter;
1716  - (b=0, c=0.5) is the Catmull-Rom spline (which has cubic precision);
1717  - (b=0, c=0.75) is the "sharper cubic" used in Photoshop and OpenCV.
1718  """
1720  def __init__(self, *, b: float, c: float, name: str | None = None) -> None:
1721    name = f'cubic_b{b}_c{c}' if name is None else name
1722    interpolating = b == 0
1723    super().__init__(name=name, radius=2.0, interpolating=interpolating)
1724    self.b, self.c = b, c
1726  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1727    x = np.abs(x)
1728    b, c = self.b, self.c
1729    f3, f2, f0 = 2 - 9 / 6 * b - c, -3 + 2 * b + c, 1 - 1 / 3 * b
1730    g3, g2, g1, g0 = -b / 6 - c, b + 5 * c, -2 * b - 8 * c, 8 / 6 * b + 4 * c
1731    # (np.polynomial.polynomial.polyval(x, [f0, 0, f2, f3]) is almost
1732    # twice as slow; see also
1733    v01 = ((f3 * x + f2) * x) * x + f0
1734    v12 = ((g3 * x + g2) * x + g1) * x + g0
1735    return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1738class CatmullRomFilter(CubicFilter):
1739  """Cubic filter with cubic precision.  Also known as Keys filter.
1741  [E. Catmull, R. Rom.  A class of local interpolating splines.  Computer aided geometric
1742  design, 1974]
1743  [Wikipedia](
1745  [R. G. Keys.  Cubic convolution interpolation for digital image processing.
1746  IEEE Trans. on Acoustics, Speech, and Signal Processing, 29(6), 1981.]
1748  """
1750  def __init__(self) -> None:
1751    super().__init__(b=0, c=0.5, name='cubic')
1754class MitchellFilter(CubicFilter):
1755  """See
1757  [D. P. Mitchell and A. N. Netravali.  Reconstruction filters in computer graphics.  Computer
1758  Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1759  """
1761  def __init__(self) -> None:
1762    super().__init__(b=1 / 3, c=1 / 3, name='mitchell')
1765class SharpCubicFilter(CubicFilter):
1766  """Cubic filter that is sharper than Catmull-Rom filter.
1768  Used by some tools including OpenCV and Photoshop.
1770  See and
1772  """
1774  def __init__(self) -> None:
1775    super().__init__(b=0, c=0.75, name='sharpcubic')
1778class LanczosFilter(Filter):
1779  """High-quality filter: sinc function modulated by a sinc window.
1781  Args:
1782    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1783    sampled: If True, use a discretized approximation for improved speed.
1785  See
1786  """
1788  def __init__(self, *, radius: int, sampled: bool = True) -> None:
1789    super().__init__(
1790        name=f'lanczos_{radius}', radius=radius, partition_of_unity=False, unit_integral=False
1791    )
1793    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1794    def _eval(x: _ArrayLike) -> _NDArray:
1795      x = np.abs(x)
1796      # Note that window[n] = sinc(2*n/N - 1), with 0 <= n <= N.
1797      # But, x = n - N/2, or equivalently, n = x + N/2, with -N/2 <= x <= N/2.
1798      window = _sinc(x / radius)  # Zero-phase function w_0(x).
1799      return np.where(x < radius, _sinc(x) * window, 0.0)
1801    self._function = _eval
1803  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1804    return self._function(x)
1807class GeneralizedHammingFilter(Filter):
1808  """Sinc function modulated by a Hamming window.
1810  Args:
1811    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1812    a0: Scalar parameter, where 0.0 < a0 < 1.0.  The case of a0=0.5 is the Hann filter.
1814  See,
1815  and hamming() in
1817  Note that `'hamming3'` is `(radius=3, a0=25/46)`, which close to but different from `a0=0.54`.
1819  See also np.hamming() and np.hanning().
1820  """
1822  def __init__(self, *, radius: int, a0: float) -> None:
1823    super().__init__(
1824        name=f'hamming_{radius}',
1825        radius=radius,
1826        partition_of_unity=False,  # 1:1.00242  av=1.00188  sd=0.00052909
1827        unit_integral=False,  # 1.00188
1828    )
1829    assert 0.0 < a0 < 1.0
1830    self.a0 = a0
1832  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1833    x = np.abs(x)
1834    # Note that window[n] = a0 - (1 - a0) * cos(2 * pi * n / N), 0 <= n <= N.
1835    # With n = x + N/2, we get the zero-phase function w_0(x):
1836    window = self.a0 + (1.0 - self.a0) * np.cos(np.pi / self.radius * x)
1837    return np.where(x < self.radius, _sinc(x) * window, 0.0)
1840class KaiserFilter(Filter):
1841  """Sinc function modulated by a Kaiser-Bessel window.
1843  See, and example use in:
1844  [Karras et al. 20201.  Alias-free generative adversarial networks.
1847  See also np.kaiser().
1849  Args:
1850    radius: Value L/2 in the definition.  It may be fractional for a (digital) resizing filter
1851      (sample spacing s != 1) with an even number of samples (dual grid), e.g., Eq. (6)
1852      in [Karras et al. 2021] --- this effects the precise shape of the window function.
1853    beta: Determines the trade-off between main-lobe width and side-lobe level.
1854    sampled: If True, use a discretized approximation for improved speed.
1855  """
1857  def __init__(self, *, radius: float, beta: float, sampled: bool = True) -> None:
1858    assert beta >= 0.0
1859    super().__init__(
1860        name=f'kaiser_{radius}_{beta}', radius=radius, partition_of_unity=False, unit_integral=False
1861    )
1863    @_cache_sampled_1d_function(xmin=-math.ceil(radius), xmax=math.ceil(radius), enable=sampled)
1864    def _eval(x: _ArrayLike) -> _NDArray:
1865      x = np.abs(x)
1866      window = np.i0(beta * np.sqrt((1.0 - np.square(x / radius)).clip(0.0, 1.0))) / np.i0(beta)
1867      return np.where(x <= radius + 1e-6, _sinc(x) * window, 0.0)
1869    self._function = _eval
1871  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1872    return self._function(x)
1875class BsplineFilter(Filter):
1876  """B-spline of a non-negative degree.
1878  Args:
1879    degree: The polynomial degree of the B-spline segments.
1880      With `degree=0`, it is like `BoxFilter` except with f(0.5) = f(-0.5) = 0.
1881      With `degree=1`, it is identical to `TriangleFilter`.
1882      With `degree >= 2`, it is no longer interpolating.
1884  See [Carl de Boor.  A practical guide to splines.  Springer, 2001.]
1886  """
1888  def __init__(self, *, degree: int) -> None:
1889    if degree < 0:
1890      raise ValueError(f'Bspline of degree {degree} is invalid.')
1891    radius = (degree + 1) / 2
1892    interpolating = degree <= 1
1893    super().__init__(name=f'bspline{degree}', radius=radius, interpolating=interpolating)
1894    t = list(range(degree + 2))
1895    self._bspline = scipy.interpolate.BSpline.basis_element(t)
1897  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1898    x = np.abs(x)
1899    return np.where(x < self.radius, self._bspline(x + self.radius), 0.0)
1902class CardinalBsplineFilter(Filter):
1903  """Interpolating B-spline, achieved with aid of digital pre or post filter.
1905  Args:
1906    degree: The polynomial degree of the B-spline segments.
1907    sampled: If True, use a discretized approximation for improved speed.
1909  See [Hou and Andrews.  Cubic splines for image interpolation and digital filtering, 1978] and
1910  [Unser et al.  Fast B-spline transforms for continuous image representation and interpolation,
1911  1991].
1912  """
1914  def __init__(self, *, degree: int, sampled: bool = True) -> None:
1915 = degree
1916    if degree < 0:
1917      raise ValueError(f'Bspline of degree {degree} is invalid.')
1918    radius = (degree + 1) / 2
1919    super().__init__(
1920        name=f'cardinal{degree}',
1921        radius=radius,
1922        requires_digital_filter=degree >= 2,
1923        continuous=degree >= 1,
1924    )
1925    t = list(range(degree + 2))
1926    bspline = scipy.interpolate.BSpline.basis_element(t)
1928    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1929    def _eval(x: _ArrayLike) -> _NDArray:
1930      x = np.abs(x)
1931      return np.where(x < radius, bspline(x + radius), 0.0)
1933    self._function = _eval
1935  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1936    return self._function(x)
1939class OmomsFilter(Filter):
1940  """OMOMS interpolating filter, with aid of digital pre or post filter.
1942  Args:
1943    degree: The polynomial degree of the filter segments.
1945  Optimal MOMS (maximal-order-minimal-support) function; see [Blu and Thevenaz, MOMS: Maximal-order
1946  interpolation of minimal support, 2001].
1948  """
1950  def __init__(self, *, degree: int) -> None:
1951    if degree not in (3, 5):
1952      raise ValueError(f'Degree {degree} not supported.')
1953    super().__init__(name=f'omoms{degree}', radius=(degree + 1) / 2, requires_digital_filter=True)
1954 = degree
1956  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1957    x = np.abs(x)
1958    match
1959      case 3:
1960        v01 = ((0.5 * x - 1.0) * x + 3 / 42) * x + 26 / 42
1961        v12 = ((-7 / 42 * x + 1.0) * x - 85 / 42) * x + 58 / 42
1962        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1963      case 5:
1964        v01 = ((((-1 / 12 * x + 1 / 4) * x - 5 / 99) * x - 9 / 22) * x - 1 / 792) * x + 229 / 440
1965        v12 = (
1966            (((1 / 24 * x - 3 / 8) * x + 505 / 396) * x - 83 / 44) * x + 1351 / 1584
1967        ) * x + 839 / 2640
1968        v23 = (
1969            (((-1 / 120 * x + 1 / 8) * x - 299 / 396) * x + 101 / 44) * x - 27811 / 7920
1970        ) * x + 5707 / 2640
1971        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, np.where(x < 3.0, v23, 0.0)))
1972      case _:
1973        raise ValueError(
1976class GaussianFilter(Filter):
1977  r"""See
1979  Args:
1980    standard_deviation: Sets the Gaussian $\sigma$.  The default value is 1.25/3.0, which
1981      creates a kernel that is as-close-as-possible to a partition of unity.
1982  """
1985  """This value creates a kernel that is as-close-as-possible to a partition of unity; see
1986  mesh_processing/test/GridOp_test.cpp: `0.93503:1.06497     av=1           sd=0.0459424`.
1987  Another possibility is 0.5, as suggested on p. 4 of [Ken Turkowski.  Filters for common
1988  resampling tasks, 1990] for kernels with a support of 3 pixels.
1990  """
1992  def __init__(self, *, standard_deviation: float = DEFAULT_STANDARD_DEVIATION) -> None:
1993    super().__init__(
1994        name=f'gaussian_{standard_deviation:.3f}',
1995        radius=np.ceil(8.0 * standard_deviation),  # Sufficiently large.
1996        interpolating=False,
1997        partition_of_unity=False,
1998    )
1999    self.standard_deviation = standard_deviation
2001  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2002    x = np.abs(x)
2003    sdv = self.standard_deviation
2004    v0r = np.exp(np.square(x / sdv) / -2.0) / (np.sqrt(math.tau) * sdv)
2005    return np.where(x < self.radius, v0r, 0.0)
2008class NarrowBoxFilter(Filter):
2009  """Compact footprint, used for visualization of grid sample location.
2011  Args:
2012    radius: Specifies the support [-radius, radius] of the narrow box function.  (The default
2013      value 0.199 is an inexact 0.2 to avoid numerical ambiguities.)
2014  """
2016  def __init__(self, *, radius: float = 0.199) -> None:
2017    super().__init__(
2018        name='narrowbox',
2019        radius=radius,
2020        continuous=False,
2021        unit_integral=False,
2022        partition_of_unity=False,
2023    )
2025  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2026    radius = self.radius
2027    magnitude = 1.0
2028    x = np.asarray(x)
2029    return np.where((-radius <= x) & (x < radius), magnitude, 0.0)
2032_DEFAULT_FILTER = 'lanczos3'
2035    'impulse': ImpulseFilter(),
2036    'box': BoxFilter(),
2037    'trapezoid': TrapezoidFilter(),
2038    'triangle': TriangleFilter(),
2039    'cubic': CatmullRomFilter(),
2040    'sharpcubic': SharpCubicFilter(),
2041    'lanczos3': LanczosFilter(radius=3),
2042    'lanczos5': LanczosFilter(radius=5),
2043    'lanczos10': LanczosFilter(radius=10),
2044    'cardinal3': CardinalBsplineFilter(degree=3),
2045    'cardinal5': CardinalBsplineFilter(degree=5),
2046    'omoms3': OmomsFilter(degree=3),
2047    'omoms5': OmomsFilter(degree=5),
2048    'hamming3': GeneralizedHammingFilter(radius=3, a0=25 / 46),
2049    'kaiser3': KaiserFilter(radius=3.0, beta=7.12),
2050    'gaussian': GaussianFilter(),
2051    'bspline3': BsplineFilter(degree=3),
2052    'mitchell': MitchellFilter(),
2053    'narrowbox': NarrowBoxFilter(),
2054    # Not in FILTERS:
2055    'hamming1': GeneralizedHammingFilter(radius=1, a0=0.54),
2056    'hann3': GeneralizedHammingFilter(radius=3, a0=0.5),
2057    'lanczos4': LanczosFilter(radius=4),
2060FILTERS = list(itertools.takewhile(lambda x: x != 'hamming1', _DICT_FILTERS))
2061r"""Shortcut names for some predefined filter kernels (specified per dimension).
2062The names expand to:
2064| name           | `Filter`                      | a.k.a. / comments |
2066| `'impulse'`    | `ImpulseFilter()`             | *nearest* |
2067| `'box'`        | `BoxFilter()`                 | non-antialiased box, e.g. ImageMagick |
2068| `'trapezoid'`  | `TrapezoidFilter()`           | *area* antialiasing, e.g. `cv.INTER_AREA` |
2069| `'triangle'`   | `TriangleFilter()`            | *linear*  (*bilinear* in 2D), spline `order=1` |
2070| `'cubic'`      | `CatmullRomFilter()`          | *catmullrom*, *keys*, *bicubic* |
2071| `'sharpcubic'` | `SharpCubicFilter()`          | `cv.INTER_CUBIC`, `torch 'bicubic'` |
2072| `'lanczos3'`   | `LanczosFilter`(radius=3)     | support window [-3, 3] |
2073| `'lanczos5'`   | `LanczosFilter`(radius=5)     | [-5, 5] |
2074| `'lanczos10'`  | `LanczosFilter`(radius=10)    | [-10, 10] |
2075| `'cardinal3'`  | `CardinalBsplineFilter`(degree=3) | *spline interpolation*, `order=3`, *GF* |
2076| `'cardinal5'`  | `CardinalBsplineFilter`(degree=5) | *spline interpolation*, `order=5`, *GF* |
2077| `'omoms3'`     | `OmomsFilter`(degree=3)       | non-$C^1$, [-3, 3], *GF* |
2078| `'omoms5'`     | `OmomsFilter`(degree=5)       | non-$C^1$, [-5, 5], *GF* |
2079| `'hamming3'`   | `GeneralizedHammingFilter`(...) | (radius=3, a0=25/46) |
2080| `'kaiser3'`    | `KaiserFilter`(radius=3.0, beta=7.12) | |
2081| `'gaussian'`   | `GaussianFilter()`            | non-interpolating, default $\sigma=1.25/3$ |
2082| `'bspline3'`   | `BsplineFilter`(degree=3)     | non-interpolating |
2083| `'mitchell'`   | `MitchellFilter()`            | *mitchellcubic* |
2084| `'narrowbox'`  | `NarrowBoxFilter()`           | for visualization of sample positions |
2086The comment label *GF* denotes a [generalized filter](, formed
2087as the composition of a finitely supported kernel and a discrete inverse convolution.
2089**Some example filter kernels:**
2092<img src="" width="100%"/>
2095<br/>A more extensive set of filters is presented [here](#plots_of_filters) in the
2097together with visualizations and analyses of the filter properties.
2098See the source code for extensibility.
2102def _get_filter(filter: str | Filter, /) -> Filter:
2103  """Return a `Filter`, which can be specified as a name string key in `FILTERS`."""
2104  return filter if isinstance(filter, Filter) else _DICT_FILTERS[filter]
2107def _to_float_01(array: _Array, /, dtype: _DTypeLike) -> _Array:
2108  """Scale uint to the range [0.0, 1.0], and clip float to [0.0, 1.0]."""
2109  array_dtype = _arr_dtype(array)
2110  dtype = np.dtype(dtype)
2111  assert np.issubdtype(dtype, np.floating)
2112  match array_dtype.type:
2113    case np.uint8 | np.uint16 | np.uint32:
2114      if _arr_arraylib(array) == 'numpy':
2115        assert isinstance(array, np.ndarray)  # Help mypy.
2116        return np.multiply(array, 1 / np.iinfo(array_dtype).max, dtype=dtype)
2117      return _arr_astype(array, dtype) / np.iinfo(array_dtype).max
2118    case _:
2119      assert np.issubdtype(array_dtype, np.floating)
2120      return _arr_clip(array, 0.0, 1.0, dtype)
2123def _from_float(array: _Array, /, dtype: _DTypeLike) -> _Array:
2124  """Convert a float in range [0.0, 1.0] to uint or float type."""
2125  assert np.issubdtype(_arr_dtype(array), np.floating)
2126  dtype = np.dtype(dtype)
2127  match dtype.type:
2128    case np.uint8 | np.uint16:
2129      return cast(_Array, _arr_astype(array * np.float32(np.iinfo(dtype).max) + 0.5, dtype))
2130    case np.uint32:
2131      return cast(_Array, _arr_astype(array * np.float64(np.iinfo(dtype).max) + 0.5, dtype))
2132    case _:
2133      assert np.issubdtype(dtype, np.floating)
2134      return _arr_astype(array, dtype)
2138class Gamma(abc.ABC):
2139  """Abstract base class for transfer functions on sample values.
2141  Image/video content is often stored using a color component transfer function.
2142  See
2144  Converts between integer types and [0.0, 1.0] internal value range.
2145  """
2147  name: str
2148  """Name of component transfer function."""
2150  @abc.abstractmethod
2151  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2152    """Decode source sample values into floating-point, possibly nonlinearly.
2154    Uint source values are mapped to the range [0.0, 1.0].
2155    """
2157  @abc.abstractmethod
2158  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2159    """Encode float signal into destination samples, possibly nonlinearly.
2161    Uint destination values are mapped from the range [0.0, 1.0].
2163    Note that non-integer destination types are not clipped to the range [0.0, 1.0].
2164    If that is desired, it can be performed as a postprocess using `output.clip(0.0, 1.0)`.
2165    """
2168class IdentityGamma(Gamma):
2169  """Identity component transfer function."""
2171  def __init__(self) -> None:
2172    super().__init__('identity')
2174  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2175    dtype = np.dtype(dtype)
2176    assert np.issubdtype(dtype, np.inexact)
2177    if np.issubdtype(_arr_dtype(array), np.unsignedinteger):
2178      return _to_float_01(array, dtype)
2179    return _arr_astype(array, dtype)
2181  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2182    dtype = np.dtype(dtype)
2183    assert np.issubdtype(dtype, np.number)
2184    if np.issubdtype(dtype, np.unsignedinteger):
2185      return _from_float(_arr_clip(array, 0.0, 1.0), dtype)
2186    if np.issubdtype(dtype, np.integer):
2187      return _arr_astype(array + 0.5, dtype)
2188    return _arr_astype(array, dtype)
2191class PowerGamma(Gamma):
2192  """Gamma correction using a power function."""
2194  def __init__(self, power: float) -> None:
2195    super().__init__(name=f'power_{power}')
2196    self.power = power
2198  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2199    dtype = np.dtype(dtype)
2200    assert np.issubdtype(dtype, np.floating)
2201    if _arr_dtype(array) == np.uint8 and self.power != 2:
2202      arraylib = _arr_arraylib(array)
2203      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2204      return _arr_getitem(decode_table, array)
2206    array = _to_float_01(array, dtype)
2207    return _arr_square(array) if self.power == 2 else array**self.power
2209  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2210    array = _arr_clip(array, 0.0, 1.0)
2211    array = _arr_sqrt(array) if self.power == 2 else array ** (1.0 / self.power)
2212    return _from_float(array, dtype)
2215class SrgbGamma(Gamma):
2216  """Gamma correction using sRGB; see"""
2218  def __init__(self) -> None:
2219    super().__init__(name='srgb')
2221  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2222    dtype = np.dtype(dtype)
2223    assert np.issubdtype(dtype, np.floating)
2224    if _arr_dtype(array) == np.uint8:
2225      arraylib = _arr_arraylib(array)
2226      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2227      return _arr_getitem(decode_table, array)
2229    x = _to_float_01(array, dtype)
2230    return _arr_where(x > 0.04045, ((x + 0.055) / 1.055) ** 2.4, x / 12.92)
2232  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2233    x = _arr_clip(array, 0.0, 1.0)
2234    # Unfortunately, exponentiation is slow, and np.digitize() is even slower.
2235    # pytype: disable=wrong-arg-types
2236    x = _arr_where(x > 0.0031308, x ** (1.0 / 2.4) * 1.055 - (0.055 - 1e-17), x * 12.92)
2237    # pytype: enable=wrong-arg-types
2238    return _from_float(x, dtype)
2241_DICT_GAMMAS = {
2242    'identity': IdentityGamma(),
2243    'power2': PowerGamma(2.0),
2244    'power22': PowerGamma(2.2),
2245    'srgb': SrgbGamma(),
2249r"""Shortcut names for some predefined gamma-correction schemes:
2251| name | `Gamma` | Decoding function<br/> (linear space from stored value) | Encoding function<br/> (stored value from linear space) |
2253| `'identity'` | `IdentityGamma()` | $l = e$ | $e = l$ |
2254| `'power2'` | `PowerGamma`(2.0) | $l = e^{2.0}$ | $e = l^{1/2.0}$ |
2255| `'power22'` | `PowerGamma`(2.2) | $l = e^{2.2}$ | $e = l^{1/2.2}$ |
2256| `'srgb'` | `SrgbGamma()` | $l = \left(\left(e + 0.055\right) / 1.055\right)^{2.4}$ | $e = l^{1/2.4} * 1.055 - 0.055$ |
2258See the source code for extensibility.
2262def _get_gamma(gamma: str | Gamma, /) -> Gamma:
2263  """Return a `Gamma`, which can be specified as a name in `GAMMAS`."""
2264  return gamma if isinstance(gamma, Gamma) else _DICT_GAMMAS[gamma]
2267def _get_src_dst_gamma(
2268    gamma: str | Gamma | None,
2269    src_gamma: str | Gamma | None,
2270    dst_gamma: str | Gamma | None,
2271    src_dtype: _DType,
2272    dst_dtype: _DType,
2273) -> tuple[Gamma, Gamma]:
2274  if gamma is None and src_gamma is None and dst_gamma is None:
2275    src_uint = np.issubdtype(src_dtype, np.unsignedinteger)
2276    dst_uint = np.issubdtype(dst_dtype, np.unsignedinteger)
2277    if src_uint and dst_uint:
2278      # The default might ideally be 'srgb' but that conversion is costlier.
2279      gamma = 'power2'
2280    elif not src_uint and not dst_uint:
2281      gamma = 'identity'
2282    else:
2283      raise ValueError(f'Gamma must be specified because {src_dtype=} and {dst_dtype=}.')
2284  if gamma is not None:
2285    if src_gamma is not None:
2286      raise ValueError('Cannot specify both gamma and src_gamma.')
2287    if dst_gamma is not None:
2288      raise ValueError('Cannot specify both gamma and dst_gamma.')
2289    src_gamma = dst_gamma = gamma
2290  assert src_gamma and dst_gamma
2291  src_gamma = _get_gamma(src_gamma)
2292  dst_gamma = _get_gamma(dst_gamma)
2293  return src_gamma, dst_gamma
2296def _create_resize_matrix(
2297    src_size: int,
2298    dst_size: int,
2299    src_gridtype: Gridtype,
2300    dst_gridtype: Gridtype,
2301    boundary: Boundary,
2302    filter: Filter,
2303    prefilter: Filter | None = None,
2304    scale: float = 1.0,
2305    translate: float = 0.0,
2306    dtype: _DTypeLike = np.float64,
2307    arraylib: str = 'numpy',
2308) -> tuple[_Array, _Array | None]:
2309  """Compute affine weights for 1D resampling from `src_size` to `dst_size`.
2311  Compute a sparse matrix in which each row expresses a destination sample value as a combination
2312  of source sample values depending on the boundary rule.  If the combination is non-affine,
2313  the remainder (returned as `cval_weight`) is the contribution of the special constant value
2314  (cval) defined outside the domain.
2316  Args:
2317    src_size: The number of samples within the source 1D domain.
2318    dst_size: The number of samples within the destination 1D domain.
2319    src_gridtype: Placement of the samples in the source domain grid.
2320    dst_gridtype: Placement of the output samples in the destination domain grid.
2321    boundary: The reconstruction boundary rule.
2322    filter: The reconstruction kernel (used for upsampling/magnification).
2323    prefilter: The prefilter kernel (used for downsampling/minification).  If it is `None`,
2324      `filter` is used.
2325    scale: Scaling factor applied when mapping the source domain onto the destination domain.
2326    translate: Offset applied when mapping the scaled source domain onto the destination domain.
2327    dtype: Precision of computed resize matrix entries.
2328    arraylib: Representation of output.  Must be an element of `ARRAYLIBS`.
2330  Returns:
2331    sparse_matrix: Matrix whose rows express output sample values as affine combinations of the
2332      source sample values.
2333    cval_weight: Optional vector expressing the additional contribution of the constant value
2334      (`cval`) to the combination in each row of `sparse_matrix`.  It equals one minus the sum of
2335      the weights in each matrix row.
2336  """
2337  if src_size < src_gridtype.min_size():
2338    raise ValueError(f'Source size {src_size} is too small for resize.')
2339  if dst_size < dst_gridtype.min_size():
2340    raise ValueError(f'Destination size {dst_size} is too small for resize.')
2341  prefilter = filter if prefilter is None else prefilter
2342  dtype = np.dtype(dtype)
2343  assert np.issubdtype(dtype, np.floating)
2345  scaling = dst_gridtype.size_in_samples(dst_size) / src_gridtype.size_in_samples(src_size) * scale
2346  is_minification = scaling < 1.0
2347  filter = prefilter if is_minification else filter
2348  if == 'trapezoid':
2349    radius = 0.5 + 0.5 * min(scaling, 1.0 / scaling)
2350    filter = TrapezoidFilter(radius=radius)
2351  radius = filter.radius
2352  num_samples = int(np.ceil(radius * 2 / scaling) if is_minification else np.ceil(radius * 2))
2354  dst_index = np.arange(dst_size, dtype=dtype)
2355  # Destination sample locations in unit domain [0, 1].
2356  dst_position = dst_gridtype.point_from_index(dst_index, dst_size)
2358  src_position = (dst_position - translate) / scale
2359  src_position = boundary.preprocess_coordinates(src_position)
2361  # Sample positions mapped back to source unit domain [0, 1].
2362  src_float_index = src_gridtype.index_from_point(src_position, src_size)
2363  src_first_index = (
2364      np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
2365      - (num_samples - 1) // 2
2366  )
2368  sample_index = np.arange(num_samples, dtype=np.int32)
2369  src_index = src_first_index[:, None] + sample_index  # (dst_size, num_samples)
2371  def get_weight_matrix() -> _NDArray:
2372    if == 'impulse':
2373      return np.ones(src_index.shape, dtype)
2374    if is_minification:
2375      x = (src_float_index[:, None] - src_index.astype(dtype)) * scaling
2376      return filter(x) * scaling
2377    # Either same size or magnification.
2378    x = src_float_index[:, None] - src_index.astype(dtype)
2379    return filter(x)
2381  weight = get_weight_matrix().astype(dtype, copy=False)
2383  if != 'narrowbox' and (is_minification or not filter.partition_of_unity):
2384    weight = weight / weight.sum(axis=-1)[..., None]
2386  src_index, weight = boundary.apply(src_index, weight, src_position, src_size, src_gridtype)
2387  shape = dst_size, src_size
2389  def prepare_sparse_resize_matrix() -> tuple[_NDArray, _NDArray, _NDArray]:
2390    linearized = (src_index + np.indices(src_index.shape)[0] * src_size).ravel()
2391    values = weight.ravel()
2392    # Remove the zero weights.
2393    nonzero = values != 0.0
2394    linearized, values = linearized[nonzero], values[nonzero]
2395    # Sort and merge the duplicate indices.
2396    unique, unique_inverse = np.unique(linearized, return_inverse=True)
2397    data2 = np.ones(len(linearized), np.float32)
2398    row_ind2 = unique_inverse
2399    col_ind2 = np.arange(len(linearized))
2400    shape2 = len(unique), len(linearized)
2401    csr = scipy.sparse.csr_matrix((data2, (row_ind2, col_ind2)), shape=shape2)
2402    data = csr * values  # Merged values.
2403    row_ind, col_ind = unique // src_size, unique % src_size  # Merged indices.
2404    return data, row_ind, col_ind
2406  data, row_ind, col_ind = prepare_sparse_resize_matrix()
2407  resize_matrix = _make_sparse_matrix(data, row_ind, col_ind, shape, arraylib)
2409  uses_cval = boundary.uses_cval or == 'narrowbox'
2410  cval_weight = _make_array(1.0 - weight.sum(axis=-1), arraylib) if uses_cval else None
2412  return resize_matrix, cval_weight
2415def _apply_digital_filter_1d(
2416    array: _Array,
2417    gridtype: Gridtype,
2418    boundary: Boundary,
2419    cval: _ArrayLike,
2420    filter: Filter,
2421    /,
2422    *,
2423    axis: int = 0,
2424) -> _Array:
2425  """Apply inverse convolution to the specified dimension of the array.
2427  Find the array coefficients such that convolution with the (continuous) filter (given
2428  gridtype and boundary) interpolates the original array values.
2429  """
2430  assert filter.requires_digital_filter
2431  arraylib = _arr_arraylib(array)
2433  if arraylib == 'tensorflow':
2434    import tensorflow as tf
2436    def forward(x: _NDArray) -> _NDArray:
2437      return _apply_digital_filter_1d_numpy(x, gridtype, boundary, cval, filter, axis, False)
2439    def backward(grad_output: _NDArray) -> _NDArray:
2440      return _apply_digital_filter_1d_numpy(
2441          grad_output, gridtype, boundary, cval, filter, axis, True
2442      )
2444    @tf.custom_gradient  # type: ignore[misc]
2445    def tensorflow_inverse_convolution(x: _TensorflowTensor) -> _TensorflowTensor:
2446      # Although `forward` accesses parameters gridtype, boundary, etc., it is not stateful
2447      # because the function is redefined on each invocation of _apply_digital_filter_1d.
2448      y = tf.numpy_function(forward, [x], x.dtype, stateful=False)
2450      def grad(grad_output: _TensorflowTensor) -> _TensorflowTensor:
2451        return tf.numpy_function(backward, [grad_output], x.dtype, stateful=False)
2453      return y, grad
2455    return tensorflow_inverse_convolution(array)
2457  if arraylib == 'torch':
2458    import torch.autograd
2460    class InverseConvolution(torch.autograd.Function):  # type: ignore[misc] # pylint: disable=abstract-method
2461      """Differentiable wrapper for _apply_digital_filter_1d."""
2463      @staticmethod
2464      def forward(ctx: Any, *args: _TorchTensor, **kwargs: Any) -> _TorchTensor:
2465        del ctx
2466        assert not kwargs
2467        (x,) = args
2468        a = _apply_digital_filter_1d_numpy(
2469            x.detach().numpy(), gridtype, boundary, cval, filter, axis, False
2470        )
2471        return torch.as_tensor(a)
2473      @staticmethod
2474      def backward(ctx: Any, *grad_outputs: _TorchTensor) -> _TorchTensor:
2475        del ctx
2476        (grad_output,) = grad_outputs
2477        a = _apply_digital_filter_1d_numpy(
2478            grad_output.detach().numpy(), gridtype, boundary, cval, filter, axis, True
2479        )
2480        return torch.as_tensor(a)
2482    return InverseConvolution.apply(array)
2484  if arraylib == 'jax':
2485    import jax
2486    import jax.numpy as jnp
2487    # It seems rather difficult to implement this digital filter (inverse convolution) in Jax.
2488    # sadly omits scipy.signal.filtfilt().
2489    # To include a (non-traceable) numpy function in Jax requires jax.experimental.host_callback
2490    # and/or defining a new jax.core.Primitive (which allows differentiability).  See
2491    #
2492    #  :-(
2493    #
2495    @jax.custom_gradient  # type: ignore[misc]
2496    def jax_inverse_convolution(x: _JaxArray) -> _JaxArray:
2497      # This function is not jax-traceable due to the presence of to_py(), so jit and grad fail.
2498      x_py = np.asarray(x)  # to_py() deprecated.
2499      a = _apply_digital_filter_1d_numpy(x_py, gridtype, boundary, cval, filter, axis, False)
2500      y = jnp.asarray(a)
2502      def grad(grad_output: _JaxArray) -> _JaxArray:
2503        grad_output_py = np.asarray(grad_output)  # to_py() deprecated.
2504        a = _apply_digital_filter_1d_numpy(
2505            grad_output_py, gridtype, boundary, cval, filter, axis, True
2506        )
2507        return jnp.asarray(a)
2509      return y, grad
2511    return jax_inverse_convolution(array)
2513  assert arraylib == 'numpy'
2514  assert isinstance(array, np.ndarray)  # Help mypy.
2515  return _apply_digital_filter_1d_numpy(array, gridtype, boundary, cval, filter, axis, False)
2518def _apply_digital_filter_1d_numpy(
2519    array: _NDArray,
2520    gridtype: Gridtype,
2521    boundary: Boundary,
2522    cval: _ArrayLike,
2523    filter: Filter,
2524    axis: int,
2525    compute_backward: bool,
2526    /,
2527) -> _NDArray:
2528  """Version of _apply_digital_filter_1d` specialized to numpy array."""
2529  assert np.issubdtype(array.dtype, np.inexact)
2530  cval = np.asarray(cval).astype(array.dtype, copy=False)
2532  # Use fast spline_filter1d() if we have a compatible gridtype, boundary, and filter:
2533  mode = {
2534      ('reflect', 'dual'): 'reflect',
2535      ('reflect', 'primal'): 'mirror',
2536      ('wrap', 'dual'): 'grid-wrap',
2537      ('wrap', 'primal'): 'wrap',
2538  }.get((,
2539  filter_is_compatible = isinstance(filter, CardinalBsplineFilter)
2540  use_split_filter1d = filter_is_compatible and mode
2541  if use_split_filter1d:
2542    assert isinstance(filter, CardinalBsplineFilter)  # Help mypy.
2543    assert >= 2
2544    # compute_backward=True is same: matrix is symmetric and cval is unused.
2545    return scipy.ndimage.spline_filter1d(
2546        array, axis=axis,, mode=mode, output=array.dtype
2547    )
2549  array_dim = np.moveaxis(array, axis, 0)
2550  l = original_l = math.ceil(filter.radius) - 1
2551  x = np.arange(-l, l + 1, dtype=array.real.dtype)
2552  values = filter(x)
2553  size = array_dim.shape[0]
2554  src_index = np.arange(size)[:, None] + np.arange(len(values)) - l
2555  weight = np.full((size, len(values)), values)
2556  src_position = np.broadcast_to(0.5, len(values))
2557  src_index, weight = boundary.apply(src_index, weight, src_position, size, gridtype)
2558  if == 'primal' and == 'wrap':
2559    # Overwrite redundant last row to preserve unreferenced last sample and thereby make the
2560    # matrix non-singular.
2561    src_index[-1] = [size - 1] + [0] * (src_index.shape[1] - 1)
2562    weight[-1] = [1.0] + [0.0] * (weight.shape[1] - 1)
2563  bandwidth = abs(src_index - np.arange(size)[:, None]).max()
2564  is_banded = bandwidth <= l + 1  # Add one for quadratic boundary and l == 1.
2565  # Currently, matrix is always banded unless == 'wrap'.
2567  data = weight.reshape(-1).astype(array.dtype, copy=False)
2568  row_ind = np.arange(size).repeat(src_index.shape[1])
2569  col_ind = src_index.reshape(-1)
2570  matrix = scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=(size, size))
2571  if compute_backward:
2572    matrix = matrix.T
2574  if boundary.uses_cval and not compute_backward:
2575    cval_weight = 1.0 - np.asarray(matrix.sum(axis=-1))[:, 0]
2576    if array_dim.ndim == 2:  # Handle the case that we have array_flat.
2577      cval = np.tile(cval.reshape(-1), array_dim[0].size // cval.size)
2578    array_dim = array_dim - cval_weight.reshape(-1, *(1,) * array_dim[0].ndim) * cval
2580  if is_banded:
2581    matrix = matrix.todia()
2582    assert np.all(np.diff(matrix.offsets) == 1)  # Consecutive, often [-l, l].
2583    l, u = -matrix.offsets[0], matrix.offsets[-1]
2584    assert l <= original_l + 1 and u <= original_l + 1, (l, u, original_l)
2585    options = dict(check_finite=False, overwrite_ab=True, overwrite_b=False)
2586    if _is_symmetric(matrix):
2587      array_dim = scipy.linalg.solveh_banded([-1 : l - 1 : -1], array_dim, **options)
2588    else:
2589      array_dim = scipy.linalg.solve_banded((l, u),[::-1], array_dim, **options)
2591  else:
2592    lu = scipy.sparse.linalg.splu(matrix.tocsc(), permc_spec='NATURAL')
2593    assert all(s <= size * len(values) for s in (lu.L.nnz, lu.U.nnz))  # Sparse.
2594    array_dim = lu.solve(array_dim.reshape(array_dim.shape[0], -1)).reshape(array_dim.shape)
2596  return np.moveaxis(array_dim, 0, axis)
2599def resize(
2600    array: _Array,
2601    /,
2602    shape: Iterable[int],
2603    *,
2604    gridtype: str | Gridtype | None = None,
2605    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2606    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2607    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2608    cval: _ArrayLike = 0,
2609    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2610    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2611    gamma: str | Gamma | None = None,
2612    src_gamma: str | Gamma | None = None,
2613    dst_gamma: str | Gamma | None = None,
2614    scale: float | Iterable[float] = 1.0,
2615    translate: float | Iterable[float] = 0.0,
2616    precision: _DTypeLike = None,
2617    dtype: _DTypeLike = None,
2618    dim_order: Iterable[int] | None = None,
2619    num_threads: int | Literal['auto'] = 'auto',
2620) -> _Array:
2621  """Resample `array` (a grid of sample values) onto a grid with resolution `shape`.
2623  The source `array` is any object recognized by `ARRAYLIBS`.  It is interpreted as a grid
2624  with `len(shape)` domain coordinate dimensions, where each grid sample value has shape
2625  `array.shape[len(shape):]`.
2627  Some examples:
2629  - A grayscale image has `array.shape = height, width` and resizing it with `len(shape) == 2`
2630    produces a new image of scalar values.
2631  - An RGB image has `array.shape = height, width, 3` and resizing it with `len(shape) == 2`
2632    produces a new image of RGB values.
2633  - An 3D grid of 3x3 Jacobians has `array.shape = Z, Y, X, 3, 3` and resizing it with
2634    `len(shape) == 3` produces a new 3D grid of Jacobians.
2636  This function also allows scaling and translation from the source domain to the output domain
2637  through the parameters `scale` and `translate`.  For more general transforms, see `resample`.
2639  Args:
2640    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
2641      The array must have numeric type.  Its first `len(shape)` dimensions are the domain
2642      coordinate dimensions.  Each grid dimension must be at least 1 for a `'dual'` grid or
2643      at least 2 for a `'primal'` grid.
2644    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2645      `array` must have at least as many dimensions as `len(shape)`.
2646    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
2647      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
2648      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
2649    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2650      Parameters `gridtype` and `src_gridtype` cannot both be set.
2651    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2652      Parameters `gridtype` and `dst_gridtype` cannot both be set.
2653    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2654      a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses `'reflect'`
2655      for upsampling and `'clamp'` for downsampling.
2656    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
2657      onto `array.shape[len(shape):]`.
2658    filter: The reconstruction kernel for each dimension in `shape`, specified as either a filter
2659      name in `FILTERS` or a `Filter` instance.  It is used during upsampling (i.e., magnification).
2660    prefilter: The prefilter kernel for each dimension in `shape`, specified as either a filter
2661      name in `FILTERS` or a `Filter` instance.  It is used during downsampling
2662      (i.e., minification).  If `None`, it inherits the value of `filter`.  The default
2663      `'lanczos3'` is good for natural images.  For vector graphics images, `'trapezoid'` is better
2664      because it avoids ringing artifacts.
2665    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from
2666      `array` and when creating output grid samples.  It is specified as either a name in `GAMMAS`
2667      or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default is
2668      `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
2669      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
2670      range [0.0, 1.0].
2671    src_gamma: Component transfer function used to "decode" `array` samples.
2672      Parameters `gamma` and `src_gamma` cannot both be set.
2673    dst_gamma: Component transfer function used to "encode" the output samples.
2674      Parameters `gamma` and `dst_gamma` cannot both be set.
2675    scale: Scaling factor applied to each dimension of the source domain when it is mapped onto
2676      the destination domain.
2677    translate: Offset applied to each dimension of the scaled source domain when it is mapped onto
2678      the destination domain.
2679    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
2680      on `array.dtype` and `dtype`.
2681    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
2682      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
2683      to the uint range.
2684    dim_order: Override the automatically selected order in which the grid dimensions are resized.
2685      Must contain a permutation of `range(len(shape))`.
2686    num_threads: Used to determine multithread parallelism if `array` is from `numpy`.  If set to
2687      `'auto'`, it is selected automatically.  Otherwise, it must be a positive integer.
2689  Returns:
2690    An array of the same class as the source `array`, with shape `shape + array.shape[len(shape):]`
2691      and data type `dtype`.
2693  **Example of image upsampling:**
2695  >>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
2696  >>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.
2698  <center>
2699  <img src=""/>
2700  </center>
2702  **Example of image downsampling:**
2704  >>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
2705  >>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
2706  >>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
2707  >>> downsampled = resize(array, (24, 48))
2709  <center>
2710  <img src=""/>
2711  </center>
2713  **Unit test:**
2715  >>> result = resize([1.0, 4.0, 5.0], shape=(4,))
2716  >>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
2717  """
2718  if isinstance(array, (tuple, list)):
2719    array = np.asarray(array)
2720  arraylib = _arr_arraylib(array)
2721  array_dtype = _arr_dtype(array)
2722  if not np.issubdtype(array_dtype, np.number):
2723    raise ValueError(f'Type {array.dtype} is not numeric.')
2724  shape2 = tuple(shape)
2725  array_ndim = len(array.shape)
2726  if not 0 < len(shape2) <= array_ndim:
2727    raise ValueError(f'Shape {array.shape} cannot be resized to {shape2}.')
2728  src_shape = array.shape[: len(shape2)]
2729  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2730      gridtype, src_gridtype, dst_gridtype, len(shape2), len(shape2)
2731  )
2732  boundary2 = np.broadcast_to(np.array(boundary), len(shape2))
2733  cval = np.broadcast_to(cval, array.shape[len(shape2) :])
2734  prefilter = filter if prefilter is None else prefilter
2735  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape2))]
2736  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), len(shape2))]
2737  dtype = array_dtype if dtype is None else np.dtype(dtype)
2738  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, array_dtype, dtype)
2739  scale2 = np.broadcast_to(np.array(scale), len(shape2))
2740  translate2 = np.broadcast_to(np.array(translate), len(shape2))
2741  del (shape, src_gridtype, dst_gridtype, boundary, filter, prefilter)
2742  del (src_gamma, dst_gamma, scale, translate)
2743  precision = _get_precision(precision, [array_dtype, dtype], [])
2744  weight_precision = _real_precision(precision)
2746  is_noop = (
2747      all(src == dst for src, dst in zip(src_shape, shape2))
2748      and all(gt1 == gt2 for gt1, gt2 in zip(src_gridtype2, dst_gridtype2))
2749      and all(f.interpolating for f in prefilter2)
2750      and np.all(scale2 == 1.0)
2751      and np.all(translate2 == 0.0)
2752      and src_gamma2 == dst_gamma2
2753  )
2754  if is_noop:
2755    return array
2757  if dim_order is None:
2758    dim_order = _arr_best_dims_order_for_resize(array, shape2)
2759  else:
2760    dim_order = tuple(dim_order)
2761    if sorted(dim_order) != list(range(len(shape2))):
2762      raise ValueError(f'{dim_order} not a permutation of {list(range(len(shape2)))}.')
2764  array = src_gamma2.decode(array, precision)
2766  can_use_fast_box_downsampling = (
2767      using_numba
2768      and arraylib == 'numpy'
2769      and len(shape2) == 2
2770      and array_ndim in (2, 3)
2771      and all(src > dst for src, dst in zip(src_shape, shape2))
2772      and all(src % dst == 0 for src, dst in zip(src_shape, shape2))
2773      and all( == 'dual' for gridtype in src_gridtype2)
2774      and all( == 'dual' for gridtype in dst_gridtype2)
2775      and all( in ('box', 'trapezoid') for f in prefilter2)
2776      and np.all(scale2 == 1.0)
2777      and np.all(translate2 == 0.0)
2778  )
2779  if can_use_fast_box_downsampling:
2780    assert isinstance(array, np.ndarray)  # Help mypy.
2781    array = _downsample_in_2d_using_box_filter(array, cast(Any, shape2))
2782    array = dst_gamma2.encode(array, dtype)
2783    return array
2785  # Multidimensional resize can be expressed using einsum() with multiple per-dim resize matrices,
2786  # e.g., as in jax.image.resize().  A benefit is to seek the optimal order of multiplications.
2787  # However, efficiency often requires sparse resize matrices, which are unsupported in einsum().
2788  # Sparse tensors requested for tf.einsum:
2789  # C++ library that computes tensor algebra expressions
2790  # on sparse and dense tensors; however it does not interoperate with tensorflow, torch, or jax.
2792  for dim in dim_order:
2793    skip_resize_on_this_dim = (
2794        shape2[dim] == array.shape[dim]
2795        and scale2[dim] == 1.0
2796        and translate2[dim] == 0.0
2797        and filter2[dim].interpolating
2798    )
2799    if skip_resize_on_this_dim:
2800      continue
2802    def get_is_minification() -> bool:
2803      src_in_samples = src_gridtype2[dim].size_in_samples(array.shape[dim])
2804      dst_in_samples = dst_gridtype2[dim].size_in_samples(shape2[dim])
2805      return dst_in_samples / src_in_samples * scale2[dim] < 1.0
2807    is_minification = get_is_minification()
2808    boundary_dim = boundary2[dim]
2809    if boundary_dim == 'auto':
2810      boundary_dim = 'clamp' if is_minification else 'reflect'
2811    boundary_dim = _get_boundary(boundary_dim)
2812    resize_matrix, cval_weight = _create_resize_matrix(
2813        array.shape[dim],
2814        shape2[dim],
2815        src_gridtype=src_gridtype2[dim],
2816        dst_gridtype=dst_gridtype2[dim],
2817        boundary=boundary_dim,
2818        filter=filter2[dim],
2819        prefilter=prefilter2[dim],
2820        scale=scale2[dim],
2821        translate=translate2[dim],
2822        dtype=weight_precision,
2823        arraylib=arraylib,
2824    )
2826    array_dim: _Array = _arr_moveaxis(array, dim, 0)
2827    array_flat = _arr_reshape(array_dim, (array_dim.shape[0], -1))
2828    array_flat = _arr_possibly_make_contiguous(array_flat)
2829    if not is_minification and filter2[dim].requires_digital_filter:
2830      array_flat = _apply_digital_filter_1d(
2831          array_flat, src_gridtype2[dim], boundary_dim, cval, filter2[dim]
2832      )
2834    array_flat = _arr_matmul_sparse_dense(resize_matrix, array_flat, num_threads=num_threads)
2835    if cval_weight is not None:
2836      cval_flat = np.broadcast_to(cval, array_dim.shape[1:]).reshape(-1)
2837      if np.issubdtype(array_dtype, np.complexfloating):
2838        cval_weight = _arr_astype(cval_weight, array_dtype)  # (Only necessary for 'tensorflow'.)
2839      array_flat += cval_weight[:, None] * cval_flat
2841    if is_minification and filter2[dim].requires_digital_filter:  # use prefilter2[dim]?
2842      array_flat = _apply_digital_filter_1d(
2843          array_flat, dst_gridtype2[dim], boundary_dim, cval, filter2[dim]
2844      )
2845    array_dim = _arr_reshape(array_flat, (array_flat.shape[0], *array_dim.shape[1:]))
2846    array = _arr_moveaxis(array_dim, 0, dim)
2848  array = dst_gamma2.encode(array, dtype)
2849  return array
2852_original_resize = resize
2855def resize_in_arraylib(array: _NDArray, /, *args: Any, arraylib: str, **kwargs: Any) -> _NDArray:
2856  """Evaluate the `resize()` operation using the specified array library from `ARRAYLIBS`."""
2857  _check_eq(_arr_arraylib(array), 'numpy')
2858  return _arr_numpy(_original_resize(_make_array(array, arraylib), *args, **kwargs))
2861def resize_in_numpy(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2862  """Evaluate the `resize()` operation using the `numpy` library."""
2863  return resize_in_arraylib(array, *args, arraylib='numpy', **kwargs)
2866def resize_in_tensorflow(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2867  """Evaluate the `resize()` operation using the `tensorflow` library."""
2868  return resize_in_arraylib(array, *args, arraylib='tensorflow', **kwargs)
2871def resize_in_torch(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2872  """Evaluate the `resize()` operation using the `torch` library."""
2873  return resize_in_arraylib(array, *args, arraylib='torch', **kwargs)
2876def resize_in_jax(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2877  """Evaluate the `resize()` operation using the `jax` library."""
2878  return resize_in_arraylib(array, *args, arraylib='jax', **kwargs)
2881def _resize_possibly_in_arraylib(
2882    array: _Array, /, *args: Any, arraylib: str, **kwargs: Any
2883) -> _AnyArray:
2884  """If `array` is from numpy, evaluate `resize()` using the array library from `ARRAYLIBS`."""
2885  if _arr_arraylib(array) == 'numpy':
2886    return _arr_numpy(
2887        _original_resize(_make_array(cast(_ArrayLike, array), arraylib), *args, **kwargs)
2888    )
2889  return _original_resize(array, *args, **kwargs)
2893def _create_jaxjit_resize() -> Callable[..., _Array]:
2894  """Lazily invoke `jax.jit` on `resize`."""
2895  import jax
2897  jitted: Any = jax.jit(
2898      _original_resize, static_argnums=(1,), static_argnames=list(_original_resize.__kwdefaults__)
2899  )
2900  return jitted
2903def jaxjit_resize(array: _Array, /, *args: Any, **kwargs: Any) -> _Array:
2904  """Compute `resize` but with resize function jitted using Jax."""
2905  return _create_jaxjit_resize()(array, *args, **kwargs)  # pylint: disable=not-callable
2908def uniform_resize(
2909    array: _Array,
2910    /,
2911    shape: Iterable[int],
2912    *,
2913    object_fit: Literal['contain', 'cover'] = 'contain',
2914    gridtype: str | Gridtype | None = None,
2915    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2916    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2917    boundary: str | Boundary | Iterable[str | Boundary] = 'natural',  # Instead of 'auto' default.
2918    scale: float | Iterable[float] = 1.0,
2919    translate: float | Iterable[float] = 0.0,
2920    **kwargs: Any,
2921) -> _Array:
2922  """Resample `array` onto a grid with resolution `shape` but with uniform scaling.
2924  Calls function `resize` with `scale` and `translate` set such that the aspect ratio of `array`
2925  is preserved.  The effect is similar to CSS `object-fit: contain`.
2926  The parameter `boundary` (whose default is changed to `'natural'`) determines the values assigned
2927  outside the source domain.
2929  Args:
2930    array: Regular grid of source sample values.
2931    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2932      `array` must have at least as many dimensions as `len(shape)`.
2933    object_fit: Like CSS `object-fit`.  If `'contain'`, `array` is resized uniformly to fit within
2934      `shape`. If `'cover'`, `array` is resized to fully cover `shape`.
2935    gridtype: Placement of samples on all dimensions of both the source and output domain grids.
2936    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2937    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2938    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2939      a name in `BOUNDARIES` or a `Boundary` instance.  The default is `'natural'`, which assigns
2940      `cval` to output points that map outside the source unit domain.
2941    scale: Parameter may not be specified.
2942    translate: Parameter may not be specified.
2943    **kwargs: Additional parameters for `resize` function (including `cval`).
2945  Returns:
2946    An array with shape `shape + array.shape[len(shape):]`.
2948  >>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
2949  array([[0., 1., 1., 0.],
2950         [0., 1., 1., 0.]])
2952  >>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
2953  array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
2954         [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
2956  >>> a = np.arange(6.0).reshape(2, 3)
2957  >>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
2958  array([[0.5, 1.5],
2959         [3.5, 4.5]])
2960  """
2961  if scale != 1.0 or translate != 0.0:
2962    raise ValueError('`uniform_resize()` does not accept `scale` or `translate` parameters.')
2963  if isinstance(array, (tuple, list)):
2964    array = np.asarray(array)
2965  shape = tuple(shape)
2966  array_ndim = len(array.shape)
2967  if not 0 < len(shape) <= array_ndim:
2968    raise ValueError(f'Shape {array.shape} cannot be resized to {shape}.')
2969  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2970      gridtype, src_gridtype, dst_gridtype, len(shape), len(shape)
2971  )
2972  raw_scales = [
2973      dst_gridtype2[dim].size_in_samples(shape[dim])
2974      / src_gridtype2[dim].size_in_samples(array.shape[dim])
2975      for dim in range(len(shape))
2976  ]
2977  scale0 = {'contain': min(raw_scales), 'cover': max(raw_scales)}[object_fit]
2978  scale2 = scale0 / np.array(raw_scales)
2979  translate = (1.0 - scale2) / 2
2980  return resize(array, shape, boundary=boundary, scale=scale2, translate=translate, **kwargs)
2983_MAX_BLOCK_SIZE_RECURSING = -999  # Special value to indicate re-invocation on partitioned blocks.
2986def resample(
2987    array: _Array,
2988    /,
2989    coords: _ArrayLike,
2990    *,
2991    gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual',
2992    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2993    cval: _ArrayLike = 0,
2994    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2995    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2996    gamma: str | Gamma | None = None,
2997    src_gamma: str | Gamma | None = None,
2998    dst_gamma: str | Gamma | None = None,
2999    jacobian: _ArrayLike | None = None,
3000    precision: _DTypeLike = None,
3001    dtype: _DTypeLike = None,
3002    max_block_size: int = 40_000,
3003    debug: bool = False,
3004) -> _Array:
3005  """Interpolate `array` (a grid of samples) at specified unit-domain coordinates `coords`.
3007  The last dimension of `coords` contains unit-domain coordinates at which to interpolate the
3008  domain grid samples in `array`.
3010  The number of coordinates (`coords.shape[-1]`) determines how to interpret `array`: its first
3011  `coords.shape[-1]` dimensions define the grid, and the remaining dimensions describe each grid
3012  sample (e.g., scalar, vector, tensor).
3014  Concretely, the grid has shape `array.shape[:coords.shape[-1]]` and each grid sample has shape
3015  `array.shape[coords.shape[-1]:]`.
3017  Examples include:
3019  - Resample a grayscale image with `array.shape = height, width` onto a new grayscale image with
3020    `new.shape = height2, width2` by using `coords.shape = height2, width2, 2`.
3022  - Resample an RGB image with `array.shape = height, width, 3` onto a new RGB image with
3023    `new.shape = height2, width2, 3` by using `coords.shape = height2, width2, 2`.
3025  - Sample an RGB image at `num` 2D points along a line segment by using `coords.shape = num, 2`.
3027  - Sample an RGB image at a single 2D point by using `coords.shape = (2,)`.
3029  - Sample a 3D grid of 3x3 Jacobians with `array.shape = nz, ny, nx, 3, 3` along a 2D plane by
3030    using `coords.shape = height, width, 3`.
3032  - Map a grayscale image through a color map by using `array.shape = 256, 3` and
3033    `coords.shape = height, width`.
3035  Args:
3036    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3037      The array must have numeric type.  The coordinate dimensions appear first, and
3038      each grid sample may have an arbitrary shape.  Each grid dimension must be at least 1 for
3039      a `'dual'` grid or at least 2 for a `'primal'` grid.
3040    coords: Grid of points at which to resample `array`.  The point coordinates are in the last
3041      dimension of `coords`.  The domain associated with the source grid is a unit hypercube,
3042      i.e. with a range [0, 1] on each coordinate dimension.  The output grid has shape
3043      `coords.shape[:-1]` and each of its grid samples has shape `array.shape[coords.shape[-1]:]`.
3044    gridtype: Placement of the samples in the source domain grid for each dimension, specified as
3045      either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`.
3046    boundary: The reconstruction boundary rule for each dimension in `coords.shape[-1]`, specified
3047      as either a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses
3048      `'reflect'` for upsampling and `'clamp'` for downsampling.
3049    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
3050      onto the shape `array.shape[coords.shape[-1]:]`.
3051    filter: The reconstruction kernel for each dimension in `coords.shape[-1]`, specified as either
3052      a filter name in `FILTERS` or a `Filter` instance.
3053    prefilter: The prefilter kernel for each dimension in `coords.shape[:-1]`, specified as either
3054      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3055      (i.e., minification).  If `None`, it inherits the value of `filter`.
3056    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples
3057      from `array` and when creating output grid samples.  It is specified as either a name in
3058      `GAMMAS` or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default
3059      is `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
3060      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
3061      range [0.0, 1.0].
3062    src_gamma: Component transfer function used to "decode" `array` samples.
3063      Parameters `gamma` and `src_gamma` cannot both be set.
3064    dst_gamma: Component transfer function used to "encode" the output samples.
3065      Parameters `gamma` and `dst_gamma` cannot both be set.
3066    jacobian: Optional array, which must be broadcastable onto the shape
3067      `coords.shape[:-1] + (coords.shape[-1], coords.shape[-1])`, storing for each point in the
3068      output grid the Jacobian matrix of the map from the unit output domain to the unit source
3069      domain.  If omitted, it is estimated by computing finite differences on `coords`.
3070    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3071      on `array.dtype`, `coords.dtype`, and `dtype`.
3072    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3073      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3074      to the uint range.
3075    max_block_size: If nonzero, maximum number of grid points in `coords` before the resampling
3076      evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
3077    debug: Show internal information.
3079  Returns:
3080    A new sample grid of shape `coords.shape[:-1]`, represented as an array of shape
3081    `coords.shape[:-1] + array.shape[coords.shape[-1]:]`, of the same array library type as
3082    the source array.
3084  **Example of resample operation:**
3086  <center>
3087  <img src=""/>
3088  </center>
3090  For reference, the identity resampling for a scalar-valued grid with the default grid-type
3091  `'dual'` is:
3093  >>> array = np.random.default_rng(0).random((5, 7, 3))
3094  >>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
3095  >>> new_array = resample(array, coords)
3096  >>> assert np.allclose(new_array, array)
3098  It is more efficient to use the function `resize` for the special case where the `coords` are
3099  obtained as simple scaling and translation of a new regular grid over the source domain:
3101  >>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
3102  >>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
3103  >>> coords = (coords - translate) / scale
3104  >>> resampled = resample(array, coords)
3105  >>> resized = resize(array, new_shape, scale=scale, translate=translate)
3106  >>> assert np.allclose(resampled, resized)
3107  """
3108  if isinstance(array, (tuple, list)):
3109    array = np.asarray(array)
3110  arraylib = _arr_arraylib(array)
3111  if len(array.shape) == 0:
3112    array = array[None]
3113  coords = np.atleast_1d(coords)
3114  if not np.issubdtype(_arr_dtype(array), np.number):
3115    raise ValueError(f'Type {array.dtype} is not numeric.')
3116  if not np.issubdtype(coords.dtype, np.floating):
3117    raise ValueError(f'Type {coords.dtype} is not floating.')
3118  array_ndim = len(array.shape)
3119  if coords.ndim == 1 and coords.shape[0] > 1 and array_ndim == 1:
3120    coords = coords[:, None]
3121  grid_ndim = coords.shape[-1]
3122  grid_shape = array.shape[:grid_ndim]
3123  sample_shape = array.shape[grid_ndim:]
3124  resampled_ndim = coords.ndim - 1
3125  resampled_shape = coords.shape[:-1]
3126  if grid_ndim > array_ndim:
3127    raise ValueError(
3128        f'There are more coordinate dimensions ({grid_ndim}) in {coords=} than in {array.shape=}.'
3129    )
3130  gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(gridtype), grid_ndim)]
3131  boundary2 = np.broadcast_to(np.array(boundary), grid_ndim).tolist()
3132  cval = np.broadcast_to(cval, sample_shape)
3133  prefilter = filter if prefilter is None else prefilter
3134  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), grid_ndim)]
3135  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), resampled_ndim)]
3136  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3137  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, _arr_dtype(array), dtype)
3138  del gridtype, boundary, filter, prefilter, src_gamma, dst_gamma
3139  if jacobian is not None:
3140    jacobian = np.broadcast_to(jacobian, resampled_shape + (coords.shape[-1],) * 2)
3141  precision = _get_precision(precision, [_arr_dtype(array), dtype], [coords.dtype])
3142  weight_precision = _real_precision(precision)
3143  coords = coords.astype(weight_precision, copy=False)
3144  is_minification = False  # Current limitation; no prefiltering!
3145  assert max_block_size >= 0 or max_block_size == _MAX_BLOCK_SIZE_RECURSING
3146  for dim in range(grid_ndim):
3147    if boundary2[dim] == 'auto':
3148      boundary2[dim] = 'clamp' if is_minification else 'reflect'
3149    boundary2[dim] = _get_boundary(boundary2[dim])
3151  if max_block_size != _MAX_BLOCK_SIZE_RECURSING:
3152    array = src_gamma2.decode(array, precision)
3153    for dim in range(grid_ndim):
3154      assert not is_minification
3155      if filter2[dim].requires_digital_filter:
3156        array = _apply_digital_filter_1d(
3157            array, gridtype2[dim], boundary2[dim], cval, filter2[dim], axis=dim
3158        )
3160  if > max_block_size > 0:
3161    block_shape = _block_shape_with_min_size(resampled_shape, max_block_size)
3162    if debug:
3163      print(f'(resample: splitting coords into blocks {block_shape}).')
3164    coord_blocks = _split_array_into_blocks(coords, block_shape)
3166    def process_block(coord_block: _NDArray) -> _Array:
3167      return resample(
3168          array,
3169          coord_block,
3170          gridtype=gridtype2,
3171          boundary=boundary2,
3172          cval=cval,
3173          filter=filter2,
3174          prefilter=prefilter2,
3175          src_gamma='identity',
3176          dst_gamma=dst_gamma2,
3177          jacobian=jacobian,
3178          precision=precision,
3179          dtype=dtype,
3180          max_block_size=_MAX_BLOCK_SIZE_RECURSING,
3181      )
3183    result_blocks = _map_function_over_blocks(coord_blocks, process_block)
3184    array = _merge_array_from_blocks(result_blocks)
3185    return array
3187  # A concrete example of upsampling:
3188  #   array = np.ones((5, 7, 3))  # source RGB image has height=5 width=7
3189  #   coords = np.random.default_rng(0).random((8, 9, 2))  # output RGB image has height=8 width=9
3190  #   resample(array, coords, filter=('cubic', 'lanczos3'))
3191  #   grid_shape = 5, 7  grid_ndim = 2
3192  #   resampled_shape = 8, 9  resampled_ndim = 2
3193  #   sample_shape = (3,)
3194  #   src_float_index.shape = 8, 9
3195  #   src_first_index.shape = 8, 9
3196  #   sample_index.shape = (4,) for dim == 0, then (6,) for dim == 1
3197  #   weight = [shape(8, 9, 4), shape(8, 9, 6)]
3198  #   src_index = [shape(8, 9, 4), shape(8, 9, 6)]
3200  # Both:[shape(8, 9, 4), shape(8, 9, 6)]
3201  weight: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3202  src_index: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3203  uses_cval = False
3204  all_num_samples = []  # will be [4, 6]
3206  for dim in range(grid_ndim):
3207    src_size = grid_shape[dim]  # scalar
3208    coords_dim = coords[..., dim]  # (8, 9)
3209    radius = filter2[dim].radius  # scalar
3210    num_samples = int(np.ceil(radius * 2))  # scalar
3211    all_num_samples.append(num_samples)
3213    boundary_dim = boundary2[dim]
3214    coords_dim = boundary_dim.preprocess_coordinates(coords_dim)
3216    # Sample positions mapped back to source unit domain [0, 1].
3217    src_float_index = gridtype2[dim].index_from_point(coords_dim, src_size)  # (8, 9)
3218    src_first_index = (
3219        np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
3220        - (num_samples - 1) // 2
3221    )  # (8, 9)
3223    sample_index = np.arange(num_samples, dtype=np.int32)  # (4,) then (6,)
3224    src_index[dim] = src_first_index[..., None] + sample_index  # (8, 9, 4) then (8, 9, 6)
3225    if filter2[dim].name == 'trapezoid':
3226      # (It might require changing the filter radius at every sample.)
3227      raise ValueError('resample() cannot use adaptive `trapezoid` filter.')
3228    if filter2[dim].name == 'impulse':
3229      weight[dim] = np.ones_like(src_index[dim], weight_precision)
3230    else:
3231      x = src_float_index[..., None] - src_index[dim].astype(weight_precision)
3232      weight[dim] = filter2[dim](x).astype(weight_precision, copy=False)
3233      if filter2[dim].name != 'narrowbox' and (
3234          is_minification or not filter2[dim].partition_of_unity
3235      ):
3236        weight[dim] = weight[dim] / weight[dim].sum(axis=-1)[..., None]
3238    src_index[dim], weight[dim] = boundary_dim.apply(
3239        src_index[dim], weight[dim], coords_dim, src_size, gridtype2[dim]
3240    )
3241    if boundary_dim.uses_cval or filter2[dim].name == 'narrowbox':
3242      uses_cval = True
3244  # Gather the samples.
3246  # Recall that src_index = [shape(8, 9, 4), shape(8, 9, 6)].
3247  src_index_expanded = []
3248  for dim in range(grid_ndim):
3249    src_index_dim = np.moveaxis(
3250        src_index[dim].reshape(src_index[dim].shape + (1,) * (grid_ndim - 1)),
3251        resampled_ndim,
3252        resampled_ndim + dim,
3253    )
3254    src_index_expanded.append(src_index_dim)
3255  indices = tuple(src_index_expanded)  # (shape(8, 9, 4, 1), shape(8, 9, 1, 6))
3256  samples = _arr_getitem(array, indices)  # (8, 9, 4, 6, 3)
3258  # Indirectly derive samples.ndim (which is unavailable during Tensorflow grad computation).
3259  samples_ndim = resampled_ndim + grid_ndim + len(sample_shape)
3261  # Compute an Einstein summation over the samples and each of the per-dimension weights.
3263  def label(dims: Iterable[int]) -> str:
3264    return ''.join(chr(ord('a') + i) for i in dims)
3266  operands = [samples]  # (8, 9, 4, 6, 3)
3267  assert samples_ndim < 26  # Letters 'a' through 'z'.
3268  labels = [label(range(samples_ndim))]  # ['abcde']
3269  for dim in range(grid_ndim):
3270    operands.append(weight[dim])  # (8, 9, 4), then (8, 9, 6)
3271    labels.append(label(list(range(resampled_ndim)) + [resampled_ndim + dim]))  # 'abc' then 'abd'
3272  output_label = label(
3273      list(range(resampled_ndim)) + list(range(resampled_ndim + grid_ndim, samples_ndim))
3274  )  # 'abe'
3275  subscripts = ','.join(labels) + '->' + output_label  # 'abcde,abc,abd->abe'
3276  array = _arr_einsum(subscripts, *operands)  # (8, 9, 3)
3278  # Gathering `samples` is the memory bottleneck.  It would be ideal if the gather() and einsum()
3279  # computations could be fused.  In Jax, suggests
3280  # that this may become possible.  In any case, for large outputs it helps to partition the
3281  # evaluation over output tiles (using max_block_size).
3283  if uses_cval:
3284    cval_weight = 1.0 - np.multiply.reduce(
3285        [weight[dim].sum(axis=-1) for dim in range(resampled_ndim)]
3286    )  # (8, 9)
3287    cval_weight_reshaped = cval_weight.reshape(cval_weight.shape + (1,) * len(sample_shape))
3288    array += _make_array((cval_weight_reshaped * cval).astype(precision, copy=False), arraylib)
3290  array = dst_gamma2.encode(array, dtype)
3291  return array
3294def resample_affine(
3295    array: _Array,
3296    /,
3297    shape: Iterable[int],
3298    matrix: _ArrayLike,
3299    *,
3300    gridtype: str | Gridtype | None = None,
3301    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3302    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3303    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3304    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3305    precision: _DTypeLike = None,
3306    dtype: _DTypeLike = None,
3307    **kwargs: Any,
3308) -> _Array:
3309  """Resample a source array using an affinely transformed grid of given shape.
3311  The `matrix` transformation can be linear,
3312    `source_point = matrix @ destination_point`,
3313  or it can be affine where the last matrix column is an offset vector,
3314    `source_point = matrix @ (destination_point, 1.0)`.
3316  Args:
3317    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3318      The array must have numeric type.  The number of grid dimensions is determined from
3319      `matrix.shape[0]`; the remaining dimensions are for each sample value and are all
3320      linearly interpolated.
3321    shape: Dimensions of the desired destination grid.  The number of destination grid dimensions
3322      may be different from that of the source grid.
3323    matrix: 2D array for a linear or affine transform from unit-domain destination points
3324      (in a space with `len(shape)` dimensions) into unit-domain source points (in a space with
3325      `matrix.shape[0]` dimensions).  If the matrix has `len(shape) + 1` columns, the last column
3326      is the affine offset (i.e., translation).
3327    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
3328      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
3329      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
3330    src_gridtype: Placement of samples in the source domain grid for each dimension.
3331      Parameters `gridtype` and `src_gridtype` cannot both be set.
3332    dst_gridtype: Placement of samples in the output domain grid for each dimension.
3333      Parameters `gridtype` and `dst_gridtype` cannot both be set.
3334    filter: The reconstruction kernel for each dimension in `matrix.shape[0]`, specified as either
3335      a filter name in `FILTERS` or a `Filter` instance.
3336    prefilter: The prefilter kernel for each dimension in `len(shape)`, specified as either
3337      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3338      (i.e., minification).  If `None`, it inherits the value of `filter`.
3339    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3340      on `array.dtype` and `dtype`.
3341    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3342      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3343      to the uint range.
3344    **kwargs: Additional parameters for `resample` function.
3346  Returns:
3347    An array of the same class as the source `array`, representing a grid with specified `shape`,
3348    where each grid value is resampled from `array`.  Thus the shape of the returned array is
3349    `shape + array.shape[matrix.shape[0]:]`.
3350  """
3351  if isinstance(array, (tuple, list)):
3352    array = np.asarray(array)
3353  shape = tuple(shape)
3354  matrix = np.asarray(matrix)
3355  dst_ndim = len(shape)
3356  if matrix.ndim != 2:
3357    raise ValueError(f'Array {matrix} is not 2D matrix.')
3358  src_ndim = matrix.shape[0]
3359  # grid_shape = array.shape[:src_ndim]
3360  is_affine = matrix.shape[1] == dst_ndim + 1
3361  if src_ndim > len(array.shape):
3362    raise ValueError(
3363        f'Matrix {matrix} has more rows ({matrix.shape[0]}) than ndim in {array.shape=}.'
3364    )
3365  if matrix.shape[1] != dst_ndim and not is_affine:
3366    raise ValueError(
3367        f'Matrix has {matrix.shape=}, but we expect either {dst_ndim} or {dst_ndim + 1} columns.'
3368    )
3369  src_gridtype2, dst_gridtype2 = _get_gridtypes(
3370      gridtype, src_gridtype, dst_gridtype, src_ndim, dst_ndim
3371  )
3372  prefilter = filter if prefilter is None else prefilter
3373  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), src_ndim)]
3374  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), dst_ndim)]
3375  del src_gridtype, dst_gridtype, filter, prefilter
3376  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3377  precision = _get_precision(precision, [_arr_dtype(array), dtype], [])
3378  weight_precision = _real_precision(precision)
3380  dst_position_list = []  # per dimension
3381  for dim in range(dst_ndim):
3382    dst_size = shape[dim]
3383    dst_index = np.arange(dst_size, dtype=weight_precision)
3384    dst_position_list.append(dst_gridtype2[dim].point_from_index(dst_index, dst_size))
3385  dst_position = np.meshgrid(*dst_position_list, indexing='ij')
3387  linear_matrix = matrix[:, :-1] if is_affine else matrix
3388  src_position = np.tensordot(linear_matrix, dst_position, 1)
3389  coords = np.moveaxis(src_position, 0, -1)
3390  if is_affine:
3391    coords += matrix[:, -1]
3393  # TODO: Based on grid_shape, shape, linear_matrix, and prefilter, determine a
3394  # convolution prefilter and apply it to bandlimit 'array', using boundary for padding.
3396  return resample(
3397      array,
3398      coords,
3399      gridtype=src_gridtype2,
3400      filter=filter2,
3401      prefilter=prefilter2,
3402      precision=precision,
3403      dtype=dtype,
3404      **kwargs,
3405  )
3408def _resize_using_resample(
3409    array: _Array,
3410    /,
3411    shape: Iterable[int],
3412    *,
3413    scale: _ArrayLike = 1.0,
3414    translate: _ArrayLike = 0.0,
3415    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3416    fallback: bool = False,
3417    **kwargs: Any,
3418) -> _Array:
3419  """Use the more general `resample` operation for `resize`, as a debug tool."""
3420  if isinstance(array, (tuple, list)):
3421    array = np.asarray(array)
3422  shape = tuple(shape)
3423  scale = np.broadcast_to(scale, len(shape))
3424  translate = np.broadcast_to(translate, len(shape))
3425  # TODO: let resample() do prefiltering for proper downsampling.
3426  has_minification = np.any(np.array(shape) < array.shape[: len(shape)]) or np.any(scale < 1.0)
3427  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape))]
3428  has_auto_trapezoid = any( == 'trapezoid' for f in filter2)
3429  if fallback and (has_minification or has_auto_trapezoid):
3430    return _original_resize(array, shape, scale=scale, translate=translate, filter=filter, **kwargs)
3431  offset = -translate / scale
3432  matrix = np.concatenate([np.diag(1.0 / scale), offset[:, None]], axis=1)
3433  return resample_affine(array, shape, matrix, filter=filter, **kwargs)
3436def rotation_about_center_in_2d(
3437    src_shape: _ArrayLike,
3438    /,
3439    angle: float,
3440    *,
3441    new_shape: _ArrayLike | None = None,
3442    scale: float = 1.0,
3443) -> _NDArray:
3444  """Return the 3x3 matrix mapping destination into a source unit domain.
3446  The returned matrix accounts for the possibly non-square domain shapes.
3448  Args:
3449    src_shape: Resolution `(ny, nx)` of the source domain grid.
3450    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3451      onto the destination domain.
3452    new_shape: Resolution `(ny, nx)` of the destination domain grid; it defaults to `src_shape`.
3453    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3454  """
3456  def translation_matrix(vector: _NDArray) -> _NDArray:
3457    matrix = np.eye(len(vector) + 1)
3458    matrix[:-1, -1] = vector
3459    return matrix
3461  def scaling_matrix(scale: _NDArray) -> _NDArray:
3462    return np.diag(tuple(scale) + (1.0,))
3464  def rotation_matrix_2d(angle: float) -> _NDArray:
3465    cos, sin = np.cos(angle), np.sin(angle)
3466    return np.array([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
3468  src_shape = np.asarray(src_shape)
3469  new_shape = src_shape if new_shape is None else np.asarray(new_shape)
3470  _check_eq(src_shape.shape, (2,))
3471  _check_eq(new_shape.shape, (2,))
3472  half = np.array([0.5, 0.5])
3473  matrix = (
3474      translation_matrix(half)
3475      @ scaling_matrix(min(src_shape) / src_shape)
3476      @ rotation_matrix_2d(angle)
3477      @ scaling_matrix(scale * new_shape / min(new_shape))
3478      @ translation_matrix(-half)
3479  )
3480  assert np.allclose(matrix[-1], [0.0, 0.0, 1.0])
3481  return matrix
3484def rotate_image_about_center(
3485    image: _NDArray,
3486    /,
3487    angle: float,
3488    *,
3489    new_shape: _ArrayLike | None = None,
3490    scale: float = 1.0,
3491    num_rotations: int = 1,
3492    **kwargs: Any,
3493) -> _NDArray:
3494  """Return a copy of `image` rotated about its center.
3496  Args:
3497    image: Source grid samples; the first two dimensions are spatial (ny, nx).
3498    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3499      onto the destination domain.
3500    new_shape: Resolution `(ny, nx)` of the output grid; it defaults to `image.shape[:2]`.
3501    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3502    num_rotations: Number of rotations (each by `angle`).  Successive resamplings are useful in
3503      analyzing the filtering quality.
3504    **kwargs: Additional parameters for `resample_affine`.
3505  """
3506  new_shape = image.shape[:2] if new_shape is None else np.asarray(new_shape)
3507  matrix = rotation_about_center_in_2d(image.shape[:2], angle, new_shape=new_shape, scale=scale)
3508  for _ in range(num_rotations):
3509    image = resample_affine(image, new_shape, matrix[:-1], **kwargs)
3510  return image
3513def pil_image_resize(
3514    array: _ArrayLike,
3515    /,
3516    shape: Iterable[int],
3517    *,
3518    filter: str,
3519    boundary: str = 'natural',
3520    cval: float = 0.0,
3521) -> _NDArray:
3522  """Invoke `PIL.Image.resize` using the same parameters as `resize`."""
3523  import PIL.Image
3525  if boundary != 'natural':
3526    raise ValueError(f"{boundary=} must equal 'natural'.")
3527  del cval
3528  array = np.asarray(array)
3529  assert 1 <= array.ndim <= 3
3530  assert np.issubdtype(array.dtype, np.floating)
3531  shape = tuple(shape)
3532  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3533  if array.ndim == 1:
3534    return pil_image_resize(array[None], (1, *shape), filter=filter)[0]
3535  if not hasattr(PIL.Image, 'Resampling'):  # Pillow<9.0
3536    PIL.Image.Resampling = PIL.Image
3537  filters = {
3538      'impulse': PIL.Image.Resampling.NEAREST,
3539      'box': PIL.Image.Resampling.BOX,
3540      'triangle': PIL.Image.Resampling.BILINEAR,
3541      'hamming1': PIL.Image.Resampling.HAMMING,
3542      'cubic': PIL.Image.Resampling.BICUBIC,
3543      'lanczos3': PIL.Image.Resampling.LANCZOS,
3544  }
3545  if filter not in filters:
3546    raise ValueError(f'{filter=} not in {filters=}.')
3547  pil_resample = filters[filter]
3548  if array.ndim == 2:
3549    return np.array(
3550        PIL.Image.fromarray(array).resize(shape[::-1], resample=pil_resample), array.dtype
3551    )
3552  stack = []
3553  for channel in np.moveaxis(array, -1, 0):
3554    pil_image = PIL.Image.fromarray(channel).resize(shape[::-1], resample=pil_resample)
3555    stack.append(np.array(pil_image, array.dtype))
3556  return np.dstack(stack)
3559def cv_resize(
3560    array: _ArrayLike,
3561    /,
3562    shape: Iterable[int],
3563    *,
3564    filter: str,
3565    boundary: str = 'clamp',
3566    cval: float = 0.0,
3567) -> _NDArray:
3568  """Invoke `cv.resize` using the same parameters as `resize`."""
3569  import cv2 as cv
3571  if boundary != 'clamp':
3572    raise ValueError(f"{boundary=} must equal 'clamp'.")
3573  del cval
3574  array = np.asarray(array)
3575  assert 1 <= array.ndim <= 3
3576  shape = tuple(shape)
3577  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3578  if array.ndim == 1:
3579    return cv_resize(array[None], (1, *shape), filter=filter)[0]
3580  filters = {
3581      'impulse': cv.INTER_NEAREST,  # Or consider cv.INTER_NEAREST_EXACT.
3582      'triangle': cv.INTER_LINEAR_EXACT,  # Or just cv.INTER_LINEAR.
3583      'trapezoid': cv.INTER_AREA,
3584      'sharpcubic': cv.INTER_CUBIC,
3585      'lanczos4': cv.INTER_LANCZOS4,
3586  }
3587  if filter not in filters:
3588    raise ValueError(f'{filter=} not in {filters=}.')
3589  interpolation = filters[filter]
3590  result = cv.resize(array, shape[::-1], interpolation=interpolation)
3591  if array.ndim == 3 and result.ndim == 2:
3592    assert array.shape[2] == 1
3593    return result[..., None]  # Add back the last dimension dropped by cv.resize().
3594  return result
3597def scipy_ndimage_resize(
3598    array: _ArrayLike,
3599    /,
3600    shape: Iterable[int],
3601    *,
3602    filter: str,
3603    boundary: str = 'reflect',
3604    cval: float = 0.0,
3605    scale: float | Iterable[float] = 1.0,
3606    translate: float | Iterable[float] = 0.0,
3607) -> _NDArray:
3608  """Invoke `scipy.ndimage.map_coordinates` using the same parameters as `resize`."""
3609  array = np.asarray(array)
3610  shape = tuple(shape)
3611  assert 1 <= len(shape) <= array.ndim
3612  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3613  if filter not in filters:
3614    raise ValueError(f'{filter=} not in {filters=}.')
3615  order = filters[filter]
3616  boundaries = {'reflect': 'reflect', 'wrap': 'grid-wrap', 'clamp': 'nearest', 'border': 'constant'}
3617  if boundary not in boundaries:
3618    raise ValueError(f'{boundary=} not in {boundaries=}.')
3619  mode = boundaries[boundary]
3620  shape_all = shape + array.shape[len(shape) :]
3621  coords = np.moveaxis(np.indices(shape_all, array.dtype), 0, -1)
3622  coords[..., : len(shape)] = (
3623      (coords[..., : len(shape)] + 0.5) / shape - np.asarray(translate)
3624  ) / np.asarray(scale) * np.array(array.shape)[: len(shape)] - 0.5
3625  coords = np.moveaxis(coords, -1, 0)
3626  return scipy.ndimage.map_coordinates(array, coords, order=order, mode=mode, cval=cval)
3629def skimage_transform_resize(
3630    array: _ArrayLike,
3631    /,
3632    shape: Iterable[int],
3633    *,
3634    filter: str,
3635    boundary: str = 'reflect',
3636    cval: float = 0.0,
3637) -> _NDArray:
3638  """Invoke `skimage.transform.resize` using the same parameters as `resize`."""
3639  import skimage.transform
3641  array = np.asarray(array)
3642  shape = tuple(shape)
3643  assert 1 <= len(shape) <= array.ndim
3644  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3645  if filter not in filters:
3646    raise ValueError(f'{filter=} not in {filters=}.')
3647  order = filters[filter]
3648  boundaries = {'reflect': 'symmetric', 'wrap': 'wrap', 'clamp': 'edge', 'border': 'constant'}
3649  if boundary not in boundaries:
3650    raise ValueError(f'{boundary=} not in {boundaries=}.')
3651  mode = boundaries[boundary]
3652  shape_all = shape + array.shape[len(shape) :]
3653  # Default anti_aliasing=None automatically enables (poor) Gaussian prefilter if downsampling.
3654  # clip=False is the default behavior in `resampler` if the output type is non-integer.
3655  return skimage.transform.resize(
3656      array, shape_all, order=order, mode=mode, cval=cval, clip=False
3657  )  # type: ignore[no-untyped-call]
3661    'impulse': 'nearest',
3662    'trapezoid': 'area',
3663    'triangle': 'bilinear',
3664    'mitchell': 'mitchellcubic',
3665    'cubic': 'bicubic',
3666    'lanczos3': 'lanczos3',
3667    'lanczos5': 'lanczos5',
3668    # GaussianFilter(0.5): 'gaussian',  # radius_4 > desired_radius_3.
3672def tf_image_resize(
3673    array: _ArrayLike,
3674    /,
3675    shape: Iterable[int],
3676    *,
3677    filter: str,
3678    boundary: str = 'natural',
3679    cval: float = 0.0,
3680    antialias: bool = True,
3681) -> _TensorflowTensor:
3682  """Invoke `tf.image.resize` using the same parameters as `resize`."""
3683  import tensorflow as tf
3686    raise ValueError(f'{filter=} not in {_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER=}.')
3687  if boundary != 'natural':
3688    raise ValueError(f"{boundary=} must equal 'natural'.")
3689  del cval
3690  array2 = tf.convert_to_tensor(array)
3691  ndim = len(array2.shape)
3692  del array
3693  assert 1 <= ndim <= 3
3694  shape = tuple(shape)
3695  _check_eq(len(shape), 2 if ndim >= 2 else 1)
3696  match ndim:
3697    case 1:
3698      return tf_image_resize(array2[None], (1, *shape), filter=filter, antialias=antialias)[0]
3699    case 2:
3700      return tf_image_resize(array2[..., None], shape, filter=filter, antialias=antialias)[..., 0]
3701    case _:
3703      return tf.image.resize(array2, shape, method=method, antialias=antialias)
3707    'impulse': 'nearest-exact',  # ('nearest' matches buggy OpenCV's INTER_NEAREST)
3708    'trapezoid': 'area',
3709    'triangle': 'bilinear',
3710    'sharpcubic': 'bicubic',
3714def torch_nn_resize(
3715    array: _ArrayLike,
3716    /,
3717    shape: Iterable[int],
3718    *,
3719    filter: str,
3720    boundary: str = 'clamp',
3721    cval: float = 0.0,
3722    antialias: bool = False,
3723) -> _TorchTensor:
3724  """Invoke `torch.nn.functional.interpolate` using the same parameters as `resize`."""
3725  import torch
3728    raise ValueError(f'{filter=} not in {_TORCH_INTERPOLATE_MODE_FROM_FILTER=}.')
3729  if boundary != 'clamp':
3730    raise ValueError(f"{boundary=} must equal 'clamp'.")
3731  del cval
3732  a = torch.as_tensor(array)
3733  del array
3734  assert 1 <= a.ndim <= 3
3735  shape = tuple(shape)
3736  _check_eq(len(shape), 2 if a.ndim >= 2 else 1)
3739  def local_resize(a: _TorchTensor) -> _TorchTensor:
3740    # For upsampling, BILINEAR antialias is same PSNR and slower,
3741    #  and BICUBIC antialias is worse PSNR and faster.
3742    # For downsampling, antialias improves PSNR for both BILINEAR and BICUBIC.
3743    # Default align_corners=None corresponds to False which is what we desire.
3744    return torch.nn.functional.interpolate(a, shape, mode=mode, antialias=antialias)
3746  match a.ndim:
3747    case 1:
3748      shape = (1, *shape)
3749      return local_resize(a[None, None, None])[0, 0, 0]
3750    case 2:
3751      return local_resize(a[None, None])[0, 0]
3752    case _:
3753      return local_resize(a.moveaxis(2, 0)[None])[0].moveaxis(0, 2)
3756def jax_image_resize(
3757    array: _ArrayLike,
3758    /,
3759    shape: Iterable[int],
3760    *,
3761    filter: str,
3762    boundary: str = 'natural',
3763    cval: float = 0.0,
3764    scale: float | Iterable[float] = 1.0,
3765    translate: float | Iterable[float] = 0.0,
3766) -> _JaxArray:
3767  """Invoke `jax.image.scale_and_translate` using the same parameters as `resize`."""
3768  import jax.image
3769  import jax.numpy as jnp
3771  filters = 'triangle cubic lanczos3 lanczos5'.split()
3772  if filter not in filters:
3773    raise ValueError(f'{filter=} not in {filters=}.')
3774  if boundary != 'natural':
3775    raise ValueError(f"{boundary=} must equal 'natural'.")
3776  # When `scale` or `translate` are applied, any region outside the unit domain is assigned value 0.
3777  # To be consistent, the parameter `cval` must be zero.
3778  if scale != 1.0 and cval != 0.0:
3779    raise ValueError(f'Non-unity {scale=} requires that {cval=} be zero.')
3780  if translate != 0.0 and cval != 0.0:
3781    raise ValueError(f'Nonzero {translate=} requires that {cval=} be zero.')
3782  array2 = jnp.asarray(array)
3783  del array
3784  shape = tuple(shape)
3785  assert len(shape) <= array2.ndim
3786  completed_shape = shape + (1,) * (array2.ndim - len(shape))
3787  spatial_dims = list(range(len(shape)))
3788  scale2 = np.broadcast_to(np.array(scale), len(shape))
3789  scale2 = scale2 / np.array(array2.shape[: len(shape)]) * np.array(shape)
3790  translate2 = np.broadcast_to(np.array(translate), len(shape))
3791  translate2 = translate2 * np.array(shape)
3792  return jax.image.scale_and_translate(
3793      array2, completed_shape, spatial_dims, scale2, translate2, filter
3794  )
3798    'resampler.resize': resize,
3799    'PIL.Image.resize': pil_image_resize,
3800    'cv.resize': cv_resize,
3801    'scipy.ndimage.map_coordinates': scipy_ndimage_resize,
3802    'skimage.transform.resize': skimage_transform_resize,
3803    'tf.image.resize': tf_image_resize,
3804    'torch.nn.functional.interpolate': torch_nn_resize,
3805    'jax.image.scale_and_translate': jax_image_resize,
3809def _resizer_is_available(library_function: str) -> bool:
3810  """Return whether the resizer is available as an installed package."""
3811  top_name = library_function.split('.', 1)[0]
3812  module = {'PIL': 'Pillow', 'cv': 'cv2', 'tf': 'tensorflow'}.get(top_name, top_name)
3813  return importlib.util.find_spec(module) is not None
3816_RESIZERS = {
3817    library_function: resizer
3818    for library_function, resizer in _CANDIDATE_RESIZERS.items()
3819    if _resizer_is_available(library_function)
3823def _find_closest_filter(filter: str, resizer: Callable[..., Any]) -> str:
3824  """Return the filter supported by `resizer` (i.e., `*_resize`) that is closest to `filter`."""
3825  match filter:
3826    case 'box_like':
3827      return {
3828          cv_resize: 'trapezoid',
3829          skimage_transform_resize: 'box',
3830          tf_image_resize: 'trapezoid',
3831          torch_nn_resize: 'trapezoid',
3832      }.get(resizer, 'box')
3833    case 'cubic_like':
3834      return {
3835          cv_resize: 'sharpcubic',
3836          scipy_ndimage_resize: 'cardinal3',
3837          skimage_transform_resize: 'cardinal3',
3838          torch_nn_resize: 'sharpcubic',
3839      }.get(resizer, 'cubic')
3840    case 'high_quality':
3841      return {
3842          pil_image_resize: 'lanczos3',
3843          cv_resize: 'lanczos4',
3844          scipy_ndimage_resize: 'cardinal5',
3845          skimage_transform_resize: 'cardinal5',
3846          torch_nn_resize: 'sharpcubic',
3847      }.get(resizer, 'lanczos5')
3848    case _:
3849      return filter
3852# For Emacs:
3853# Local Variables:
3854# fill-column: 100
3855# End:
ARRAYLIBS: list[str] = ['numpy']

Array libraries supported automatically in the resize and resampling operations.

  • The library is selected automatically based on the type of the array function parameter.

  • The class _Arraylib provides library-specific implementations of needed basic functions.

  • The _arr_*() functions dispatch the _Arraylib methods based on the array type.

