resampler

resampler: fast differentiable resizing and warping of arbitrary grids.

   1"""resampler: fast differentiable resizing and warping of arbitrary grids.
   2
   3.. include:: ../README.md
   4"""
   5
   6__docformat__ = 'google'
   7__version__ = '1.0.1'
   8__version_info__ = tuple(int(num) for num in __version__.split('.'))
   9
  10from collections.abc import Callable, Iterable, Sequence
  11import abc
  12import dataclasses
  13import functools
  14import importlib
  15import itertools
  16import math
  17import os
  18import sys
  19import types
  20import typing
  21from typing import Any, Generic, Literal, TypeAlias, Union
  22
  23import numpy as np
  24import numpy.typing
  25import scipy.interpolate
  26import scipy.linalg
  27import scipy.ndimage
  28import scipy.sparse
  29import scipy.sparse.linalg
  30
  31
  32def _noop_decorator(*args: Any, **kwargs: Any) -> Any:
  33  """Return function decorated with no-operation; invocable with or without args."""
  34  if len(args) != 1 or not callable(args[0]) or kwargs:
  35    return _noop_decorator  # Decorator is invoked with arguments; ignore them.
  36  func: Callable[..., Any] = args[0]
  37  return func
  38
  39
  40try:
  41  import numba
  42except ModuleNotFoundError:
  43  numba = sys.modules['numba'] = types.ModuleType('numba')
  44  numba.njit = _noop_decorator  # type: ignore[attr-defined]
  45using_numba = hasattr(numba, 'jit')
  46
  47if TYPE_CHECKING:
  48  import jax.numpy
  49  import tensorflow as tf
  50  import torch
  51
  52  _DType: TypeAlias = np.dtype[Any]  # (Requires Python 3.9 or TYPE_CHECKING.)
  53  _NDArray: TypeAlias = numpy.NDArray[Any]
  54  _DTypeLike: TypeAlias = numpy.DTypeLike
  55  _ArrayLike: TypeAlias = numpy.ArrayLike
  56  _TensorflowTensor: TypeAlias = tf.Tensor
  57  _TorchTensor: TypeAlias = torch.Tensor
  58  _JaxArray: TypeAlias = jax.numpy.ndarray
  59
  60else:
  61  # Typically, create named types for use in the `pdoc` documentation.
  62  # But here, these are superseded by the declarations in __init__.pyi!
  63  _DType: TypeAlias = Any
  64  _NDArray: TypeAlias = Any
  65  _DTypeLike: TypeAlias = Any
  66  _ArrayLike: TypeAlias = Any
  67  _TensorflowTensor: TypeAlias = Any
  68  _TorchTensor: TypeAlias = Any
  69  _JaxArray: TypeAlias = Any
  70
  71_Array = TypeVar('_Array', _NDArray, _TensorflowTensor, _TorchTensor, _JaxArray)
  72_AnyArray = Union[_NDArray, _TensorflowTensor, _TorchTensor, _JaxArray]
  73
  74
  75def _check_eq(a: Any, b: Any, /) -> None:
  76  """If the two values or arrays are not equal, raise an exception with a useful message."""
  77  are_equal = np.all(a == b) if isinstance(a, np.ndarray) else a == b
  78  if not are_equal:
  79    raise AssertionError(f'{a!r} == {b!r}')
  80
  81
  82def _real_precision(dtype: _DTypeLike, /) -> _DType:
  83  """Return the type of the real part of a complex number."""
  84  return np.array([], dtype).real.dtype
  85
  86
  87def _complex_precision(dtype: _DTypeLike, /) -> _DType:
  88  """Return a complex type to represent a non-complex type."""
  89  return np.result_type(dtype, np.complex64)
  90
  91
  92def _get_precision(
  93    precision: _DTypeLike, dtypes: list[_DType], weight_dtypes: list[_DType], /
  94) -> _DType:
  95  """Return dtype based on desired precision or on data and weight types."""
  96  precision2 = np.dtype(
  97      precision if precision is not None else np.result_type(np.float32, *dtypes, *weight_dtypes)
  98  )
  99  if not np.issubdtype(precision2, np.inexact):
 100    raise ValueError(f'Precision {precision2} is not floating or complex.')
 101  check_complex = [precision2, *dtypes]
 102  is_complex = [np.issubdtype(dtype, np.complexfloating) for dtype in check_complex]
 103  if len(set(is_complex)) != 1:
 104    s_types = ','.join(str(dtype) for dtype in check_complex)
 105    raise ValueError(f'Types {s_types} must be all real or all complex.')
 106  return precision2
 107
 108
 109def _sinc(x: _ArrayLike, /) -> _NDArray:
 110  """Return the value `np.sinc(x)` but improved to:
 111  (1) ignore underflow that occurs at 0.0 for np.float32, and
 112  (2) output exact zero for integer input values.
 113
 114  >>> _sinc(np.array([-3, -2, -1, 0], np.float32))
 115  array([0., 0., 0., 1.], dtype=float32)
 116
 117  >>> _sinc(np.array([-3, -2, -1, 0]))
 118  array([0., 0., 0., 1.])
 119
 120  >>> _sinc(0)
 121  1.0
 122  """
 123  x = np.asarray(x)
 124  x_is_scalar = x.ndim == 0
 125  with np.errstate(under='ignore'):
 126    result = np.sinc(np.atleast_1d(x))
 127    result[x == np.floor(x)] = 0.0
 128    result[x == 0] = 1.0
 129    return result.item() if x_is_scalar else result
 130
 131
 132def _is_symmetric(matrix: scipy.sparse.spmatrix, /, tol: float = 1e-10) -> bool:
 133  """Return True if the sparse matrix is symmetric."""
 134  norm: float = scipy.sparse.linalg.norm(matrix - matrix.T, np.inf)
 135  return norm <= tol
 136
 137
 138def _cache_sampled_1d_function(
 139    xmin: float,
 140    xmax: float,
 141    *,
 142    num_samples: int = 3_600,
 143    enable: bool = True,
 144) -> Callable[[Callable[[_ArrayLike], _NDArray]], Callable[[_ArrayLike], _NDArray]]:
 145  """Function decorator to linearly interpolate cached function values."""
 146  # Speed unchanged up to num_samples=12_000, then slow decrease until 100_000.
 147
 148  def wrap_it(func: Callable[[_ArrayLike], _NDArray]) -> Callable[[_ArrayLike], _NDArray]:
 149    if not enable:
 150      return func
 151
 152    dx = (xmax - xmin) / num_samples
 153    x = np.linspace(xmin, xmax + dx, num_samples + 2, dtype=np.float32)
 154    samples_func = func(x)
 155    assert np.all(samples_func[[0, -1, -2]] == 0.0)
 156
 157    @functools.wraps(func)
 158    def interpolate_using_cached_samples(x: _ArrayLike) -> _NDArray:
 159      x = np.asarray(x)
 160      index_float = np.clip((x - xmin) / dx, 0.0, num_samples)
 161      index = index_float.astype(np.int64)
 162      frac = np.subtract(index_float, index, dtype=np.float32)
 163      return (1 - frac) * samples_func[index] + frac * samples_func[index + 1]
 164
 165    return interpolate_using_cached_samples
 166
 167  return wrap_it
 168
 169
 170class _DownsampleIn2dUsingBoxFilter:
 171  """Fast 2D box-filter downsampling using cached numba-jitted functions."""
 172
 173  def __init__(self) -> None:
 174    # Downsampling function for params (dtype, block_height, block_width, ch).
 175    self._jitted_function: dict[tuple[_DType, int, int, int], Callable[[_NDArray], _NDArray]] = {}
 176
 177  def __call__(self, array: _NDArray, shape: tuple[int, int]) -> _NDArray:
 178    assert using_numba
 179    assert array.ndim in (2, 3), array.ndim
 180    _check_eq(len(shape), 2)
 181    dtype = array.dtype
 182    a = array[..., None] if array.ndim == 2 else array
 183    height, width, ch = a.shape
 184    new_height, new_width = shape
 185    if height % new_height != 0 or width % new_width != 0:
 186      raise ValueError(f'Shape {array.shape} not a multiple of {shape}.')
 187    block_height, block_width = height // new_height, width // new_width
 188
 189    def func(array: _NDArray) -> _NDArray:
 190      new_height = array.shape[0] // block_height
 191      new_width = array.shape[1] // block_width
 192      result = np.empty((new_height, new_width, ch), dtype)
 193      totals = np.empty(ch, dtype)
 194      factor = dtype.type(1.0 / (block_height * block_width))
 195      for y in numba.prange(new_height):  # pylint: disable=not-an-iterable
 196        for x in range(new_width):
 197          # Introducing "y2, x2 = y * block_height, x * block_width" is actually slower.
 198          if ch == 1:  # All the branches involve compile-time constants.
 199            total = dtype.type(0.0)
 200            for yy in range(block_height):
 201              for xx in range(block_width):
 202                total += array[y * block_height + yy, x * block_width + xx, 0]
 203            result[y, x, 0] = total * factor
 204          elif ch == 3:
 205            total0 = total1 = total2 = dtype.type(0.0)
 206            for yy in range(block_height):
 207              for xx in range(block_width):
 208                total0 += array[y * block_height + yy, x * block_width + xx, 0]
 209                total1 += array[y * block_height + yy, x * block_width + xx, 1]
 210                total2 += array[y * block_height + yy, x * block_width + xx, 2]
 211            result[y, x, 0] = total0 * factor
 212            result[y, x, 1] = total1 * factor
 213            result[y, x, 2] = total2 * factor
 214          elif block_height * block_width >= 9:
 215            for c in range(ch):
 216              totals[c] = 0.0
 217            for yy in range(block_height):
 218              for xx in range(block_width):
 219                for c in range(ch):
 220                  totals[c] += array[y * block_height + yy, x * block_width + xx, c]
 221            for c in range(ch):
 222              result[y, x, c] = totals[c] * factor
 223          else:
 224            for c in range(ch):
 225              total = dtype.type(0.0)
 226              for yy in range(block_height):
 227                for xx in range(block_width):
 228                  total += array[y * block_height + yy, x * block_width + xx, c]
 229              result[y, x, c] = total * factor
 230      return result
 231
 232    signature = dtype, block_height, block_width, ch
 233    jitted_function = self._jitted_function.get(signature)
 234    if not jitted_function:
 235      if 0:
 236        print(f'Creating numba jit-wrapper for {signature}.')
 237      jitted_function = numba.njit(func, parallel=True, fastmath=True, cache=True)
 238      self._jitted_function[signature] = jitted_function
 239
 240    try:
 241      result = jitted_function(a)
 242    except RuntimeError:
 243      message = (
 244          'resampler: This runtime error may be due to a corrupt resampler/__pycache__;'
 245          ' try deleting that directory.'
 246      )
 247      print(message, file=sys.stdout, flush=True)
 248      print(message, file=sys.stderr, flush=True)
 249      raise
 250
 251    return result[..., 0] if array.ndim == 2 else result
 252
 253
 254_downsample_in_2d_using_box_filter = _DownsampleIn2dUsingBoxFilter()
 255
 256
 257@numba.njit(nogil=True, fastmath=True, cache=True)  # type: ignore[untyped-decorator]
 258def _numba_serial_csr_dense_mult(
 259    indptr: _NDArray,
 260    indices: _NDArray,
 261    data: _NDArray,
 262    src: _NDArray,
 263    dst: _NDArray,
 264) -> None:
 265  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 266
 267  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 268  """
 269  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 270  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 271  acc = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 272  for i in range(dst.shape[0]):
 273    acc[:] = 0
 274    for jj in range(indptr[i], indptr[i + 1]):
 275      j = indices[jj]
 276      acc += data[jj] * src[j]
 277    dst[i] = acc
 278
 279
 280# I tried using the "minimal" "parallel=" config but this did not result in any jit speedup:
 281#  parallel=dict(comprehension=False, prange=True, numpy=True, reduction=False,
 282#                setitem=False, stencil=False, fusion=False)
 283@numba.njit(parallel=True, fastmath=True, cache=True)  # type: ignore[untyped-decorator]
 284def _numba_parallel_csr_dense_mult(
 285    indptr: _NDArray,
 286    indices: _NDArray,
 287    data: _NDArray,
 288    src: _NDArray,
 289    dst: _NDArray,
 290) -> None:
 291  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 292
 293  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 294  The introduction of parallel omp threads provides another 2-4x speedup.
 295  However, "parallel=True" leads to slow jitting (~3 s), so we cache the jitted code on disk.
 296  """
 297  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 298  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 299  acc0 = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 300  # Default is static scheduling, which is fine.
 301  for i in numba.prange(dst.shape[0]):  # pylint: disable=not-an-iterable
 302    acc = np.zeros_like(acc0)  # Numba automatically hoists the allocation outside the loop.
 303    for jj in range(indptr[i], indptr[i + 1]):
 304      j = indices[jj]
 305      acc += data[jj] * src[j]
 306    dst[i] = acc
 307
 308
 309@dataclasses.dataclass
 310class _Arraylib(abc.ABC, Generic[_Array]):
 311  """Abstract base class for abstraction of array libraries."""
 312
 313  arraylib: str
 314  """Name of array library (e.g., `'numpy'`, `'tensorflow'`, `'torch'`, `'jax'`)."""
 315
 316  array: _Array
 317
 318  @staticmethod
 319  @abc.abstractmethod
 320  def recognize(array: Any) -> bool:
 321    """Return True if `array` is recognized by this _Arraylib."""
 322
 323  @abc.abstractmethod
 324  def numpy(self) -> _NDArray:
 325    """Return a `numpy` version of `self.array`."""
 326
 327  @abc.abstractmethod
 328  def dtype(self) -> _DType:
 329    """Return the equivalent of `self.array.dtype` as a `numpy` `dtype`."""
 330
 331  @abc.abstractmethod
 332  def astype(self, dtype: _DTypeLike) -> _Array:
 333    """Return the equivalent of `self.array.astype(dtype, copy=False)` with `numpy` `dtype`."""
 334
 335  def reshape(self, shape: tuple[int, ...]) -> _Array:
 336    """Return the equivalent of `self.array.reshape(shape)`."""
 337    return self.array.reshape(shape)
 338
 339  def possibly_make_contiguous(self) -> _Array:
 340    """Return a contiguous copy of `self.array` or just `self.array` if already contiguous."""
 341    return self.array
 342
 343  @abc.abstractmethod
 344  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _Array:
 345    """Return the equivalent of `self.array.clip(low, high, dtype=dtype)` with `numpy` `dtype`."""
 346
 347  @abc.abstractmethod
 348  def square(self) -> _Array:
 349    """Return the equivalent of `np.square(self.array)`."""
 350
 351  @abc.abstractmethod
 352  def sqrt(self) -> _Array:
 353    """Return the equivalent of `np.sqrt(self.array)`."""
 354
 355  def getitem(self, indices: Any) -> _Array:
 356    """Return the equivalent of `self.array[indices]` (a "gather" operation)."""
 357    return self.array[indices]
 358
 359  @abc.abstractmethod
 360  def where(self, if_true: Any, if_false: Any) -> _Array:
 361    """Return the equivalent of `np.where(self.array, if_true, if_false)`."""
 362
 363  @abc.abstractmethod
 364  def transpose(self, axes: Sequence[int]) -> _Array:
 365    """Return the equivalent of `np.transpose(self.array, axes)`."""
 366
 367  @abc.abstractmethod
 368  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 369    """Return the best order in which to process dims for resizing `self.array` to `dst_shape`."""
 370
 371  @abc.abstractmethod
 372  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _Array:
 373    """Return the multiplication of the `sparse` matrix and `self.array`."""
 374
 375  @staticmethod
 376  @abc.abstractmethod
 377  def concatenate(arrays: Sequence[_Array], axis: int) -> _Array:
 378    """Return the equivalent of `np.concatenate(arrays, axis)`."""
 379
 380  @staticmethod
 381  @abc.abstractmethod
 382  def einsum(subscripts: str, *operands: _Array) -> _Array:
 383    """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 384
 385  @staticmethod
 386  @abc.abstractmethod
 387  def make_sparse_matrix(
 388      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 389  ) -> _Array:
 390    """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 391    However, the indices must be ordered and unique."""
 392
 393
 394class _NumpyArraylib(_Arraylib[_NDArray]):
 395  """Numpy implementation of the array abstraction."""
 396
 397  # pylint: disable=missing-function-docstring
 398
 399  def __init__(self, array: _NDArray) -> None:
 400    super().__init__(arraylib='numpy', array=np.asarray(array))
 401
 402  @staticmethod
 403  def recognize(array: Any) -> bool:
 404    return isinstance(array, (np.ndarray, np.number))
 405
 406  def numpy(self) -> _NDArray:
 407    return self.array
 408
 409  def dtype(self) -> _DType:
 410    dtype: _DType = self.array.dtype
 411    return dtype
 412
 413  def astype(self, dtype: _DTypeLike) -> _NDArray:
 414    return self.array.astype(dtype, copy=False)
 415
 416  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _NDArray:
 417    return self.array.clip(low, high, dtype=dtype)
 418
 419  def square(self) -> _NDArray:
 420    return np.square(self.array)
 421
 422  def sqrt(self) -> _NDArray:
 423    return np.sqrt(self.array)
 424
 425  def where(self, if_true: Any, if_false: Any) -> _NDArray:
 426    condition = self.array
 427    return np.where(condition, if_true, if_false)
 428
 429  def transpose(self, axes: Sequence[int]) -> _NDArray:
 430    return np.transpose(self.array, tuple(axes))
 431
 432  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 433    # Our heuristics: (1) a dimension with small scaling (especially minification) gets priority,
 434    # and (2) timings show preference to resizing dimensions with larger strides first.
 435    # (Of course, tensorflow.Tensor lacks strides, so (2) does not apply.)
 436    # The optimal ordering might be related to the logic in np.einsum_path().  (Unfortunately,
 437    # np.einsum() does not support the sparse multiplications that we require here.)
 438    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 439    strides: Sequence[int] = self.array.strides
 440    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 441
 442    def priority(dim: int) -> float:
 443      scaling = dst_shape[dim] / src_shape[dim]
 444      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 445
 446    return sorted(range(len(src_shape)), key=priority)
 447
 448  def premult_with_sparse(
 449      self, sparse: scipy.sparse.csr_matrix, num_threads: int | Literal['auto']
 450  ) -> _NDArray:
 451    assert self.array.ndim == sparse.ndim == 2 and sparse.shape[1] == self.array.shape[0]
 452    # Empirically faster than with default numba.config.NUMBA_NUM_THREADS (e.g., 24).
 453    if using_numba:
 454      num_threads2 = min(6, os.cpu_count() or 1) if num_threads == 'auto' else num_threads
 455      src = np.ascontiguousarray(self.array)  # Like .ravel() in _mul_multivector().
 456      dtype = np.result_type(sparse.dtype, src.dtype)
 457      dst = np.empty((sparse.shape[0], src.shape[1]), dtype)
 458      num_scalar_multiplies = len(sparse.data) * src.shape[1]
 459      is_small_size = num_scalar_multiplies < 200_000
 460      if is_small_size or num_threads2 == 1:
 461        _numba_serial_csr_dense_mult(sparse.indptr, sparse.indices, sparse.data, src, dst)
 462      else:
 463        numba.set_num_threads(num_threads2)
 464        _numba_parallel_csr_dense_mult(sparse.indptr, sparse.indices, sparse.data, src, dst)
 465      return dst
 466
 467    # Note that sicpy.sparse does not use multithreading.  The "@" operation
 468    # calls _spbase.__matmul__() -> _spbase._mul_dispatch() -> _cs_matrix._mul_multivector() ->
 469    # scipy.sparse._sparsetools.csr_matvecs() in
 470    # https://github.com/scipy/scipy/blob/main/scipy/sparse/sparsetools/csr.h
 471    # which iteratively calls the (in theory, LEVEL 1 BLAS) function axpy() in
 472    # https://github.com/scipy/scipy/blob/main/scipy/sparse/sparsetools/dense.h
 473    return sparse @ self.array
 474
 475  @staticmethod
 476  def concatenate(arrays: Sequence[_NDArray], axis: int) -> _NDArray:
 477    return np.concatenate(arrays, axis)
 478
 479  @staticmethod
 480  def einsum(subscripts: str, *operands: _NDArray) -> _NDArray:
 481    return np.einsum(subscripts, *operands, optimize=True)
 482
 483  @staticmethod
 484  def make_sparse_matrix(
 485      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 486  ) -> _NDArray:
 487    return scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=shape)
 488
 489
 490class _TensorflowArraylib(_Arraylib[_TensorflowTensor]):
 491  """Tensorflow implementation of the array abstraction."""
 492
 493  def __init__(self, array: _NDArray) -> None:
 494    import tensorflow
 495
 496    self.tf = tensorflow
 497    super().__init__(arraylib='tensorflow', array=self.tf.convert_to_tensor(array))
 498
 499  @staticmethod
 500  def recognize(array: Any) -> bool:
 501    # Eager: tensorflow.python.framework.ops.Tensor
 502    # Non-eager: tensorflow.python.ops.resource_variable_ops.ResourceVariable
 503    return type(array).__module__.startswith('tensorflow.')
 504
 505  def numpy(self) -> _NDArray:
 506    return self.array.numpy()
 507
 508  def dtype(self) -> _DType:
 509    return np.dtype(self.array.dtype.as_numpy_dtype)
 510
 511  def astype(self, dtype: _DTypeLike) -> _TensorflowTensor:
 512    return self.tf.cast(self.array, dtype)
 513
 514  def reshape(self, shape: tuple[int, ...]) -> _TensorflowTensor:
 515    return self.tf.reshape(self.array, shape)
 516
 517  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _TensorflowTensor:
 518    array = self.array
 519    if dtype is not None:
 520      array = self.tf.cast(array, dtype)
 521    return self.tf.clip_by_value(array, low, high)
 522
 523  def square(self) -> _TensorflowTensor:
 524    return self.tf.square(self.array)
 525
 526  def sqrt(self) -> _TensorflowTensor:
 527    return self.tf.sqrt(self.array)
 528
 529  def getitem(self, indices: Any) -> _TensorflowTensor:
 530    if isinstance(indices, tuple):
 531      basic = all(isinstance(x, (type(None), type(Ellipsis), int, slice)) for x in indices)
 532      if not basic:
 533        # We require tf.gather_nd(), which unfortunately requires broadcast expansion of indices.
 534        assert all(isinstance(a, np.ndarray) for a in indices)
 535        assert all(a.ndim == indices[0].ndim for a in indices)
 536        broadcast_indices = np.broadcast_arrays(*indices)  # list of np.ndarray
 537        indices_array = np.moveaxis(np.array(broadcast_indices), 0, -1)
 538        return self.tf.gather_nd(self.array, indices_array)
 539    elif _arr_dtype(indices).type in (np.uint8, np.uint16):
 540      indices = self.tf.cast(indices, np.int32)
 541    return self.tf.gather(self.array, indices)
 542
 543  def where(self, if_true: Any, if_false: Any) -> _TensorflowTensor:
 544    condition = self.array
 545    return self.tf.where(condition, if_true, if_false)
 546
 547  def transpose(self, axes: Sequence[int]) -> _TensorflowTensor:
 548    return self.tf.transpose(self.array, tuple(axes))
 549
 550  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 551    # Note that a tensorflow.Tensor does not have strides.
 552    # Our heuristic is to process dimension 1 first iff dimension 0 is upsampling.  Improve?
 553    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 554    dims = list(range(len(src_shape)))
 555    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 556      dims[:2] = [1, 0]
 557    return dims
 558
 559  def premult_with_sparse(
 560      self, sparse: 'tf.sparse.SparseTensor', num_threads: int | Literal['auto']
 561  ) -> _TensorflowTensor:
 562    import tensorflow as tf
 563
 564    del num_threads
 565    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 566      sparse = sparse.with_values(_arr_astype(sparse.values, _arr_dtype(self.array)))
 567    return tf.sparse.sparse_dense_matmul(sparse, self.array)
 568
 569  @staticmethod
 570  def concatenate(arrays: Sequence[_TensorflowTensor], axis: int) -> _TensorflowTensor:
 571    import tensorflow as tf
 572
 573    return tf.concat(arrays, axis)
 574
 575  @staticmethod
 576  def einsum(subscripts: str, *operands: _TensorflowTensor) -> _TensorflowTensor:
 577    import tensorflow as tf
 578
 579    return tf.einsum(subscripts, *operands, optimize='greedy')
 580
 581  @staticmethod
 582  def make_sparse_matrix(
 583      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 584  ) -> _TensorflowTensor:
 585    import tensorflow as tf
 586
 587    indices = np.vstack((row_ind, col_ind)).T
 588    return tf.sparse.SparseTensor(indices, data, shape)
 589
 590
 591class _TorchArraylib(_Arraylib[_TorchTensor]):
 592  """Torch implementation of the array abstraction."""
 593
 594  # pylint: disable=missing-function-docstring
 595
 596  def __init__(self, array: _NDArray) -> None:
 597    import torch
 598
 599    self.torch = torch
 600    super().__init__(arraylib='torch', array=self.torch.as_tensor(array))
 601
 602  @staticmethod
 603  def recognize(array: Any) -> bool:
 604    return type(array).__module__ == 'torch'
 605
 606  def numpy(self) -> _NDArray:
 607    return self.array.numpy()
 608
 609  def dtype(self) -> _DType:
 610    numpy_type = {
 611        self.torch.float32: np.float32,
 612        self.torch.float64: np.float64,
 613        self.torch.complex64: np.complex64,
 614        self.torch.complex128: np.complex128,
 615        self.torch.uint8: np.uint8,  # No uint16, uint32, uint64.
 616        self.torch.int16: np.int16,
 617        self.torch.int32: np.int32,
 618        self.torch.int64: np.int64,
 619    }[self.array.dtype]
 620    return np.dtype(numpy_type)
 621
 622  def astype(self, dtype: _DTypeLike) -> _TorchTensor:
 623    torch_type = {
 624        np.float32: self.torch.float32,
 625        np.float64: self.torch.float64,
 626        np.complex64: self.torch.complex64,
 627        np.complex128: self.torch.complex128,
 628        np.uint8: self.torch.uint8,  # No uint16, uint32, uint64.
 629        np.int16: self.torch.int16,
 630        np.int32: self.torch.int32,
 631        np.int64: self.torch.int64,
 632    }[np.dtype(dtype).type]
 633    return self.array.type(torch_type)
 634
 635  def possibly_make_contiguous(self) -> _TorchTensor:
 636    return self.array.contiguous()
 637
 638  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _TorchTensor:
 639    array = self.array
 640    array = _arr_astype(array, dtype) if dtype is not None else array
 641    return array.clip(low, high)
 642
 643  def square(self) -> _TorchTensor:
 644    return self.array.square()
 645
 646  def sqrt(self) -> _TorchTensor:
 647    return self.array.sqrt()
 648
 649  def getitem(self, indices: Any) -> _TorchTensor:
 650    if not isinstance(indices, tuple):
 651      indices = indices.type(self.torch.int64)
 652    return self.array[indices]  # pylint: disable=unsubscriptable-object
 653
 654  def where(self, if_true: Any, if_false: Any) -> _TorchTensor:
 655    condition = self.array
 656    return if_true.where(condition, if_false)
 657
 658  def transpose(self, axes: Sequence[int]) -> _TorchTensor:
 659    return self.torch.permute(self.array, tuple(axes))
 660
 661  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 662    # Similar to `_NumpyArraylib`.  We access `array.stride()` instead of `array.strides`.
 663    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 664    strides: Sequence[int] = self.array.stride()
 665    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 666
 667    def priority(dim: int) -> float:
 668      scaling = dst_shape[dim] / src_shape[dim]
 669      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 670
 671    return sorted(range(len(src_shape)), key=priority)
 672
 673  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _TorchTensor:
 674    del num_threads
 675    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 676      sparse = _arr_astype(sparse, _arr_dtype(self.array))
 677    return sparse @ self.array  # Calls torch.sparse.mm().
 678
 679  @staticmethod
 680  def concatenate(arrays: Sequence[_TorchTensor], axis: int) -> _TorchTensor:
 681    import torch
 682
 683    return torch.cat(tuple(arrays), axis)
 684
 685  @staticmethod
 686  def einsum(subscripts: str, *operands: _TorchTensor) -> _TorchTensor:
 687    import torch
 688
 689    operands = tuple(torch.as_tensor(operand) for operand in operands)
 690    if any(np.issubdtype(_arr_dtype(array), np.complexfloating) for array in operands):
 691      operands = tuple(
 692          _arr_astype(array, _complex_precision(_arr_dtype(array))) for array in operands
 693      )
 694    return torch.einsum(subscripts, *operands)
 695
 696  @staticmethod
 697  def make_sparse_matrix(
 698      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 699  ) -> _TorchTensor:
 700    import torch
 701
 702    indices = np.vstack((row_ind, col_ind))
 703    return torch.sparse_coo_tensor(torch.as_tensor(indices), torch.as_tensor(data), shape)
 704    # .coalesce() is unnecessary because indices/data are already merged.
 705
 706
 707class _JaxArraylib(_Arraylib[_JaxArray]):
 708  """Jax implementation of the array abstraction."""
 709
 710  def __init__(self, array: _NDArray) -> None:
 711    import jax.numpy
 712
 713    self.jnp = jax.numpy
 714    super().__init__(arraylib='jax', array=self.jnp.asarray(array))
 715
 716  @staticmethod
 717  def recognize(array: Any) -> bool:
 718    # e.g., jaxlib.xla_extension.DeviceArray, jax.interpreters.ad.JVPTracer
 719    return type(array).__module__.startswith(('jaxlib.', 'jax.'))
 720
 721  def numpy(self) -> _NDArray:
 722    # 2023-01-09: jax 0.3.17 "DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead."
 723    # Whereas array.to_py() and np.asarray(array) may return a non-writable np.ndarray,
 724    # np.array(array) always returns a writable array but the copy may be more costly.
 725    # return self.array.to_py()
 726    return np.asarray(self.array)
 727
 728  def dtype(self) -> _DType:
 729    return np.dtype(self.array.dtype)
 730
 731  def astype(self, dtype: _DTypeLike) -> _JaxArray:
 732    return self.array.astype(dtype)  # (copy=False is unavailable)
 733
 734  def possibly_make_contiguous(self) -> _JaxArray:
 735    return self.array.copy()
 736
 737  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _JaxArray:
 738    array = self.array
 739    if dtype is not None:
 740      array = array.astype(dtype)  # (copy=False is unavailable)
 741    return self.jnp.clip(array, low, high)
 742
 743  def square(self) -> _JaxArray:
 744    return self.jnp.square(self.array)
 745
 746  def sqrt(self) -> _JaxArray:
 747    return self.jnp.sqrt(self.array)
 748
 749  def where(self, if_true: Any, if_false: Any) -> _JaxArray:
 750    condition = self.array
 751    return self.jnp.where(condition, if_true, if_false)
 752
 753  def transpose(self, axes: Sequence[int]) -> _JaxArray:
 754    return self.jnp.transpose(self.array, tuple(axes))
 755
 756  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 757    # Jax/XLA does not have strides.  Arrays are contiguous, almost always in C order; see
 758    # https://github.com/google/jax/discussions/7544#discussioncomment-1197038.
 759    # We use a heuristic similar to `_TensorflowArraylib`.
 760    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 761    dims = list(range(len(src_shape)))
 762    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 763      dims[:2] = [1, 0]
 764    return dims
 765
 766  def premult_with_sparse(
 767      self, sparse: 'jax.experimental.sparse.BCOO', num_threads: int | Literal['auto']
 768  ) -> _JaxArray:
 769    del num_threads
 770    return sparse @ self.array  # Calls jax.bcoo_multiply_dense().
 771
 772  @staticmethod
 773  def concatenate(arrays: Sequence[_JaxArray], axis: int) -> _JaxArray:
 774    import jax.numpy as jnp
 775
 776    return jnp.concatenate(arrays, axis)
 777
 778  @staticmethod
 779  def einsum(subscripts: str, *operands: _JaxArray) -> _JaxArray:
 780    import jax.numpy as jnp
 781
 782    return jnp.einsum(subscripts, *operands, optimize='greedy')
 783
 784  @staticmethod
 785  def make_sparse_matrix(
 786      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 787  ) -> _JaxArray:
 788    # https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html
 789    import jax.experimental.sparse
 790
 791    indices = np.vstack((row_ind, col_ind)).T
 792    return jax.experimental.sparse.BCOO(
 793        (data, indices), shape=shape, indices_sorted=True, unique_indices=True
 794    )
 795
 796
 797_CANDIDATE_ARRAYLIBS = {
 798    'numpy': _NumpyArraylib,
 799    'tensorflow': _TensorflowArraylib,
 800    'torch': _TorchArraylib,
 801    'jax': _JaxArraylib,
 802}
 803
 804
 805def is_available(arraylib: str) -> bool:
 806  """Return whether the array library (e.g. 'tensorflow') is available as an installed package."""
 807  # Faster than trying to import it.
 808  return importlib.util.find_spec(arraylib) is not None  # type: ignore[attr-defined]
 809
 810
 811_DICT_ARRAYLIBS = {
 812    arraylib: cls for arraylib, cls in _CANDIDATE_ARRAYLIBS.items() if is_available(arraylib)
 813}
 814
 815ARRAYLIBS = list(_DICT_ARRAYLIBS)
 816"""Array libraries supported automatically in the resize and resampling operations.
 817
 818- The library is selected automatically based on the type of the `array` function parameter.
 819
 820- The class `_Arraylib` provides library-specific implementations of needed basic functions.
 821
 822- The `_arr_*()` functions dispatch the `_Arraylib` methods based on the array type.
 823"""
 824
 825
 826def _as_arr(array: _Array, /) -> _Arraylib[_Array]:
 827  """Return `array` wrapped as an `_Arraylib` for dispatch of functions."""
 828  for cls in _DICT_ARRAYLIBS.values():
 829    if cls.recognize(array):
 830      return cls(array)  # type: ignore[abstract]
 831  raise ValueError(f'{array} {type(array)} {type(array).__module__} unrecognized by {ARRAYLIBS}.')
 832
 833
 834def _arr_arraylib(array: _Array, /) -> str:
 835  """Return the name of the `Arraylib` representing `array`."""
 836  return _as_arr(array).arraylib
 837
 838
 839def _arr_numpy(array: _Array, /) -> _NDArray:
 840  """Return a `numpy` version of `array`."""
 841  return _as_arr(array).numpy()
 842
 843
 844def _arr_dtype(array: _Array, /) -> _DType:
 845  """Return the equivalent of `array.dtype` as a `numpy` `dtype`."""
 846  return _as_arr(array).dtype()
 847
 848
 849def _arr_astype(array: _Array, dtype: _DTypeLike, /) -> _Array:
 850  """Return the equivalent of `array.astype(dtype)` with `numpy` `dtype`."""
 851  return _as_arr(array).astype(dtype)
 852
 853
 854def _arr_reshape(array: _Array, shape: tuple[int, ...], /) -> _Array:
 855  """Return the equivalent of `array.reshape(shape)."""
 856  return _as_arr(array).reshape(shape)
 857
 858
 859def _arr_possibly_make_contiguous(array: _Array, /) -> _Array:
 860  """Return a contiguous copy of `array` or just `array` if already contiguous."""
 861  return _as_arr(array).possibly_make_contiguous()
 862
 863
 864def _arr_clip(array: _Array, low: _Array, high: _Array, /, dtype: _DTypeLike = None) -> _Array:
 865  """Return the equivalent of `array.clip(low, high, dtype)` with `numpy` `dtype`."""
 866  return _as_arr(array).clip(low, high, dtype)
 867
 868
 869def _arr_square(array: _Array, /) -> _Array:
 870  """Return the equivalent of `np.square(array)`."""
 871  return _as_arr(array).square()
 872
 873
 874def _arr_sqrt(array: _Array, /) -> _Array:
 875  """Return the equivalent of `np.sqrt(array)`."""
 876  return _as_arr(array).sqrt()
 877
 878
 879def _arr_getitem(array: _Array, indices: _Array, /) -> _Array:
 880  """Return the equivalent of `array[indices]`."""
 881  return _as_arr(array).getitem(indices)
 882
 883
 884def _arr_where(condition: _Array, if_true: _Array, if_false: _Array, /) -> _Array:
 885  """Return the equivalent of `np.where(condition, if_true, if_false)`."""
 886  return _as_arr(condition).where(if_true, if_false)
 887
 888
 889def _arr_transpose(array: _Array, axes: Sequence[int], /) -> _Array:
 890  """Return the equivalent of `np.transpose(array, axes)`."""
 891  return _as_arr(array).transpose(axes)
 892
 893
 894def _arr_best_dims_order_for_resize(array: _Array, dst_shape: tuple[int, ...], /) -> list[int]:
 895  """Return the best order in which to process dims for resizing `array` to `dst_shape`."""
 896  return _as_arr(array).best_dims_order_for_resize(dst_shape)
 897
 898
 899def _arr_matmul_sparse_dense(
 900    sparse: Any, dense: _Array, /, *, num_threads: int | Literal['auto'] = 'auto'
 901) -> _Array:
 902  """Return the multiplication of the `sparse` and `dense` matrices."""
 903  assert num_threads == 'auto' or num_threads >= 1
 904  return _as_arr(dense).premult_with_sparse(sparse, num_threads)
 905
 906
 907def _arr_concatenate(arrays: Sequence[_Array], axis: int, /) -> _Array:
 908  """Return the equivalent of `np.concatenate(arrays, axis)`."""
 909  arraylib = _arr_arraylib(arrays[0])
 910  return _DICT_ARRAYLIBS[arraylib].concatenate(arrays, axis)
 911
 912
 913def _arr_einsum(subscripts: str, /, *operands: _Array) -> _Array:
 914  """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 915  arraylib = _arr_arraylib(operands[0])
 916  return _DICT_ARRAYLIBS[arraylib].einsum(subscripts, *operands)
 917
 918
 919def _arr_swapaxes(array: _Array, axis1: int, axis2: int, /) -> _Array:
 920  """Return the equivalent of `np.swapaxes(array, axis1, axis2)`."""
 921  ndim = len(array.shape)
 922  assert 0 <= axis1 < ndim and 0 <= axis2 < ndim, (axis1, axis2, ndim)
 923  axes = list(range(ndim))
 924  axes[axis1] = axis2
 925  axes[axis2] = axis1
 926  return _arr_transpose(array, axes)
 927
 928
 929def _arr_moveaxis(array: _Array, source: int, destination: int, /) -> _Array:
 930  """Return the equivalent of `np.moveaxis(array, source, destination)`."""
 931  ndim = len(array.shape)
 932  assert 0 <= source < ndim and 0 <= destination < ndim, (source, destination, ndim)
 933  axes = [n for n in range(ndim) if n != source]
 934  axes.insert(destination, source)
 935  return _arr_transpose(array, axes)
 936
 937
 938def _make_sparse_matrix(
 939    data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int], arraylib: str, /
 940) -> Any:
 941  """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 942  However, indices must be ordered and unique."""
 943  return _DICT_ARRAYLIBS[arraylib].make_sparse_matrix(data, row_ind, col_ind, shape)
 944
 945
 946def _make_array(array: _ArrayLike, arraylib: str, /) -> Any:
 947  """Return an array from the library `arraylib` initialized with the `numpy` `array`."""
 948  return _DICT_ARRAYLIBS[arraylib](np.asarray(array)).array  # type: ignore[abstract]
 949
 950
 951# Because np.ndarray supports strides, np.moveaxis() and np.permute() are constant-time.
 952# However, ndarray.reshape() often creates a copy of the array if the data is non-contiguous,
 953# e.g. dim=1 in an RGB image.
 954#
 955# In contrast, tf.Tensor does not support strides, so tf.transpose() returns a new permuted
 956# tensor.  However, tf.reshape() is always efficient.
 957
 958
 959def _block_shape_with_min_size(
 960    shape: tuple[int, ...], min_size: int, compact: bool = True
 961) -> tuple[int, ...]:
 962  """Return shape of block (of size at least `min_size`) to subdivide shape."""
 963  if math.prod(shape) < min_size:
 964    raise ValueError(f'Shape {shape} smaller than min_size {min_size}.')
 965  if compact:
 966    root = int(math.ceil(min_size ** (1 / len(shape))))
 967    block_shape = np.minimum(shape, root)
 968    for dim in range(len(shape)):
 969      if block_shape[dim] == 2 and block_shape.prod() >= min_size * 2:
 970        block_shape[dim] = 1
 971    for dim in range(len(shape) - 1, -1, -1):
 972      if block_shape.prod() < min_size:
 973        block_shape[dim] = shape[dim]
 974  else:
 975    block_shape = np.ones_like(shape)
 976    for dim in range(len(shape) - 1, -1, -1):
 977      if block_shape.prod() < min_size:
 978        block_shape[dim] = min(shape[dim], math.ceil(min_size / block_shape.prod()))
 979  return tuple(block_shape)
 980
 981
 982def _array_split(array: _Array, axis: int, num_sections: int) -> list[Any]:
 983  """Split `array` into `num_sections` along `axis`."""
 984  assert 0 <= axis < len(array.shape)
 985  assert 1 <= num_sections <= array.shape[axis]
 986
 987  if 0:
 988    split = np.array_split(array, num_sections, axis=axis)  # Numpy-specific.
 989
 990  else:
 991    # Adapted from https://github.com/numpy/numpy/blob/main/numpy/lib/shape_base.py#L739-L792.
 992    num_total = array.shape[axis]
 993    num_each, num_extra = divmod(num_total, num_sections)
 994    section_sizes = [0] + num_extra * [num_each + 1] + (num_sections - num_extra) * [num_each]
 995    div_points = np.array(section_sizes).cumsum()
 996    split = []
 997    tmp = _arr_swapaxes(array, axis, 0)
 998    for i in range(num_sections):
 999      split.append(_arr_swapaxes(tmp[div_points[i] : div_points[i + 1]], axis, 0))
1000
1001  return split
1002
1003
1004def _split_array_into_blocks(array: _Array, block_shape: Sequence[int], start_axis: int = 0) -> Any:
1005  """Split `array` into nested lists of blocks of size at most `block_shape`."""
1006  # See https://stackoverflow.com/a/50305924.  (If the block_shape is known to
1007  # exactly partition the array, see https://stackoverflow.com/a/16858283.)
1008  if len(block_shape) > len(array.shape):
1009    raise ValueError(f'Block ndim {len(block_shape)} > array ndim {len(array.shape)}.')
1010  if start_axis == len(block_shape):
1011    return array
1012
1013  num_sections = math.ceil(array.shape[start_axis] / block_shape[start_axis])
1014  split = _array_split(array, start_axis, num_sections)
1015  return [_split_array_into_blocks(split_a, block_shape, start_axis + 1) for split_a in split]
1016
1017
1018def _map_function_over_blocks(blocks: Any, func: Callable[[Any], Any]) -> Any:
1019  """Apply `func` to each block in the nested lists of `blocks`."""
1020  if isinstance(blocks, list):
1021    return [_map_function_over_blocks(block, func) for block in blocks]
1022  return func(blocks)
1023
1024
1025def _merge_array_from_blocks(blocks: Any, axis: int = 0) -> Any:
1026  """Merge an array from the nested lists of array blocks in `blocks`."""
1027  # More general than np.block() because the blocks can have additional dims.
1028  if isinstance(blocks, list):
1029    new_blocks = [_merge_array_from_blocks(block, axis + 1) for block in blocks]
1030    return _arr_concatenate(new_blocks, axis)
1031  return blocks
1032
1033
1034@dataclasses.dataclass(frozen=True)
1035class Gridtype(abc.ABC):
1036  """Abstract base class for grid-types such as `'dual'` and `'primal'`.
1037
1038  In resampling operations, the grid-type may be specified separately as `src_gridtype` for the
1039  source domain and `dst_gridtype` for the destination domain.  Moreover, the grid-type may be
1040  specified per domain dimension.
1041
1042  Examples:
1043    `resize(source, shape, gridtype='primal')`  # Sets both src and dst to be `'primal'` grids.
1044
1045    `resize(source, shape, src_gridtype=['dual', 'primal'],
1046            dst_gridtype='dual')`  # Source is `'dual'` in dim0 and `'primal'` in dim1.
1047  """
1048
1049  name: str
1050  """Gridtype name."""
1051
1052  @abc.abstractmethod
1053  def min_size(self) -> int:
1054    """Return the necessary minimum number of grid samples."""
1055
1056  @abc.abstractmethod
1057  def size_in_samples(self, size: int, /) -> int:
1058    """Return the domain size in units of inter-sample spacing."""
1059
1060  @abc.abstractmethod
1061  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1062    """Return [0.0, 1.0] coordinates given [0, size - 1] indices."""
1063
1064  @abc.abstractmethod
1065  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1066    """Return location x given coordinates [0.0, 1.0], where x == 0.0 is the first grid sample
1067    and x == size - 1.0 is the last grid sample."""
1068
1069  @abc.abstractmethod
1070  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1071    """Map integer sample indices to interior ones using boundary reflection."""
1072
1073  @abc.abstractmethod
1074  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1075    """Map integer sample indices to interior ones using wrapping."""
1076
1077  @abc.abstractmethod
1078  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1079    """Map integer sample indices to interior ones using reflect-clamp."""
1080
1081
1082class DualGridtype(Gridtype):
1083  """Samples are at the center of cells in a uniform partition of the domain.
1084
1085  For a unit-domain dimension with N samples, each sample 0 <= i < N has position (i + 0.5) / N,
1086  e.g., [0.125, 0.375, 0.625, 0.875] for N = 4.
1087  """
1088
1089  def __init__(self) -> None:
1090    super().__init__(name='dual')
1091
1092  def min_size(self) -> int:
1093    return 1
1094
1095  def size_in_samples(self, size: int, /) -> int:
1096    return size
1097
1098  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1099    return (index + 0.5) / size
1100
1101  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1102    return point * size - 0.5
1103
1104  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1105    index = np.mod(index, size * 2)
1106    return np.where(index < size, index, 2 * size - 1 - index)
1107
1108  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1109    return np.mod(index, size)
1110
1111  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1112    return np.minimum(np.where(index < 0, -1 - index, index), size - 1)
1113
1114
1115class PrimalGridtype(Gridtype):
1116  """Samples are at the vertices of cells in a uniform partition of the domain.
1117
1118  For a unit-domain dimension with N samples, each sample 0 <= i < N has position i / (N - 1),
1119  e.g., [0, 1/3, 2/3, 1] for N = 4.
1120  """
1121
1122  def __init__(self) -> None:
1123    super().__init__(name='primal')
1124
1125  def min_size(self) -> int:
1126    return 2
1127
1128  def size_in_samples(self, size: int, /) -> int:
1129    return size - 1
1130
1131  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1132    return index / (size - 1)
1133
1134  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1135    return point * (size - 1)
1136
1137  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1138    index = np.mod(index, size * 2 - 2)
1139    return np.where(index < size, index, 2 * size - 2 - index)
1140
1141  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1142    return np.mod(index, size - 1)
1143
1144  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1145    return np.minimum(np.abs(index), size - 1)
1146
1147
1148_DICT_GRIDTYPES = {
1149    'dual': DualGridtype(),
1150    'primal': PrimalGridtype(),
1151}
1152
1153GRIDTYPES = list(_DICT_GRIDTYPES)
1154r"""Shortcut names for the two predefined grid types (specified per dimension):
1155
1156| `gridtype` | `'dual'`<br/>`DualGridtype()`<br/>(default) | `'primal'`<br/>`PrimalGridtype()`<br/>&nbsp; |
1157| --- |:---:|:---:|
1158| Sample positions in 2D<br/>and in 1D at different resolutions | ![Dual](https://github.com/hhoppe/resampler/raw/main/media/dual_grid_small.png) | ![Primal](https://github.com/hhoppe/resampler/raw/main/media/primal_grid_small.png) |
1159| Nesting of samples across resolutions | The samples positions do *not* nest. | The *even* samples remain at coarser scale. |
1160| Number $N_\ell$ of samples (per-dimension) at resolution level $\ell$ | $N_\ell=2^\ell$ | $N_\ell=2^\ell+1$ |
1161| Position of sample index $i$ within domain $[0, 1]$ | $\frac{i + 0.5}{N}$ ("half-integer" coordinates) | $\frac{i}{N-1}$ |
1162| Image resolutions ($N_\ell\times N_\ell$) for dyadic scales | $1\times1, ~~2\times2, ~~4\times4, ~~8\times8, ~\ldots$ | $2\times2, ~~3\times3, ~~5\times5, ~~9\times9, ~\ldots$ |
1163
1164See the source code for extensibility.
1165"""
1166
1167
1168def _get_gridtype(gridtype: str | Gridtype) -> Gridtype:
1169  """Return a `Gridtype`, which can be specified as a name in `GRIDTYPES`."""
1170  return gridtype if isinstance(gridtype, Gridtype) else _DICT_GRIDTYPES[gridtype]
1171
1172
1173def _get_gridtypes(
1174    gridtype: str | Gridtype | None,
1175    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1176    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1177    src_ndim: int,
1178    dst_ndim: int,
1179) -> tuple[list[Gridtype], list[Gridtype]]:
1180  """Return per-dim source and destination grid types given all parameters."""
1181  if gridtype is None and src_gridtype is None and dst_gridtype is None:
1182    gridtype = 'dual'
1183  if gridtype is not None:
1184    if src_gridtype is not None:
1185      raise ValueError('Cannot have both gridtype and src_gridtype.')
1186    if dst_gridtype is not None:
1187      raise ValueError('Cannot have both gridtype and dst_gridtype.')
1188    src_gridtype = dst_gridtype = gridtype
1189  src_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(src_gridtype), src_ndim)]
1190  dst_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(dst_gridtype), dst_ndim)]
1191  return src_gridtype2, dst_gridtype2
1192
1193
1194@dataclasses.dataclass(frozen=True)
1195class RemapCoordinates(abc.ABC):
1196  """Abstract base class for modifying the specified coordinates prior to evaluating the
1197  reconstruction kernels."""
1198
1199  @abc.abstractmethod
1200  def __call__(self, point: _NDArray, /) -> _NDArray:
1201    ...
1202
1203
1204class NoRemapCoordinates(RemapCoordinates):
1205  """The coordinates are not remapped."""
1206
1207  def __call__(self, point: _NDArray, /) -> _NDArray:
1208    return point
1209
1210
1211class MirrorRemapCoordinates(RemapCoordinates):
1212  """The coordinates are reflected across the domain boundaries so that they lie in the unit
1213  interval.  The resulting function is continuous but not smooth across the boundaries."""
1214
1215  def __call__(self, point: _NDArray, /) -> _NDArray:
1216    point = np.mod(point, 2.0)
1217    return np.where(point >= 1.0, 2.0 - point, point)
1218
1219
1220class TileRemapCoordinates(RemapCoordinates):
1221  """The coordinates are mapped to the unit interval using a "modulo 1.0" operation.  The resulting
1222  function is generally discontinuous across the domain boundaries."""
1223
1224  def __call__(self, point: _NDArray, /) -> _NDArray:
1225    return np.mod(point, 1.0)
1226
1227
1228@dataclasses.dataclass(frozen=True)
1229class ExtendSamples(abc.ABC):
1230  """Abstract base class for replacing references to grid samples exterior to the unit domain by
1231  affine combinations of interior sample(s) and possibly the constant value (`cval`)."""
1232
1233  uses_cval: bool = False
1234  """True if some exterior samples are defined in terms of `cval`, i.e., if the computed weight
1235  is non-affine."""
1236
1237  @abc.abstractmethod
1238  def __call__(
1239      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1240  ) -> tuple[_NDArray, _NDArray]:
1241    """Detect references to exterior samples, i.e., entries of `index` that lie outside the
1242    interval [0, size), and update these indices (and possibly their associated weights) to
1243    reference only interior samples.  Return `new_index, new_weight`."""
1244
1245
1246class ReflectExtendSamples(ExtendSamples):
1247  """Find the interior sample by reflecting across domain boundaries."""
1248
1249  def __call__(
1250      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1251  ) -> tuple[_NDArray, _NDArray]:
1252    index = gridtype.reflect(index, size)
1253    return index, weight
1254
1255
1256class WrapExtendSamples(ExtendSamples):
1257  """Wrap the interior samples periodically.  For a `'primal'` grid, the last
1258  sample is ignored as its value is replaced by the first sample."""
1259
1260  def __call__(
1261      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1262  ) -> tuple[_NDArray, _NDArray]:
1263    index = gridtype.wrap(index, size)
1264    return index, weight
1265
1266
1267class ClampExtendSamples(ExtendSamples):
1268  """Use the nearest interior sample."""
1269
1270  def __call__(
1271      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1272  ) -> tuple[_NDArray, _NDArray]:
1273    index = index.clip(0, size - 1)
1274    return index, weight
1275
1276
1277class ReflectClampExtendSamples(ExtendSamples):
1278  """Extend the grid samples from [0, 1] into [-1, 0] using reflection and then define grid
1279  samples outside [-1, 1] as that of the nearest sample."""
1280
1281  def __call__(
1282      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1283  ) -> tuple[_NDArray, _NDArray]:
1284    index = gridtype.reflect_clamp(index, size)
1285    return index, weight
1286
1287
1288class BorderExtendSamples(ExtendSamples):
1289  """Let all exterior samples have the constant value (`cval`)."""
1290
1291  def __init__(self) -> None:
1292    super().__init__(uses_cval=True)
1293
1294  def __call__(
1295      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1296  ) -> tuple[_NDArray, _NDArray]:
1297    low = index < 0
1298    weight[low] = 0.0
1299    index[low] = 0
1300    high = index >= size
1301    weight[high] = 0.0
1302    index[high] = size - 1
1303    return index, weight
1304
1305
1306class ValidExtendSamples(ExtendSamples):
1307  """Assign all domain samples weight 1 and all outside samples weight 0.
1308  Compute a weighted reconstruction and divide by the reconstructed weight."""
1309
1310  def __init__(self) -> None:
1311    super().__init__(uses_cval=True)
1312
1313  def __call__(
1314      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1315  ) -> tuple[_NDArray, _NDArray]:
1316    low = index < 0
1317    weight[low] = 0.0
1318    index[low] = 0
1319    high = index >= size
1320    weight[high] = 0.0
1321    index[high] = size - 1
1322    sum_weight = weight.sum(axis=-1)
1323    nonzero_sum = sum_weight != 0.0
1324    np.divide(weight, sum_weight[..., None], out=weight, where=nonzero_sum[..., None])
1325    return index, weight
1326
1327
1328class LinearExtendSamples(ExtendSamples):
1329  """Linearly extrapolate beyond boundary samples."""
1330
1331  def __call__(
1332      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1333  ) -> tuple[_NDArray, _NDArray]:
1334    if size < 2:
1335      index = gridtype.reflect(index, size)
1336      return index, weight
1337    # For each boundary, define new columns in index and weight arrays to represent the last and
1338    # next-to-last samples.  When we later construct the sparse resize matrix, we will sum the
1339    # duplicate index entries.
1340    low = index < 0
1341    high = index >= size
1342    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 4), weight.dtype)
1343    x = index
1344    w[..., -4] = ((1 - x) * weight).sum(where=low, axis=-1)
1345    w[..., -3] = ((x) * weight).sum(where=low, axis=-1)
1346    x = (size - 1) - index
1347    w[..., -2] = ((x) * weight).sum(where=high, axis=-1)
1348    w[..., -1] = ((1 - x) * weight).sum(where=high, axis=-1)
1349    weight[low] = 0.0
1350    index[low] = 0
1351    weight[high] = 0.0
1352    index[high] = size - 1
1353    w[..., :-4] = weight
1354    weight = w
1355    new_index = np.empty(w.shape, index.dtype)
1356    new_index[..., :-4] = index
1357    # Let matrix (including zero values) be banded.
1358    new_index[..., -4:] = np.where(w[..., -4:] != 0.0, [0, 1, size - 2, size - 1], index[..., :1])
1359    index = new_index
1360    return index, weight
1361
1362
1363class QuadraticExtendSamples(ExtendSamples):
1364  """Quadratically extrapolate beyond boundary samples."""
1365
1366  def __call__(
1367      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1368  ) -> tuple[_NDArray, _NDArray]:
1369    # [Keys 1981] suggests this as x[-1] = 3*x[0] - 3*x[1] + x[2], calling it "cubic precision",
1370    # but it seems just quadratic.
1371    if size < 3:
1372      index = gridtype.reflect(index, size)
1373      return index, weight
1374    low = index < 0
1375    high = index >= size
1376    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 6), weight.dtype)
1377    x = index
1378    w[..., -6] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=low, axis=-1)
1379    w[..., -5] = (((-x + 2) * x) * weight).sum(where=low, axis=-1)
1380    w[..., -4] = (((0.5 * x - 0.5) * x) * weight).sum(where=low, axis=-1)
1381    x = (size - 1) - index
1382    w[..., -3] = (((0.5 * x - 0.5) * x) * weight).sum(where=high, axis=-1)
1383    w[..., -2] = (((-x + 2) * x) * weight).sum(where=high, axis=-1)
1384    w[..., -1] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=high, axis=-1)
1385    weight[low] = 0.0
1386    index[low] = 0
1387    weight[high] = 0.0
1388    index[high] = size - 1
1389    w[..., :-6] = weight
1390    weight = w
1391    new_index = np.empty(w.shape, index.dtype)
1392    new_index[..., :-6] = index
1393    # Let matrix (including zero values) be banded.
1394    new_index[..., -6:] = np.where(
1395        w[..., -6:] != 0.0, [0, 1, 2, size - 3, size - 2, size - 1], index[..., :1]
1396    )
1397    index = new_index
1398    return index, weight
1399
1400
1401@dataclasses.dataclass(frozen=True)
1402class OverrideExteriorValue:
1403  """Abstract base class to set the value outside some domain extent to a
1404  constant value (`cval`)."""
1405
1406  boundary_antialiasing: bool = True
1407  """Antialias the pixel values adjacent to the boundary of the extent."""
1408
1409  uses_cval: bool = False
1410  """Modify some weights to introduce references to the `cval` constant value."""
1411
1412  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1413    """For all `point` outside some extent, modify the weight to be zero."""
1414
1415  def override_using_signed_distance(
1416      self, weight: _NDArray, point: _NDArray, signed_distance: _NDArray, /
1417  ) -> None:
1418    """Reduce sample weights for "outside" values based on the signed distance function,
1419    to effectively assign the constant value `cval`."""
1420    all_points_inside_domain = np.all(signed_distance <= 0.0)
1421    if all_points_inside_domain:
1422      return
1423    if self.boundary_antialiasing and min(point.shape) >= 2:
1424      # For discontinuous coordinate mappings, we may need to somehow ignore
1425      # the large finite differences computed across the map discontinuities.
1426      gradient = np.gradient(point)
1427      gradient_norm = np.linalg.norm(np.atleast_2d(gradient), axis=0)
1428      signed_distance_in_samples = signed_distance / (gradient_norm + 1e-20)
1429      # Opacity is in linear space, which is correct if Gamma is set.
1430      opacity = (0.5 - signed_distance_in_samples).clip(0.0, 1.0)
1431      weight *= opacity[..., None]
1432    else:
1433      is_outside = signed_distance > 0.0
1434      weight[is_outside, :] = 0.0
1435
1436
1437class NoOverrideExteriorValue(OverrideExteriorValue):
1438  """The function value is not overridden."""
1439
1440  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1441    pass
1442
1443
1444class UnitDomainOverrideExteriorValue(OverrideExteriorValue):
1445  """Values outside the unit interval [0, 1] are replaced by the constant `cval`."""
1446
1447  def __init__(self, **kwargs: Any) -> None:
1448    super().__init__(uses_cval=True, **kwargs)
1449
1450  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1451    signed_distance = abs(point - 0.5) - 0.5  # Boundaries at 0.0 and 1.0.
1452    self.override_using_signed_distance(weight, point, signed_distance)
1453
1454
1455class PlusMinusOneOverrideExteriorValue(OverrideExteriorValue):
1456  """Values outside the interval [-1, 1] are replaced by the constant `cval`."""
1457
1458  def __init__(self, **kwargs: Any) -> None:
1459    super().__init__(uses_cval=True, **kwargs)
1460
1461  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1462    signed_distance = abs(point) - 1.0  # Boundaries at -1.0 and 1.0.
1463    self.override_using_signed_distance(weight, point, signed_distance)
1464
1465
1466@dataclasses.dataclass(frozen=True)
1467class Boundary:
1468  """Domain boundary rules.  These define the reconstruction over the source domain near and beyond
1469  the domain boundaries.  The rules may be specified separately for each domain dimension."""
1470
1471  name: str = ''
1472  """Boundary rule name."""
1473
1474  coord_remap: RemapCoordinates = NoRemapCoordinates()
1475  """Modify specified coordinates prior to evaluating the reconstruction kernels."""
1476
1477  extend_samples: ExtendSamples = ReflectExtendSamples()
1478  """Define the value of each grid sample outside the unit domain as an affine combination of
1479  interior sample(s) and possibly the constant value (`cval`)."""
1480
1481  override_value: OverrideExteriorValue = NoOverrideExteriorValue()
1482  """Set the value outside some extent to a constant value (`cval`)."""
1483
1484  @property
1485  def uses_cval(self) -> bool:
1486    """True if weights may be non-affine, involving the constant value (`cval`)."""
1487    return self.extend_samples.uses_cval or self.override_value.uses_cval
1488
1489  def preprocess_coordinates(self, point: _NDArray, /) -> _NDArray:
1490    """Modify coordinates prior to evaluating the filter kernels."""
1491    # Antialiasing across the tile boundaries may be feasible but seems hard.
1492    point = self.coord_remap(point)
1493    return point
1494
1495  def apply(
1496      self, index: _NDArray, weight: _NDArray, point: _NDArray, size: int, gridtype: Gridtype, /
1497  ) -> tuple[_NDArray, _NDArray]:
1498    """Replace exterior samples by combinations of interior samples."""
1499    index, weight = self.extend_samples(index, weight, size, gridtype)
1500    self.override_reconstruction(weight, point)
1501    return index, weight
1502
1503  def override_reconstruction(self, weight: _NDArray, point: _NDArray, /) -> None:
1504    """For points outside an extent, modify weight to zero to assign `cval`."""
1505    self.override_value(weight, point)
1506
1507
1508_DICT_BOUNDARIES = {
1509    'reflect': Boundary('reflect', extend_samples=ReflectExtendSamples()),
1510    'wrap': Boundary('wrap', extend_samples=WrapExtendSamples()),
1511    'tile': Boundary(
1512        'title', coord_remap=TileRemapCoordinates(), extend_samples=ReflectExtendSamples()
1513    ),
1514    'clamp': Boundary('clamp', extend_samples=ClampExtendSamples()),
1515    'border': Boundary('border', extend_samples=BorderExtendSamples()),
1516    'natural': Boundary(
1517        'natural',
1518        extend_samples=ValidExtendSamples(),
1519        override_value=UnitDomainOverrideExteriorValue(),
1520    ),
1521    'linear_constant': Boundary(
1522        'linear_constant',
1523        extend_samples=LinearExtendSamples(),
1524        override_value=UnitDomainOverrideExteriorValue(),
1525    ),
1526    'quadratic_constant': Boundary(
1527        'quadratic_constant',
1528        extend_samples=QuadraticExtendSamples(),
1529        override_value=UnitDomainOverrideExteriorValue(),
1530    ),
1531    'reflect_clamp': Boundary('reflect_clamp', extend_samples=ReflectClampExtendSamples()),
1532    'constant': Boundary(
1533        'constant',
1534        extend_samples=ReflectExtendSamples(),
1535        override_value=UnitDomainOverrideExteriorValue(),
1536    ),
1537    'linear': Boundary('linear', extend_samples=LinearExtendSamples()),
1538    'quadratic': Boundary('quadratic', extend_samples=QuadraticExtendSamples()),
1539}
1540
1541BOUNDARIES = list(_DICT_BOUNDARIES)
1542"""Shortcut names for some predefined boundary rules (as defined by `_DICT_BOUNDARIES`):
1543
1544| name                   | a.k.a. / comments |
1545|------------------------|-------------------|
1546| `'reflect'`            | *reflected*, *symm*, *symmetric*, *mirror*, *grid-mirror* |
1547| `'wrap'`               | *periodic*, *repeat*, *grid-wrap* |
1548| `'tile'`               | like `'reflect'` within unit domain, then tile discontinuously |
1549| `'clamp'`              | *clamped*, *nearest*, *edge*, *clamp-to-edge*, repeat last sample |
1550| `'border'`             | *grid-constant*, use `cval` for samples outside unit domain |
1551| `'natural'`            | *renormalize* using only interior samples, use `cval` outside domain |
1552| `'reflect_clamp'`      | *mirror-clamp-to-edge* |
1553| `'constant'`           | like `'reflect'` but replace by `cval` outside unit domain |
1554| `'linear'`             | extrapolate from 2 last samples |
1555| `'quadratic'`          | extrapolate from 3 last samples |
1556| `'linear_constant'`    | like `'linear'` but replace by `cval` outside unit domain |
1557| `'quadratic_constant'` | like `'quadratic'` but replace by `cval` outside unit domain |
1558
1559These boundary rules may be specified per dimension.  See the source code for extensibility
1560using the classes `RemapCoordinates`, `ExtendSamples`, and `OverrideExteriorValue`.
1561
1562**Boundary rules illustrated in 1D:**
1563
1564<center>
1565<img src="https://github.com/hhoppe/resampler/raw/main/media/boundary_rules_in_1D.png" width="100%"/>
1566</center>
1567
1568**Boundary rules illustrated in 2D:**
1569
1570<center>
1571<img src="https://github.com/hhoppe/resampler/raw/main/media/boundary_rules_in_2D.png" width="100%"/>
1572</center>
1573"""
1574
1575_OFTUSED_BOUNDARIES = (
1576    'reflect wrap tile clamp border natural linear_constant quadratic_constant'.split()
1577)
1578"""A useful subset of `BOUNDARIES` for visualization in figures."""
1579
1580
1581def _get_boundary(boundary: str | Boundary, /) -> Boundary:
1582  """Return a `Boundary`, which can be specified as a name in `BOUNDARIES`."""
1583  return boundary if isinstance(boundary, Boundary) else _DICT_BOUNDARIES[boundary]
1584
1585
1586@dataclasses.dataclass(frozen=True)
1587class Filter(abc.ABC):
1588  """Abstract base class for filter kernel functions.
1589
1590  Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support
1591  interval [-radius, radius].  (Some sites instead define kernels over the interval [0, N]
1592  where N = 2 * radius.)
1593
1594  Portions of this code are adapted from the C++ library in
1595  https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp
1596
1597  See also https://hhoppe.com/proj/filtering/.
1598  """
1599
1600  name: str
1601  """Filter kernel name."""
1602
1603  radius: float
1604  """Max absolute value of x for which self(x) is nonzero."""
1605
1606  interpolating: bool = True
1607  """True if self(0) == 1.0 and self(i) == 0.0 for all nonzero integers i."""
1608
1609  continuous: bool = True
1610  """True if the kernel function has $C^0$ continuity."""
1611
1612  partition_of_unity: bool = True
1613  """True if the convolution of the kernel with a Dirac comb reproduces the
1614  unity function."""
1615
1616  unit_integral: bool = True
1617  """True if the integral of the kernel function is 1."""
1618
1619  requires_digital_filter: bool = False
1620  """True if the filter needs a pre/post digital filter for interpolation."""
1621
1622  @abc.abstractmethod
1623  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1624    """Return evaluation of filter kernel at locations x."""
1625
1626
1627class ImpulseFilter(Filter):
1628  """See https://en.wikipedia.org/wiki/Dirac_delta_function."""
1629
1630  def __init__(self) -> None:
1631    super().__init__(name='impulse', radius=1e-20, continuous=False, partition_of_unity=False)
1632
1633  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1634    raise AssertionError('The Impulse is infinitely narrow, so cannot be directly evaluated.')
1635
1636
1637class BoxFilter(Filter):
1638  """See https://en.wikipedia.org/wiki/Box_function.
1639
1640  The kernel function has value 1.0 over the half-open interval [-.5, .5).
1641  """
1642
1643  def __init__(self) -> None:
1644    super().__init__(name='box', radius=0.5, continuous=False)
1645
1646  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1647    use_asymmetric = True
1648    if use_asymmetric:
1649      x = np.asarray(x)
1650      return np.where((-0.5 <= x) & (x < 0.5), 1.0, 0.0)
1651    x = np.abs(x)
1652    return np.where(x < 0.5, 1.0, np.where(x == 0.5, 0.5, 0.0))
1653
1654
1655class TrapezoidFilter(Filter):
1656  """Filter for antialiased "area-based" filtering.
1657
1658  Args:
1659    radius: Specifies the support [-radius, radius] of the filter, where 0.5 < radius <= 1.0.
1660      The special case `radius = None` is a placeholder that indicates that the filter will be
1661      replaced by a trapezoid of the appropriate radius (based on scaling) for correct
1662      antialiasing in both minification and magnification.
1663
1664  This filter is similar to the BoxFilter but with linearly sloped sides.  It has value 1.0
1665  in the interval abs(x) <= 1.0 - radius and decreases linearly to value 0.0 in the interval
1666  1.0 - radius <= abs(x) <= radius, always with value 0.5 at x = 0.5.
1667  """
1668
1669  def __init__(self, *, radius: float | None = None) -> None:
1670    if radius is None:
1671      super().__init__(name='trapezoid', radius=0.0)
1672      return
1673    if not 0.5 < radius <= 1.0:
1674      raise ValueError(f'Radius {radius} is outside the range (0.5, 1.0].')
1675    super().__init__(name=f'trapezoid_{radius}', radius=radius)
1676
1677  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1678    x = np.abs(x)
1679    assert 0.5 < self.radius <= 1.0
1680    return ((0.5 + 0.25 / (self.radius - 0.5)) - (0.5 / (self.radius - 0.5)) * x).clip(0.0, 1.0)
1681
1682
1683class TriangleFilter(Filter):
1684  """See https://en.wikipedia.org/wiki/Triangle_function.
1685
1686  Also known as the hat or tent function.  It is used for piecewise-linear
1687  (or bilinear, or trilinear, ...) interpolation.
1688  """
1689
1690  def __init__(self) -> None:
1691    super().__init__(name='triangle', radius=1.0)
1692
1693  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1694    return (1.0 - np.abs(x)).clip(0.0, 1.0)
1695
1696
1697class CubicFilter(Filter):
1698  """Family of cubic filters parameterized by two scalar parameters.
1699
1700  Args:
1701    b: first scalar parameter.
1702    c: second scalar parameter.
1703
1704  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1705  https://doi.org/10.1145/378456.378514.
1706
1707  [D. P. Mitchell and A. N. Netravali. Reconstruction filters in computer graphics.
1708  Computer Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1709
1710  - The filter has quadratic precision iff b + 2 * c == 1.
1711  - The filter is interpolating iff b == 0.
1712  - (b=1, c=0) is the (non-interpolating) cubic B-spline basis;
1713  - (b=1/3, c=1/3) is the Mitchell filter;
1714  - (b=0, c=0.5) is the Catmull-Rom spline (which has cubic precision);
1715  - (b=0, c=0.75) is the "sharper cubic" used in Photoshop and OpenCV.
1716  """
1717
1718  def __init__(self, *, b: float, c: float, name: str | None = None) -> None:
1719    name = f'cubic_b{b}_c{c}' if name is None else name
1720    interpolating = b == 0
1721    super().__init__(name=name, radius=2.0, interpolating=interpolating)
1722    self.b, self.c = b, c
1723
1724  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1725    x = np.abs(x)
1726    b, c = self.b, self.c
1727    f3, f2, f0 = 2 - 9 / 6 * b - c, -3 + 2 * b + c, 1 - 1 / 3 * b
1728    g3, g2, g1, g0 = -b / 6 - c, b + 5 * c, -2 * b - 8 * c, 8 / 6 * b + 4 * c
1729    # (np.polynomial.polynomial.polyval(x, [f0, 0, f2, f3]) is almost
1730    # twice as slow; see also https://stackoverflow.com/questions/24065904)
1731    v01 = ((f3 * x + f2) * x) * x + f0
1732    v12 = ((g3 * x + g2) * x + g1) * x + g0
1733    return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1734
1735
1736class CatmullRomFilter(CubicFilter):
1737  """Cubic filter with cubic precision.  Also known as Keys filter.
1738
1739  [E. Catmull, R. Rom.  A class of local interpolating splines.  Computer aided geometric
1740  design, 1974]
1741  [Wikipedia](https://en.wikipedia.org/wiki/Cubic_Hermite_spline#Catmull%E2%80%93Rom_spline)
1742
1743  [R. G. Keys.  Cubic convolution interpolation for digital image processing.
1744  IEEE Trans. on Acoustics, Speech, and Signal Processing, 29(6), 1981.]
1745  https://ieeexplore.ieee.org/document/1163711/.
1746  """
1747
1748  def __init__(self) -> None:
1749    super().__init__(b=0, c=0.5, name='cubic')
1750
1751
1752class MitchellFilter(CubicFilter):
1753  """See https://doi.org/10.1145/378456.378514.
1754
1755  [D. P. Mitchell and A. N. Netravali.  Reconstruction filters in computer graphics.  Computer
1756  Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1757  """
1758
1759  def __init__(self) -> None:
1760    super().__init__(b=1 / 3, c=1 / 3, name='mitchell')
1761
1762
1763class SharpCubicFilter(CubicFilter):
1764  """Cubic filter that is sharper than Catmull-Rom filter.
1765
1766  Used by some tools including OpenCV and Photoshop.
1767
1768  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1769  https://entropymine.com/resamplescope/notes/photoshop/.
1770  """
1771
1772  def __init__(self) -> None:
1773    super().__init__(b=0, c=0.75, name='sharpcubic')
1774
1775
1776class LanczosFilter(Filter):
1777  """High-quality filter: sinc function modulated by a sinc window.
1778
1779  Args:
1780    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1781    sampled: If True, use a discretized approximation for improved speed.
1782
1783  See https://en.wikipedia.org/wiki/Lanczos_kernel.
1784  """
1785
1786  def __init__(self, *, radius: int, sampled: bool = True) -> None:
1787    super().__init__(
1788        name=f'lanczos_{radius}', radius=radius, partition_of_unity=False, unit_integral=False
1789    )
1790
1791    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1792    def _eval(x: _ArrayLike) -> _NDArray:
1793      x = np.abs(x)
1794      # Note that window[n] = sinc(2*n/N - 1), with 0 <= n <= N.
1795      # But, x = n - N/2, or equivalently, n = x + N/2, with -N/2 <= x <= N/2.
1796      window = _sinc(x / radius)  # Zero-phase function w_0(x).
1797      return np.where(x < radius, _sinc(x) * window, 0.0)
1798
1799    self._function = _eval
1800
1801  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1802    return self._function(x)
1803
1804
1805class GeneralizedHammingFilter(Filter):
1806  """Sinc function modulated by a Hamming window.
1807
1808  Args:
1809    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1810    a0: Scalar parameter, where 0.0 < a0 < 1.0.  The case of a0=0.5 is the Hann filter.
1811
1812  See https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows,
1813  and hamming() in https://github.com/scipy/scipy/blob/main/scipy/signal/windows/_windows.py.
1814
1815  Note that `'hamming3'` is `(radius=3, a0=25/46)`, which close to but different from `a0=0.54`.
1816
1817  See also np.hamming() and np.hanning().
1818  """
1819
1820  def __init__(self, *, radius: int, a0: float) -> None:
1821    super().__init__(
1822        name=f'hamming_{radius}',
1823        radius=radius,
1824        partition_of_unity=False,  # 1:1.00242  av=1.00188  sd=0.00052909
1825        unit_integral=False,  # 1.00188
1826    )
1827    assert 0.0 < a0 < 1.0
1828    self.a0 = a0
1829
1830  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1831    x = np.abs(x)
1832    # Note that window[n] = a0 - (1 - a0) * cos(2 * pi * n / N), 0 <= n <= N.
1833    # With n = x + N/2, we get the zero-phase function w_0(x):
1834    window = self.a0 + (1.0 - self.a0) * np.cos(np.pi / self.radius * x)
1835    return np.where(x < self.radius, _sinc(x) * window, 0.0)
1836
1837
1838class KaiserFilter(Filter):
1839  """Sinc function modulated by a Kaiser-Bessel window.
1840
1841  See https://en.wikipedia.org/wiki/Kaiser_window, and example use in:
1842  [Karras et al. 20201.  Alias-free generative adversarial networks.
1843  https://arxiv.org/pdf/2106.12423.pdf].
1844
1845  See also np.kaiser().
1846
1847  Args:
1848    radius: Value L/2 in the definition.  It may be fractional for a (digital) resizing filter
1849      (sample spacing s != 1) with an even number of samples (dual grid), e.g., Eq. (6)
1850      in [Karras et al. 2021] --- this effects the precise shape of the window function.
1851    beta: Determines the trade-off between main-lobe width and side-lobe level.
1852    sampled: If True, use a discretized approximation for improved speed.
1853  """
1854
1855  def __init__(self, *, radius: float, beta: float, sampled: bool = True) -> None:
1856    assert beta >= 0.0
1857    super().__init__(
1858        name=f'kaiser_{radius}_{beta}', radius=radius, partition_of_unity=False, unit_integral=False
1859    )
1860
1861    @_cache_sampled_1d_function(xmin=-math.ceil(radius), xmax=math.ceil(radius), enable=sampled)
1862    def _eval(x: _ArrayLike) -> _NDArray:
1863      x = np.abs(x)
1864      window = np.i0(beta * np.sqrt((1.0 - np.square(x / radius)).clip(0.0, 1.0))) / np.i0(beta)
1865      return np.where(x <= radius + 1e-6, _sinc(x) * window, 0.0)
1866
1867    self._function = _eval
1868
1869  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1870    return self._function(x)
1871
1872
1873class BsplineFilter(Filter):
1874  """B-spline of a non-negative degree.
1875
1876  Args:
1877    degree: The polynomial degree of the B-spline segments.
1878      With `degree=0`, it is like `BoxFilter` except with f(0.5) = f(-0.5) = 0.
1879      With `degree=1`, it is identical to `TriangleFilter`.
1880      With `degree >= 2`, it is no longer interpolating.
1881
1882  See [Carl de Boor.  A practical guide to splines.  Springer, 2001.]
1883  https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.BSpline.html
1884  """
1885
1886  def __init__(self, *, degree: int) -> None:
1887    if degree < 0:
1888      raise ValueError(f'Bspline of degree {degree} is invalid.')
1889    radius = (degree + 1) / 2
1890    interpolating = degree <= 1
1891    super().__init__(name=f'bspline{degree}', radius=radius, interpolating=interpolating)
1892    t = list(range(degree + 2))
1893    self._bspline = scipy.interpolate.BSpline.basis_element(t)
1894
1895  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1896    x = np.abs(x)
1897    return np.where(x < self.radius, self._bspline(x + self.radius), 0.0)
1898
1899
1900class CardinalBsplineFilter(Filter):
1901  """Interpolating B-spline, achieved with aid of digital pre or post filter.
1902
1903  Args:
1904    degree: The polynomial degree of the B-spline segments.
1905    sampled: If True, use a discretized approximation for improved speed.
1906
1907  See [Hou and Andrews.  Cubic splines for image interpolation and digital filtering, 1978] and
1908  [Unser et al.  Fast B-spline transforms for continuous image representation and interpolation,
1909  1991].
1910  """
1911
1912  def __init__(self, *, degree: int, sampled: bool = True) -> None:
1913    self.degree = degree
1914    if degree < 0:
1915      raise ValueError(f'Bspline of degree {degree} is invalid.')
1916    radius = (degree + 1) / 2
1917    super().__init__(
1918        name=f'cardinal{degree}',
1919        radius=radius,
1920        requires_digital_filter=degree >= 2,
1921        continuous=degree >= 1,
1922    )
1923    t = list(range(degree + 2))
1924    bspline = scipy.interpolate.BSpline.basis_element(t)
1925
1926    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1927    def _eval(x: _ArrayLike) -> _NDArray:
1928      x = np.abs(x)
1929      return np.where(x < radius, bspline(x + radius), 0.0)
1930
1931    self._function = _eval
1932
1933  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1934    return self._function(x)
1935
1936
1937class OmomsFilter(Filter):
1938  """OMOMS interpolating filter, with aid of digital pre or post filter.
1939
1940  Args:
1941    degree: The polynomial degree of the filter segments.
1942
1943  Optimal MOMS (maximal-order-minimal-support) function; see [Blu and Thevenaz, MOMS: Maximal-order
1944  interpolation of minimal support, 2001].
1945  https://infoscience.epfl.ch/record/63074/files/blu0101.pdf
1946  """
1947
1948  def __init__(self, *, degree: int) -> None:
1949    if degree not in (3, 5):
1950      raise ValueError(f'Degree {degree} not supported.')
1951    super().__init__(name=f'omoms{degree}', radius=(degree + 1) / 2, requires_digital_filter=True)
1952    self.degree = degree
1953
1954  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1955    x = np.abs(x)
1956    match self.degree:
1957      case 3:
1958        v01 = ((0.5 * x - 1.0) * x + 3 / 42) * x + 26 / 42
1959        v12 = ((-7 / 42 * x + 1.0) * x - 85 / 42) * x + 58 / 42
1960        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1961      case 5:
1962        v01 = ((((-1 / 12 * x + 1 / 4) * x - 5 / 99) * x - 9 / 22) * x - 1 / 792) * x + 229 / 440
1963        v12 = (
1964            (((1 / 24 * x - 3 / 8) * x + 505 / 396) * x - 83 / 44) * x + 1351 / 1584
1965        ) * x + 839 / 2640
1966        v23 = (
1967            (((-1 / 120 * x + 1 / 8) * x - 299 / 396) * x + 101 / 44) * x - 27811 / 7920
1968        ) * x + 5707 / 2640
1969        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, np.where(x < 3.0, v23, 0.0)))
1970      case _:
1971        raise ValueError(self.degree)
1972
1973
1974class GaussianFilter(Filter):
1975  r"""See https://en.wikipedia.org/wiki/Gaussian_function.
1976
1977  Args:
1978    standard_deviation: Sets the Gaussian $\sigma$.  The default value is 1.25/3.0, which
1979      creates a kernel that is as-close-as-possible to a partition of unity.
1980  """
1981
1982  DEFAULT_STANDARD_DEVIATION = 1.25 / 3.0
1983  """This value creates a kernel that is as-close-as-possible to a partition of unity; see
1984  mesh_processing/test/GridOp_test.cpp: `0.93503:1.06497     av=1           sd=0.0459424`.
1985  Another possibility is 0.5, as suggested on p. 4 of [Ken Turkowski.  Filters for common
1986  resampling tasks, 1990] for kernels with a support of 3 pixels.
1987  https://cadxfem.org/inf/ResamplingFilters.pdf
1988  """
1989
1990  def __init__(self, *, standard_deviation: float = DEFAULT_STANDARD_DEVIATION) -> None:
1991    super().__init__(
1992        name=f'gaussian_{standard_deviation:.3f}',
1993        radius=np.ceil(8.0 * standard_deviation),  # Sufficiently large.
1994        interpolating=False,
1995        partition_of_unity=False,
1996    )
1997    self.standard_deviation = standard_deviation
1998
1999  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2000    x = np.abs(x)
2001    sdv = self.standard_deviation
2002    v0r = np.exp(np.square(x / sdv) / -2.0) / (np.sqrt(math.tau) * sdv)
2003    return np.where(x < self.radius, v0r, 0.0)
2004
2005
2006class NarrowBoxFilter(Filter):
2007  """Compact footprint, used for visualization of grid sample location.
2008
2009  Args:
2010    radius: Specifies the support [-radius, radius] of the narrow box function.  (The default
2011      value 0.199 is an inexact 0.2 to avoid numerical ambiguities.)
2012  """
2013
2014  def __init__(self, *, radius: float = 0.199) -> None:
2015    super().__init__(
2016        name='narrowbox',
2017        radius=radius,
2018        continuous=False,
2019        unit_integral=False,
2020        partition_of_unity=False,
2021    )
2022
2023  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2024    radius = self.radius
2025    magnitude = 1.0
2026    x = np.asarray(x)
2027    return np.where((-radius <= x) & (x < radius), magnitude, 0.0)
2028
2029
2030_DEFAULT_FILTER = 'lanczos3'
2031
2032_DICT_FILTERS = {
2033    'impulse': ImpulseFilter(),
2034    'box': BoxFilter(),
2035    'trapezoid': TrapezoidFilter(),
2036    'triangle': TriangleFilter(),
2037    'cubic': CatmullRomFilter(),
2038    'sharpcubic': SharpCubicFilter(),
2039    'lanczos3': LanczosFilter(radius=3),
2040    'lanczos5': LanczosFilter(radius=5),
2041    'lanczos10': LanczosFilter(radius=10),
2042    'cardinal3': CardinalBsplineFilter(degree=3),
2043    'cardinal5': CardinalBsplineFilter(degree=5),
2044    'omoms3': OmomsFilter(degree=3),
2045    'omoms5': OmomsFilter(degree=5),
2046    'hamming3': GeneralizedHammingFilter(radius=3, a0=25 / 46),
2047    'kaiser3': KaiserFilter(radius=3.0, beta=7.12),
2048    'gaussian': GaussianFilter(),
2049    'bspline3': BsplineFilter(degree=3),
2050    'mitchell': MitchellFilter(),
2051    'narrowbox': NarrowBoxFilter(),
2052    # Not in FILTERS:
2053    'hamming1': GeneralizedHammingFilter(radius=1, a0=0.54),
2054    'hann3': GeneralizedHammingFilter(radius=3, a0=0.5),
2055    'lanczos4': LanczosFilter(radius=4),
2056}
2057
2058FILTERS = list(itertools.takewhile(lambda x: x != 'hamming1', _DICT_FILTERS))
2059r"""Shortcut names for some predefined filter kernels (specified per dimension).
2060The names expand to:
2061
2062| name           | `Filter`                      | a.k.a. / comments |
2063|----------------|-------------------------------|-------------------|
2064| `'impulse'`    | `ImpulseFilter()`             | *nearest* |
2065| `'box'`        | `BoxFilter()`                 | non-antialiased box, e.g. ImageMagick |
2066| `'trapezoid'`  | `TrapezoidFilter()`           | *area* antialiasing, e.g. `cv.INTER_AREA` |
2067| `'triangle'`   | `TriangleFilter()`            | *linear*  (*bilinear* in 2D), spline `order=1` |
2068| `'cubic'`      | `CatmullRomFilter()`          | *catmullrom*, *keys*, *bicubic* |
2069| `'sharpcubic'` | `SharpCubicFilter()`          | `cv.INTER_CUBIC`, `torch 'bicubic'` |
2070| `'lanczos3'`   | `LanczosFilter`(radius=3)     | support window [-3, 3] |
2071| `'lanczos5'`   | `LanczosFilter`(radius=5)     | [-5, 5] |
2072| `'lanczos10'`  | `LanczosFilter`(radius=10)    | [-10, 10] |
2073| `'cardinal3'`  | `CardinalBsplineFilter`(degree=3) | *spline interpolation*, `order=3`, *GF* |
2074| `'cardinal5'`  | `CardinalBsplineFilter`(degree=5) | *spline interpolation*, `order=5`, *GF* |
2075| `'omoms3'`     | `OmomsFilter`(degree=3)       | non-$C^1$, [-3, 3], *GF* |
2076| `'omoms5'`     | `OmomsFilter`(degree=5)       | non-$C^1$, [-5, 5], *GF* |
2077| `'hamming3'`   | `GeneralizedHammingFilter`(...) | (radius=3, a0=25/46) |
2078| `'kaiser3'`    | `KaiserFilter`(radius=3.0, beta=7.12) | |
2079| `'gaussian'`   | `GaussianFilter()`            | non-interpolating, default $\sigma=1.25/3$ |
2080| `'bspline3'`   | `BsplineFilter`(degree=3)     | non-interpolating |
2081| `'mitchell'`   | `MitchellFilter()`            | *mitchellcubic* |
2082| `'narrowbox'`  | `NarrowBoxFilter()`           | for visualization of sample positions |
2083
2084The comment label *GF* denotes a [generalized filter](https://hhoppe.com/proj/filtering/), formed
2085as the composition of a finitely supported kernel and a discrete inverse convolution.
2086
2087**Some example filter kernels:**
2088
2089<center>
2090<img src="https://github.com/hhoppe/resampler/raw/main/media/filter_summary.png" width="100%"/>
2091</center>
2092
2093<br/>A more extensive set of filters is presented [here](#plots_of_filters) in the
2094[notebook](https://colab.research.google.com/github/hhoppe/resampler/blob/main/resampler_notebook.ipynb),
2095together with visualizations and analyses of the filter properties.
2096See the source code for extensibility.
2097"""
2098
2099
2100def _get_filter(filter: str | Filter, /) -> Filter:
2101  """Return a `Filter`, which can be specified as a name string key in `FILTERS`."""
2102  return filter if isinstance(filter, Filter) else _DICT_FILTERS[filter]
2103
2104
2105def _to_float_01(array: _Array, /, dtype: _DTypeLike) -> _Array:
2106  """Scale uint to the range [0.0, 1.0], and clip float to [0.0, 1.0]."""
2107  array_dtype = _arr_dtype(array)
2108  dtype = np.dtype(dtype)
2109  assert np.issubdtype(dtype, np.floating)
2110  match array_dtype.type:
2111    case np.uint8 | np.uint16 | np.uint32:
2112      if _arr_arraylib(array) == 'numpy':
2113        assert isinstance(array, np.ndarray)  # Help mypy.
2114        return np.multiply(array, 1 / np.iinfo(array_dtype).max, dtype=dtype)
2115      return _arr_astype(array, dtype) / np.iinfo(array_dtype).max
2116    case _:
2117      assert np.issubdtype(array_dtype, np.floating)
2118      return _arr_clip(array, 0.0, 1.0, dtype)
2119
2120
2121def _from_float(array: _Array, /, dtype: _DTypeLike) -> _Array:
2122  """Convert a float in range [0.0, 1.0] to uint or float type."""
2123  assert np.issubdtype(_arr_dtype(array), np.floating)
2124  dtype = np.dtype(dtype)
2125  match dtype.type:
2126    case np.uint8 | np.uint16:
2127      return cast(_Array, _arr_astype(array * np.float32(np.iinfo(dtype).max) + 0.5, dtype))
2128    case np.uint32:
2129      return cast(_Array, _arr_astype(array * np.float64(np.iinfo(dtype).max) + 0.5, dtype))
2130    case _:
2131      assert np.issubdtype(dtype, np.floating)
2132      return _arr_astype(array, dtype)
2133
2134
2135@dataclasses.dataclass(frozen=True)
2136class Gamma(abc.ABC):
2137  """Abstract base class for transfer functions on sample values.
2138
2139  Image/video content is often stored using a color component transfer function.
2140  See https://en.wikipedia.org/wiki/Gamma_correction.
2141
2142  Converts between integer types and [0.0, 1.0] internal value range.
2143  """
2144
2145  name: str
2146  """Name of component transfer function."""
2147
2148  @abc.abstractmethod
2149  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2150    """Decode source sample values into floating-point, possibly nonlinearly.
2151
2152    Uint source values are mapped to the range [0.0, 1.0].
2153    """
2154
2155  @abc.abstractmethod
2156  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2157    """Encode float signal into destination samples, possibly nonlinearly.
2158
2159    Uint destination values are mapped from the range [0.0, 1.0].
2160
2161    Note that non-integer destination types are not clipped to the range [0.0, 1.0].
2162    If that is desired, it can be performed as a postprocess using `output.clip(0.0, 1.0)`.
2163    """
2164
2165
2166class IdentityGamma(Gamma):
2167  """Identity component transfer function."""
2168
2169  def __init__(self) -> None:
2170    super().__init__('identity')
2171
2172  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2173    dtype = np.dtype(dtype)
2174    assert np.issubdtype(dtype, np.inexact)
2175    if np.issubdtype(_arr_dtype(array), np.unsignedinteger):
2176      return _to_float_01(array, dtype)
2177    return _arr_astype(array, dtype)
2178
2179  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2180    dtype = np.dtype(dtype)
2181    assert np.issubdtype(dtype, np.number)
2182    if np.issubdtype(dtype, np.unsignedinteger):
2183      return _from_float(_arr_clip(array, 0.0, 1.0), dtype)
2184    if np.issubdtype(dtype, np.integer):
2185      return _arr_astype(array + 0.5, dtype)
2186    return _arr_astype(array, dtype)
2187
2188
2189class PowerGamma(Gamma):
2190  """Gamma correction using a power function."""
2191
2192  def __init__(self, power: float) -> None:
2193    super().__init__(name=f'power_{power}')
2194    self.power = power
2195
2196  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2197    dtype = np.dtype(dtype)
2198    assert np.issubdtype(dtype, np.floating)
2199    if _arr_dtype(array) == np.uint8 and self.power != 2:
2200      arraylib = _arr_arraylib(array)
2201      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2202      return _arr_getitem(decode_table, array)
2203
2204    array = _to_float_01(array, dtype)
2205    return _arr_square(array) if self.power == 2 else array**self.power
2206
2207  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2208    array = _arr_clip(array, 0.0, 1.0)
2209    array = _arr_sqrt(array) if self.power == 2 else array ** (1.0 / self.power)
2210    return _from_float(array, dtype)
2211
2212
2213class SrgbGamma(Gamma):
2214  """Gamma correction using sRGB; see https://en.wikipedia.org/wiki/SRGB."""
2215
2216  def __init__(self) -> None:
2217    super().__init__(name='srgb')
2218
2219  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2220    dtype = np.dtype(dtype)
2221    assert np.issubdtype(dtype, np.floating)
2222    if _arr_dtype(array) == np.uint8:
2223      arraylib = _arr_arraylib(array)
2224      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2225      return _arr_getitem(decode_table, array)
2226
2227    x = _to_float_01(array, dtype)
2228    return _arr_where(x > 0.04045, ((x + 0.055) / 1.055) ** 2.4, x / 12.92)
2229
2230  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2231    x = _arr_clip(array, 0.0, 1.0)
2232    # Unfortunately, exponentiation is slow, and np.digitize() is even slower.
2233    # pytype: disable=wrong-arg-types
2234    x = _arr_where(x > 0.0031308, x ** (1.0 / 2.4) * 1.055 - (0.055 - 1e-17), x * 12.92)
2235    # pytype: enable=wrong-arg-types
2236    return _from_float(x, dtype)
2237
2238
2239_DICT_GAMMAS = {
2240    'identity': IdentityGamma(),
2241    'power2': PowerGamma(2.0),
2242    'power22': PowerGamma(2.2),
2243    'srgb': SrgbGamma(),
2244}
2245
2246GAMMAS = list(_DICT_GAMMAS)
2247r"""Shortcut names for some predefined gamma-correction schemes:
2248
2249| name | `Gamma` | Decoding function<br/> (linear space from stored value) | Encoding function<br/> (stored value from linear space) |
2250|---|---|:---:|:---:|
2251| `'identity'` | `IdentityGamma()` | $l = e$ | $e = l$ |
2252| `'power2'` | `PowerGamma`(2.0) | $l = e^{2.0}$ | $e = l^{1/2.0}$ |
2253| `'power22'` | `PowerGamma`(2.2) | $l = e^{2.2}$ | $e = l^{1/2.2}$ |
2254| `'srgb'` | `SrgbGamma()` | $l = \left(\left(e + 0.055\right) / 1.055\right)^{2.4}$ | $e = l^{1/2.4} * 1.055 - 0.055$ |
2255
2256See the source code for extensibility.
2257"""
2258
2259
2260def _get_gamma(gamma: str | Gamma, /) -> Gamma:
2261  """Return a `Gamma`, which can be specified as a name in `GAMMAS`."""
2262  return gamma if isinstance(gamma, Gamma) else _DICT_GAMMAS[gamma]
2263
2264
2265def _get_src_dst_gamma(
2266    gamma: str | Gamma | None,
2267    src_gamma: str | Gamma | None,
2268    dst_gamma: str | Gamma | None,
2269    src_dtype: _DType,
2270    dst_dtype: _DType,
2271) -> tuple[Gamma, Gamma]:
2272  if gamma is None and src_gamma is None and dst_gamma is None:
2273    src_uint = np.issubdtype(src_dtype, np.unsignedinteger)
2274    dst_uint = np.issubdtype(dst_dtype, np.unsignedinteger)
2275    if src_uint and dst_uint:
2276      # The default might ideally be 'srgb' but that conversion is costlier.
2277      gamma = 'power2'
2278    elif not src_uint and not dst_uint:
2279      gamma = 'identity'
2280    else:
2281      raise ValueError(f'Gamma must be specified because {src_dtype=} and {dst_dtype=}.')
2282  if gamma is not None:
2283    if src_gamma is not None:
2284      raise ValueError('Cannot specify both gamma and src_gamma.')
2285    if dst_gamma is not None:
2286      raise ValueError('Cannot specify both gamma and dst_gamma.')
2287    src_gamma = dst_gamma = gamma
2288  assert src_gamma and dst_gamma
2289  src_gamma = _get_gamma(src_gamma)
2290  dst_gamma = _get_gamma(dst_gamma)
2291  return src_gamma, dst_gamma
2292
2293
2294def _create_resize_matrix(
2295    src_size: int,
2296    dst_size: int,
2297    src_gridtype: Gridtype,
2298    dst_gridtype: Gridtype,
2299    boundary: Boundary,
2300    filter: Filter,
2301    prefilter: Filter | None = None,
2302    scale: float = 1.0,
2303    translate: float = 0.0,
2304    dtype: _DTypeLike = np.float64,
2305    arraylib: str = 'numpy',
2306) -> tuple[_Array, _Array | None]:
2307  """Compute affine weights for 1D resampling from `src_size` to `dst_size`.
2308
2309  Compute a sparse matrix in which each row expresses a destination sample value as a combination
2310  of source sample values depending on the boundary rule.  If the combination is non-affine,
2311  the remainder (returned as `cval_weight`) is the contribution of the special constant value
2312  (cval) defined outside the domain.
2313
2314  Args:
2315    src_size: The number of samples within the source 1D domain.
2316    dst_size: The number of samples within the destination 1D domain.
2317    src_gridtype: Placement of the samples in the source domain grid.
2318    dst_gridtype: Placement of the output samples in the destination domain grid.
2319    boundary: The reconstruction boundary rule.
2320    filter: The reconstruction kernel (used for upsampling/magnification).
2321    prefilter: The prefilter kernel (used for downsampling/minification).  If it is `None`,
2322      `filter` is used.
2323    scale: Scaling factor applied when mapping the source domain onto the destination domain.
2324    translate: Offset applied when mapping the scaled source domain onto the destination domain.
2325    dtype: Precision of computed resize matrix entries.
2326    arraylib: Representation of output.  Must be an element of `ARRAYLIBS`.
2327
2328  Returns:
2329    sparse_matrix: Matrix whose rows express output sample values as affine combinations of the
2330      source sample values.
2331    cval_weight: Optional vector expressing the additional contribution of the constant value
2332      (`cval`) to the combination in each row of `sparse_matrix`.  It equals one minus the sum of
2333      the weights in each matrix row.
2334  """
2335  if src_size < src_gridtype.min_size():
2336    raise ValueError(f'Source size {src_size} is too small for resize.')
2337  if dst_size < dst_gridtype.min_size():
2338    raise ValueError(f'Destination size {dst_size} is too small for resize.')
2339  prefilter = filter if prefilter is None else prefilter
2340  dtype = np.dtype(dtype)
2341  assert np.issubdtype(dtype, np.floating)
2342
2343  scaling = dst_gridtype.size_in_samples(dst_size) / src_gridtype.size_in_samples(src_size) * scale
2344  is_minification = scaling < 1.0
2345  filter = prefilter if is_minification else filter
2346  if filter.name == 'trapezoid':
2347    radius = 0.5 + 0.5 * min(scaling, 1.0 / scaling)
2348    filter = TrapezoidFilter(radius=radius)
2349  radius = filter.radius
2350  num_samples = int(np.ceil(radius * 2 / scaling) if is_minification else np.ceil(radius * 2))
2351
2352  dst_index = np.arange(dst_size, dtype=dtype)
2353  # Destination sample locations in unit domain [0, 1].
2354  dst_position = dst_gridtype.point_from_index(dst_index, dst_size)
2355
2356  src_position = (dst_position - translate) / scale
2357  src_position = boundary.preprocess_coordinates(src_position)
2358
2359  # Sample positions mapped back to source unit domain [0, 1].
2360  src_float_index = src_gridtype.index_from_point(src_position, src_size)
2361  src_first_index = (
2362      np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
2363      - (num_samples - 1) // 2
2364  )
2365
2366  sample_index = np.arange(num_samples, dtype=np.int32)
2367  src_index = src_first_index[:, None] + sample_index  # (dst_size, num_samples)
2368
2369  def get_weight_matrix() -> _NDArray:
2370    if filter.name == 'impulse':
2371      return np.ones(src_index.shape, dtype)
2372    if is_minification:
2373      x = (src_float_index[:, None] - src_index.astype(dtype)) * scaling
2374      return filter(x) * scaling
2375    # Either same size or magnification.
2376    x = src_float_index[:, None] - src_index.astype(dtype)
2377    return filter(x)
2378
2379  weight = get_weight_matrix().astype(dtype, copy=False)
2380
2381  if filter.name != 'narrowbox' and (is_minification or not filter.partition_of_unity):
2382    weight = weight / weight.sum(axis=-1)[..., None]
2383
2384  src_index, weight = boundary.apply(src_index, weight, src_position, src_size, src_gridtype)
2385  shape = dst_size, src_size
2386
2387  def prepare_sparse_resize_matrix() -> tuple[_NDArray, _NDArray, _NDArray]:
2388    linearized = (src_index + np.indices(src_index.shape)[0] * src_size).ravel()
2389    values = weight.ravel()
2390    # Remove the zero weights.
2391    nonzero = values != 0.0
2392    linearized, values = linearized[nonzero], values[nonzero]
2393    # Sort and merge the duplicate indices.
2394    unique, unique_inverse = np.unique(linearized, return_inverse=True)
2395    data2 = np.ones(len(linearized), np.float32)
2396    row_ind2 = unique_inverse
2397    col_ind2 = np.arange(len(linearized))
2398    shape2 = len(unique), len(linearized)
2399    csr = scipy.sparse.csr_matrix((data2, (row_ind2, col_ind2)), shape=shape2)
2400    data = csr * values  # Merged values.
2401    row_ind, col_ind = unique // src_size, unique % src_size  # Merged indices.
2402    return data, row_ind, col_ind
2403
2404  data, row_ind, col_ind = prepare_sparse_resize_matrix()
2405  resize_matrix = _make_sparse_matrix(data, row_ind, col_ind, shape, arraylib)
2406
2407  uses_cval = boundary.uses_cval or filter.name == 'narrowbox'
2408  cval_weight = _make_array(1.0 - weight.sum(axis=-1), arraylib) if uses_cval else None
2409
2410  return resize_matrix, cval_weight
2411
2412
2413def _apply_digital_filter_1d(
2414    array: _Array,
2415    gridtype: Gridtype,
2416    boundary: Boundary,
2417    cval: _ArrayLike,
2418    filter: Filter,
2419    /,
2420    *,
2421    axis: int = 0,
2422) -> _Array:
2423  """Apply inverse convolution to the specified dimension of the array.
2424
2425  Find the array coefficients such that convolution with the (continuous) filter (given
2426  gridtype and boundary) interpolates the original array values.
2427  """
2428  assert filter.requires_digital_filter
2429  arraylib = _arr_arraylib(array)
2430
2431  if arraylib == 'tensorflow':
2432    import tensorflow as tf
2433
2434    def forward(x: _NDArray) -> _NDArray:
2435      return _apply_digital_filter_1d_numpy(x, gridtype, boundary, cval, filter, axis, False)
2436
2437    def backward(grad_output: _NDArray) -> _NDArray:
2438      return _apply_digital_filter_1d_numpy(
2439          grad_output, gridtype, boundary, cval, filter, axis, True
2440      )
2441
2442    @tf.custom_gradient  # type: ignore[untyped-decorator]
2443    def tensorflow_inverse_convolution(x: _TensorflowTensor) -> _TensorflowTensor:
2444      # Although `forward` accesses parameters gridtype, boundary, etc., it is not stateful
2445      # because the function is redefined on each invocation of _apply_digital_filter_1d.
2446      y = tf.numpy_function(forward, [x], x.dtype, stateful=False)
2447
2448      def grad(grad_output: _TensorflowTensor) -> _TensorflowTensor:
2449        return tf.numpy_function(backward, [grad_output], x.dtype, stateful=False)
2450
2451      return y, grad
2452
2453    return tensorflow_inverse_convolution(array)
2454
2455  if arraylib == 'torch':
2456    import torch.autograd
2457
2458    class InverseConvolution(torch.autograd.Function):  # type: ignore[misc] # pylint: disable=abstract-method
2459      """Differentiable wrapper for _apply_digital_filter_1d."""
2460
2461      @staticmethod
2462      # pylint: disable-next=arguments-differ
2463      def forward(ctx: Any, *args: _TorchTensor, **kwargs: Any) -> _TorchTensor:
2464        del ctx
2465        assert not kwargs
2466        (x,) = args
2467        a = _apply_digital_filter_1d_numpy(
2468            x.detach().numpy(), gridtype, boundary, cval, filter, axis, False
2469        )
2470        return torch.as_tensor(a)
2471
2472      @staticmethod
2473      def backward(ctx: Any, *grad_outputs: _TorchTensor) -> _TorchTensor:
2474        del ctx
2475        (grad_output,) = grad_outputs
2476        a = _apply_digital_filter_1d_numpy(
2477            grad_output.detach().numpy(), gridtype, boundary, cval, filter, axis, True
2478        )
2479        return torch.as_tensor(a)
2480
2481    return InverseConvolution.apply(array)
2482
2483  if arraylib == 'jax':
2484    import jax
2485    import jax.numpy as jnp
2486    # It seems rather difficult to implement this digital filter (inverse convolution) in Jax.
2487    # https://jax.readthedocs.io/en/latest/jax.scipy.html sadly omits scipy.signal.filtfilt().
2488    # To include a (non-traceable) numpy function in Jax requires jax.experimental.host_callback
2489    # and/or defining a new jax.core.Primitive (which allows differentiability).  See
2490    # https://github.com/google/jax/issues/1142#issuecomment-544286585
2491    # https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb  :-(
2492    # https://github.com/google/jax/issues/5934
2493
2494    @jax.custom_gradient  # type: ignore[untyped-decorator]
2495    def jax_inverse_convolution(x: _JaxArray) -> _JaxArray:
2496      # This function is not jax-traceable due to the presence of to_py(), so jit and grad fail.
2497      x_py = np.asarray(x)  # to_py() deprecated.
2498      a = _apply_digital_filter_1d_numpy(x_py, gridtype, boundary, cval, filter, axis, False)
2499      y = jnp.asarray(a)
2500
2501      def grad(grad_output: _JaxArray) -> _JaxArray:
2502        grad_output_py = np.asarray(grad_output)  # to_py() deprecated.
2503        a = _apply_digital_filter_1d_numpy(
2504            grad_output_py, gridtype, boundary, cval, filter, axis, True
2505        )
2506        return jnp.asarray(a)
2507
2508      return y, grad
2509
2510    return jax_inverse_convolution(array)
2511
2512  assert arraylib == 'numpy'
2513  assert isinstance(array, np.ndarray)  # Help mypy.
2514  return _apply_digital_filter_1d_numpy(array, gridtype, boundary, cval, filter, axis, False)
2515
2516
2517def _apply_digital_filter_1d_numpy(
2518    array: _NDArray,
2519    gridtype: Gridtype,
2520    boundary: Boundary,
2521    cval: _ArrayLike,
2522    filter: Filter,
2523    axis: int,
2524    compute_backward: bool,
2525    /,
2526) -> _NDArray:
2527  """Version of _apply_digital_filter_1d` specialized to numpy array."""
2528  assert np.issubdtype(array.dtype, np.inexact)
2529  cval = np.asarray(cval).astype(array.dtype, copy=False)
2530
2531  # Use fast spline_filter1d() if we have a compatible gridtype, boundary, and filter:
2532  mode = {
2533      ('reflect', 'dual'): 'reflect',
2534      ('reflect', 'primal'): 'mirror',
2535      ('wrap', 'dual'): 'grid-wrap',
2536      ('wrap', 'primal'): 'wrap',
2537  }.get((boundary.name, gridtype.name))
2538  filter_is_compatible = isinstance(filter, CardinalBsplineFilter)
2539  use_split_filter1d = filter_is_compatible and mode
2540  if use_split_filter1d:
2541    assert isinstance(filter, CardinalBsplineFilter)  # Help mypy.
2542    assert filter.degree >= 2
2543    # compute_backward=True is same: matrix is symmetric and cval is unused.
2544    return scipy.ndimage.spline_filter1d(
2545        array, axis=axis, order=filter.degree, mode=mode, output=array.dtype
2546    )
2547
2548  array_dim = np.moveaxis(array, axis, 0)
2549  l = original_l = math.ceil(filter.radius) - 1
2550  x = np.arange(-l, l + 1, dtype=array.real.dtype)
2551  values = filter(x)
2552  size = array_dim.shape[0]
2553  src_index = np.arange(size)[:, None] + np.arange(len(values)) - l
2554  weight: _NDArray = np.full((size, len(values)), values)
2555  src_position = np.broadcast_to(0.5, len(values))
2556  src_index, weight = boundary.apply(src_index, weight, src_position, size, gridtype)
2557  if gridtype.name == 'primal' and boundary.name == 'wrap':
2558    # Overwrite redundant last row to preserve unreferenced last sample and thereby make the
2559    # matrix non-singular.
2560    src_index[-1] = [size - 1] + [0] * (src_index.shape[1] - 1)
2561    weight[-1] = [1.0] + [0.0] * (weight.shape[1] - 1)
2562  bandwidth = abs(src_index - np.arange(size)[:, None]).max()
2563  is_banded = bandwidth <= l + 1  # Add one for quadratic boundary and l == 1.
2564  # Currently, matrix is always banded unless boundary.name == 'wrap'.
2565
2566  data = weight.reshape(-1).astype(array.dtype, copy=False)
2567  row_ind = np.arange(size).repeat(src_index.shape[1])
2568  col_ind = src_index.reshape(-1)
2569  matrix = scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=(size, size))
2570  if compute_backward:
2571    matrix = matrix.T
2572
2573  if boundary.uses_cval and not compute_backward:
2574    cval_weight = 1.0 - np.asarray(matrix.sum(axis=-1))[:, 0]
2575    if array_dim.ndim == 2:  # Handle the case that we have array_flat.
2576      cval = np.tile(cval.reshape(-1), array_dim[0].size // cval.size)
2577    array_dim = array_dim - cval_weight.reshape(-1, *(1,) * array_dim[0].ndim) * cval
2578
2579  array_flat = array_dim.reshape(array_dim.shape[0], -1)
2580
2581  if is_banded:
2582    matrix = matrix.todia()
2583    assert np.all(np.diff(matrix.offsets) == 1)  # Consecutive, often [-l, l].
2584    l, u = -matrix.offsets[0], matrix.offsets[-1]
2585    assert l <= original_l + 1 and u <= original_l + 1, (l, u, original_l)
2586    options = dict(check_finite=False, overwrite_ab=True, overwrite_b=False)
2587    if _is_symmetric(matrix):
2588      array_flat = scipy.linalg.solveh_banded(matrix.data[-1 : l - 1 : -1], array_flat, **options)
2589    else:
2590      array_flat = scipy.linalg.solve_banded((l, u), matrix.data[::-1], array_flat, **options)
2591
2592  else:
2593    lu = scipy.sparse.linalg.splu(matrix.tocsc(), permc_spec='NATURAL')
2594    assert all(s <= size * len(values) for s in (lu.L.nnz, lu.U.nnz))  # Sparse.
2595    array_flat = lu.solve(array_flat)
2596
2597  array_dim = array_flat.reshape(array_dim.shape)
2598  return np.moveaxis(array_dim, 0, axis)
2599
2600
2601def resize(
2602    array: _Array,
2603    /,
2604    shape: Iterable[int],
2605    *,
2606    gridtype: str | Gridtype | None = None,
2607    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2608    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2609    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2610    cval: _ArrayLike = 0.0,
2611    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2612    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2613    gamma: str | Gamma | None = None,
2614    src_gamma: str | Gamma | None = None,
2615    dst_gamma: str | Gamma | None = None,
2616    scale: float | Iterable[float] = 1.0,
2617    translate: float | Iterable[float] = 0.0,
2618    precision: _DTypeLike = None,
2619    dtype: _DTypeLike = None,
2620    dim_order: Iterable[int] | None = None,
2621    num_threads: int | Literal['auto'] = 'auto',
2622) -> _Array:
2623  """Resample `array` (a grid of sample values) onto a grid with resolution `shape`.
2624
2625  The source `array` is any object recognized by `ARRAYLIBS`.  It is interpreted as a grid
2626  with `len(shape)` domain coordinate dimensions, where each grid sample value has shape
2627  `array.shape[len(shape):]`.
2628
2629  Some examples:
2630
2631  - A grayscale image has `array.shape = height, width` and resizing it with `len(shape) == 2`
2632    produces a new image of scalar values.
2633  - An RGB image has `array.shape = height, width, 3` and resizing it with `len(shape) == 2`
2634    produces a new image of RGB values.
2635  - An 3D grid of 3x3 Jacobians has `array.shape = Z, Y, X, 3, 3` and resizing it with
2636    `len(shape) == 3` produces a new 3D grid of Jacobians.
2637
2638  This function also allows scaling and translation from the source domain to the output domain
2639  through the parameters `scale` and `translate`.  For more general transforms, see `resample`.
2640
2641  Args:
2642    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
2643      The array must have numeric type.  Its first `len(shape)` dimensions are the domain
2644      coordinate dimensions.  Each grid dimension must be at least 1 for a `'dual'` grid or
2645      at least 2 for a `'primal'` grid.
2646    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2647      `array` must have at least as many dimensions as `len(shape)`.
2648    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
2649      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
2650      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
2651    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2652      Parameters `gridtype` and `src_gridtype` cannot both be set.
2653    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2654      Parameters `gridtype` and `dst_gridtype` cannot both be set.
2655    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2656      a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses `'reflect'`
2657      for upsampling and `'clamp'` for downsampling.
2658    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
2659      onto `array.shape[len(shape):]`.  It is subject to `src_gamma`.
2660    filter: The reconstruction kernel for each dimension in `shape`, specified as either a filter
2661      name in `FILTERS` or a `Filter` instance.  It is used during upsampling (i.e., magnification).
2662    prefilter: The prefilter kernel for each dimension in `shape`, specified as either a filter
2663      name in `FILTERS` or a `Filter` instance.  It is used during downsampling
2664      (i.e., minification).  If `None`, it inherits the value of `filter`.  The default
2665      `'lanczos3'` is good for natural images.  For vector graphics images, `'trapezoid'` is better
2666      because it avoids ringing artifacts.
2667    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from
2668      `array` and when creating output grid samples.  It is specified as either a name in `GAMMAS`
2669      or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default is
2670      `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
2671      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
2672      range [0.0, 1.0].
2673    src_gamma: Component transfer function used to "decode" `array` samples.
2674      Parameters `gamma` and `src_gamma` cannot both be set.
2675    dst_gamma: Component transfer function used to "encode" the output samples.
2676      Parameters `gamma` and `dst_gamma` cannot both be set.
2677    scale: Scaling factor applied to each dimension of the source domain when it is mapped onto
2678      the destination domain.
2679    translate: Offset applied to each dimension of the scaled source domain when it is mapped onto
2680      the destination domain.
2681    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
2682      on `array.dtype` and `dtype`.
2683    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
2684      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
2685      to the uint range.
2686    dim_order: Override the automatically selected order in which the grid dimensions are resized.
2687      Must contain a permutation of `range(len(shape))`.
2688    num_threads: Used to determine multithread parallelism if `array` is from `numpy`.  If set to
2689      `'auto'`, it is selected automatically.  Otherwise, it must be a positive integer.
2690
2691  Returns:
2692    An array of the same class as the source `array`, with shape `shape + array.shape[len(shape):]`
2693      and data type `dtype`.
2694
2695  **Example of image upsampling:**
2696
2697  >>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
2698  >>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.
2699
2700  <center>
2701  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_upsampled.png"/>
2702  </center>
2703
2704  **Example of image downsampling:**
2705
2706  >>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
2707  >>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
2708  >>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
2709  >>> downsampled = resize(array, (24, 48))
2710
2711  <center>
2712  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_downsampled2.png"/>
2713  </center>
2714
2715  **Unit test:**
2716
2717  >>> result = resize([1.0, 4.0, 5.0], shape=(4,))
2718  >>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
2719  """
2720  if isinstance(array, (tuple, list)):
2721    array = np.asarray(array)
2722  arraylib = _arr_arraylib(array)
2723  array_dtype = _arr_dtype(array)
2724  if not np.issubdtype(array_dtype, np.number):
2725    raise ValueError(f'Type {array.dtype} is not numeric.')
2726  shape2 = tuple(shape)
2727  array_ndim = len(array.shape)
2728  if not 0 < len(shape2) <= array_ndim:
2729    raise ValueError(f'Shape {array.shape} cannot be resized to {shape2}.')
2730  src_shape = array.shape[: len(shape2)]
2731  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2732      gridtype, src_gridtype, dst_gridtype, len(shape2), len(shape2)
2733  )
2734  boundary2 = np.broadcast_to(np.array(boundary), len(shape2))
2735  cval = np.broadcast_to(cval, array.shape[len(shape2) :])
2736  prefilter = filter if prefilter is None else prefilter
2737  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape2))]
2738  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), len(shape2))]
2739  dtype = array_dtype if dtype is None else np.dtype(dtype)
2740  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, array_dtype, dtype)
2741  scale2 = np.broadcast_to(np.array(scale), len(shape2))
2742  translate2 = np.broadcast_to(np.array(translate), len(shape2))
2743  del (shape, src_gridtype, dst_gridtype, boundary, filter, prefilter)
2744  del (src_gamma, dst_gamma, scale, translate)
2745  precision = _get_precision(precision, [array_dtype, dtype], [])
2746  weight_precision = _real_precision(precision)
2747
2748  is_noop = (
2749      all(src == dst for src, dst in zip(src_shape, shape2))
2750      and all(gt1 == gt2 for gt1, gt2 in zip(src_gridtype2, dst_gridtype2))
2751      and all(f.interpolating for f in prefilter2)
2752      and np.all(scale2 == 1.0)
2753      and np.all(translate2 == 0.0)
2754      and src_gamma2 == dst_gamma2
2755  )
2756  if is_noop:
2757    return array
2758
2759  if dim_order is None:
2760    dim_order = _arr_best_dims_order_for_resize(array, shape2)
2761  else:
2762    dim_order = tuple(dim_order)
2763    if sorted(dim_order) != list(range(len(shape2))):
2764      raise ValueError(f'{dim_order} not a permutation of {list(range(len(shape2)))}.')
2765
2766  array = src_gamma2.decode(array, precision)
2767  cval = _arr_numpy(src_gamma2.decode(cval, precision))
2768
2769  can_use_fast_box_downsampling = (
2770      using_numba
2771      and arraylib == 'numpy'
2772      and len(shape2) == 2
2773      and array_ndim in (2, 3)
2774      and all(src > dst for src, dst in zip(src_shape, shape2))
2775      and all(src % dst == 0 for src, dst in zip(src_shape, shape2))
2776      and all(gridtype.name == 'dual' for gridtype in src_gridtype2)
2777      and all(gridtype.name == 'dual' for gridtype in dst_gridtype2)
2778      and all(f.name in ('box', 'trapezoid') for f in prefilter2)
2779      and np.all(scale2 == 1.0)
2780      and np.all(translate2 == 0.0)
2781  )
2782  if can_use_fast_box_downsampling:
2783    assert isinstance(array, np.ndarray)  # Help mypy.
2784    array = _downsample_in_2d_using_box_filter(array, cast(Any, shape2))
2785    array = dst_gamma2.encode(array, dtype)
2786    return array
2787
2788  # Multidimensional resize can be expressed using einsum() with multiple per-dim resize matrices,
2789  # e.g., as in jax.image.resize().  A benefit is to seek the optimal order of multiplications.
2790  # However, efficiency often requires sparse resize matrices, which are unsupported in einsum().
2791  # Sparse tensors requested for tf.einsum: https://github.com/tensorflow/tensorflow/issues/43497
2792  # https://github.com/tensor-compiler/taco: C++ library that computes tensor algebra expressions
2793  # on sparse and dense tensors; however it does not interoperate with tensorflow, torch, or jax.
2794
2795  for dim in dim_order:
2796    skip_resize_on_this_dim = (
2797        shape2[dim] == array.shape[dim]
2798        and scale2[dim] == 1.0
2799        and translate2[dim] == 0.0
2800        and filter2[dim].interpolating
2801    )
2802    if skip_resize_on_this_dim:
2803      continue
2804
2805    def get_is_minification() -> bool:
2806      src_in_samples = src_gridtype2[dim].size_in_samples(array.shape[dim])
2807      dst_in_samples = dst_gridtype2[dim].size_in_samples(shape2[dim])
2808      return dst_in_samples / src_in_samples * scale2[dim] < 1.0
2809
2810    is_minification = get_is_minification()
2811    boundary_dim = boundary2[dim]
2812    if boundary_dim == 'auto':
2813      boundary_dim = 'clamp' if is_minification else 'reflect'
2814    boundary_dim = _get_boundary(boundary_dim)
2815    resize_matrix, cval_weight = _create_resize_matrix(
2816        array.shape[dim],
2817        shape2[dim],
2818        src_gridtype=src_gridtype2[dim],
2819        dst_gridtype=dst_gridtype2[dim],
2820        boundary=boundary_dim,
2821        filter=filter2[dim],
2822        prefilter=prefilter2[dim],
2823        scale=scale2[dim],
2824        translate=translate2[dim],
2825        dtype=weight_precision,
2826        arraylib=arraylib,
2827    )
2828
2829    array_dim: _Array = _arr_moveaxis(array, dim, 0)
2830    array_flat = _arr_reshape(array_dim, (array_dim.shape[0], -1))
2831    array_flat = _arr_possibly_make_contiguous(array_flat)
2832    if not is_minification and filter2[dim].requires_digital_filter:
2833      array_flat = _apply_digital_filter_1d(
2834          array_flat, src_gridtype2[dim], boundary_dim, cval, filter2[dim]
2835      )
2836
2837    array_flat = _arr_matmul_sparse_dense(resize_matrix, array_flat, num_threads=num_threads)
2838    if cval_weight is not None:
2839      cval_flat = np.broadcast_to(cval, array_dim.shape[1:]).reshape(-1)
2840      if np.issubdtype(array_dtype, np.complexfloating):
2841        cval_weight = _arr_astype(cval_weight, array_dtype)  # (Only necessary for 'tensorflow'.)
2842      array_flat += cval_weight[:, None] * cval_flat
2843
2844    if is_minification and filter2[dim].requires_digital_filter:  # use prefilter2[dim]?
2845      array_flat = _apply_digital_filter_1d(
2846          array_flat, dst_gridtype2[dim], boundary_dim, cval, filter2[dim]
2847      )
2848    array_dim = _arr_reshape(array_flat, (array_flat.shape[0], *array_dim.shape[1:]))
2849    array = _arr_moveaxis(array_dim, 0, dim)
2850
2851  array = dst_gamma2.encode(array, dtype)
2852  return array
2853
2854
2855_original_resize = resize
2856
2857
2858def resize_in_arraylib(array: _NDArray, /, *args: Any, arraylib: str, **kwargs: Any) -> _NDArray:
2859  """Evaluate the `resize()` operation using the specified array library from `ARRAYLIBS`."""
2860  _check_eq(_arr_arraylib(array), 'numpy')
2861  return _arr_numpy(_original_resize(_make_array(array, arraylib), *args, **kwargs))
2862
2863
2864def resize_in_numpy(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2865  """Evaluate the `resize()` operation using the `numpy` library."""
2866  return resize_in_arraylib(array, *args, arraylib='numpy', **kwargs)
2867
2868
2869def resize_in_tensorflow(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2870  """Evaluate the `resize()` operation using the `tensorflow` library."""
2871  return resize_in_arraylib(array, *args, arraylib='tensorflow', **kwargs)
2872
2873
2874def resize_in_torch(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2875  """Evaluate the `resize()` operation using the `torch` library."""
2876  return resize_in_arraylib(array, *args, arraylib='torch', **kwargs)
2877
2878
2879def resize_in_jax(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2880  """Evaluate the `resize()` operation using the `jax` library."""
2881  return resize_in_arraylib(array, *args, arraylib='jax', **kwargs)
2882
2883
2884def _resize_possibly_in_arraylib(
2885    array: _Array, /, *args: Any, arraylib: str, **kwargs: Any
2886) -> _AnyArray:
2887  """If `array` is from numpy, evaluate `resize()` using the array library from `ARRAYLIBS`."""
2888  if _arr_arraylib(array) == 'numpy':
2889    return _arr_numpy(
2890        _original_resize(_make_array(cast(_ArrayLike, array), arraylib), *args, **kwargs)
2891    )
2892  return _original_resize(array, *args, **kwargs)
2893
2894
2895@functools.cache
2896def _create_jaxjit_resize() -> Callable[..., _Array]:
2897  """Lazily invoke `jax.jit` on `resize`."""
2898  import jax
2899
2900  jitted: Any = jax.jit(
2901      _original_resize,
2902      static_argnums=(1,),
2903      static_argnames=list(_original_resize.__kwdefaults__ or []),
2904  )
2905  return jitted
2906
2907
2908def jaxjit_resize(array: _Array, /, *args: Any, **kwargs: Any) -> _Array:
2909  """Compute `resize` but with resize function jitted using Jax."""
2910  return _create_jaxjit_resize()(array, *args, **kwargs)  # pylint: disable=not-callable
2911
2912
2913def uniform_resize(
2914    array: _Array,
2915    /,
2916    shape: Iterable[int],
2917    *,
2918    object_fit: Literal['contain', 'cover'] = 'contain',
2919    gridtype: str | Gridtype | None = None,
2920    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2921    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2922    boundary: str | Boundary | Iterable[str | Boundary] = 'natural',  # Instead of 'auto' default.
2923    scale: float | Iterable[float] = 1.0,
2924    translate: float | Iterable[float] = 0.0,
2925    **kwargs: Any,
2926) -> _Array:
2927  """Resample `array` onto a grid with resolution `shape` but with uniform scaling.
2928
2929  Calls function `resize` with `scale` and `translate` set such that the aspect ratio of `array`
2930  is preserved.  The effect is similar to CSS `object-fit: contain`.
2931  The parameter `boundary` (whose default is changed to `'natural'`) determines the values assigned
2932  outside the source domain.
2933
2934  Args:
2935    array: Regular grid of source sample values.
2936    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2937      `array` must have at least as many dimensions as `len(shape)`.
2938    object_fit: Like CSS `object-fit`.  If `'contain'`, `array` is resized uniformly to fit within
2939      `shape`. If `'cover'`, `array` is resized to fully cover `shape`.
2940    gridtype: Placement of samples on all dimensions of both the source and output domain grids.
2941    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2942    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2943    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2944      a name in `BOUNDARIES` or a `Boundary` instance.  The default is `'natural'`, which assigns
2945      `cval` to output points that map outside the source unit domain.
2946    scale: Parameter may not be specified.
2947    translate: Parameter may not be specified.
2948    **kwargs: Additional parameters for `resize` function (including `cval`).
2949
2950  Returns:
2951    An array with shape `shape + array.shape[len(shape):]`.
2952
2953  >>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
2954  array([[0., 1., 1., 0.],
2955         [0., 1., 1., 0.]])
2956
2957  >>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
2958  array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
2959         [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
2960
2961  >>> a = np.arange(6.0).reshape(2, 3)
2962  >>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
2963  array([[0.5, 1.5],
2964         [3.5, 4.5]])
2965  """
2966  if scale != 1.0 or translate != 0.0:
2967    raise ValueError('`uniform_resize()` does not accept `scale` or `translate` parameters.')
2968  if isinstance(array, (tuple, list)):
2969    array = np.asarray(array)
2970  shape = tuple(shape)
2971  array_ndim = len(array.shape)
2972  if not 0 < len(shape) <= array_ndim:
2973    raise ValueError(f'Shape {array.shape} cannot be resized to {shape}.')
2974  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2975      gridtype, src_gridtype, dst_gridtype, len(shape), len(shape)
2976  )
2977  raw_scales = [
2978      dst_gridtype2[dim].size_in_samples(shape[dim])
2979      / src_gridtype2[dim].size_in_samples(array.shape[dim])
2980      for dim in range(len(shape))
2981  ]
2982  scale0 = {'contain': min(raw_scales), 'cover': max(raw_scales)}[object_fit]
2983  scale2 = scale0 / np.array(raw_scales)
2984  translate = (1.0 - scale2) / 2
2985  return resize(array, shape, boundary=boundary, scale=scale2, translate=translate, **kwargs)
2986
2987
2988_MAX_BLOCK_SIZE_RECURSING = -999  # Special value to indicate re-invocation on partitioned blocks.
2989
2990
2991def resample(
2992    array: _Array,
2993    /,
2994    coords: _ArrayLike,
2995    *,
2996    gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual',
2997    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2998    cval: _ArrayLike = 0.0,
2999    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3000    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3001    gamma: str | Gamma | None = None,
3002    src_gamma: str | Gamma | None = None,
3003    dst_gamma: str | Gamma | None = None,
3004    jacobian: _ArrayLike | None = None,
3005    precision: _DTypeLike = None,
3006    dtype: _DTypeLike = None,
3007    max_block_size: int = 40_000,
3008    debug: bool = False,
3009) -> _Array:
3010  """Interpolate `array` (a grid of samples) at specified unit-domain coordinates `coords`.
3011
3012  The last dimension of `coords` contains unit-domain coordinates at which to interpolate the
3013  domain grid samples in `array`.
3014
3015  The number of coordinates (`coords.shape[-1]`) determines how to interpret `array`: its first
3016  `coords.shape[-1]` dimensions define the grid, and the remaining dimensions describe each grid
3017  sample (e.g., scalar, vector, tensor).
3018
3019  Concretely, the grid has shape `array.shape[:coords.shape[-1]]` and each grid sample has shape
3020  `array.shape[coords.shape[-1]:]`.
3021
3022  Examples include:
3023
3024  - Resample a grayscale image with `array.shape = height, width` onto a new grayscale image with
3025    `new.shape = height2, width2` by using `coords.shape = height2, width2, 2`.
3026
3027  - Resample an RGB image with `array.shape = height, width, 3` onto a new RGB image with
3028    `new.shape = height2, width2, 3` by using `coords.shape = height2, width2, 2`.
3029
3030  - Sample an RGB image at `num` 2D points along a line segment by using `coords.shape = num, 2`.
3031
3032  - Sample an RGB image at a single 2D point by using `coords.shape = (2,)`.
3033
3034  - Sample a 3D grid of 3x3 Jacobians with `array.shape = nz, ny, nx, 3, 3` along a 2D plane by
3035    using `coords.shape = height, width, 3`.
3036
3037  - Map a grayscale image through a color map by using `array.shape = 256, 3` and
3038    `coords.shape = height, width`.
3039
3040  Args:
3041    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3042      The array must have numeric type.  The coordinate dimensions appear first, and
3043      each grid sample may have an arbitrary shape.  Each grid dimension must be at least 1 for
3044      a `'dual'` grid or at least 2 for a `'primal'` grid.
3045    coords: Grid of points at which to resample `array`.  The point coordinates are in the last
3046      dimension of `coords`.  The domain associated with the source grid is a unit hypercube,
3047      i.e. with a range [0, 1] on each coordinate dimension.  The output grid has shape
3048      `coords.shape[:-1]` and each of its grid samples has shape `array.shape[coords.shape[-1]:]`.
3049    gridtype: Placement of the samples in the source domain grid for each dimension, specified as
3050      either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`.
3051    boundary: The reconstruction boundary rule for each dimension in `coords.shape[-1]`, specified
3052      as either a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses
3053      `'reflect'` for upsampling and `'clamp'` for downsampling.
3054    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
3055      onto the shape `array.shape[coords.shape[-1]:]`.  It is subject to `src_gamma`.
3056    filter: The reconstruction kernel for each dimension in `coords.shape[-1]`, specified as either
3057      a filter name in `FILTERS` or a `Filter` instance.
3058    prefilter: The prefilter kernel for each dimension in `coords.shape[:-1]`, specified as either
3059      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3060      (i.e., minification).  If `None`, it inherits the value of `filter`.
3061    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples
3062      from `array` and when creating output grid samples.  It is specified as either a name in
3063      `GAMMAS` or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default
3064      is `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
3065      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
3066      range [0.0, 1.0].
3067    src_gamma: Component transfer function used to "decode" `array` samples.
3068      Parameters `gamma` and `src_gamma` cannot both be set.
3069    dst_gamma: Component transfer function used to "encode" the output samples.
3070      Parameters `gamma` and `dst_gamma` cannot both be set.
3071    jacobian: Optional array, which must be broadcastable onto the shape
3072      `coords.shape[:-1] + (coords.shape[-1], coords.shape[-1])`, storing for each point in the
3073      output grid the Jacobian matrix of the map from the unit output domain to the unit source
3074      domain.  If omitted, it is estimated by computing finite differences on `coords`.
3075    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3076      on `array.dtype`, `coords.dtype`, and `dtype`.
3077    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3078      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3079      to the uint range.
3080    max_block_size: If nonzero, maximum number of grid points in `coords` before the resampling
3081      evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
3082    debug: Show internal information.
3083
3084  Returns:
3085    A new sample grid of shape `coords.shape[:-1]`, represented as an array of shape
3086    `coords.shape[:-1] + array.shape[coords.shape[-1]:]`, of the same array library type as
3087    the source array.
3088
3089  **Example of resample operation:**
3090
3091  <center>
3092  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_warp_coords.png"/>
3093  </center>
3094
3095  For reference, the identity resampling for a scalar-valued grid with the default grid-type
3096  `'dual'` is:
3097
3098  >>> array = np.random.default_rng(0).random((5, 7, 3))
3099  >>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
3100  >>> new_array = resample(array, coords)
3101  >>> assert np.allclose(new_array, array)
3102
3103  It is more efficient to use the function `resize` for the special case where the `coords` are
3104  obtained as simple scaling and translation of a new regular grid over the source domain:
3105
3106  >>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
3107  >>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
3108  >>> coords = (coords - translate) / scale
3109  >>> resampled = resample(array, coords)
3110  >>> resized = resize(array, new_shape, scale=scale, translate=translate)
3111  >>> assert np.allclose(resampled, resized)
3112  """
3113  if isinstance(array, (tuple, list)):
3114    array = np.asarray(array)
3115  arraylib = _arr_arraylib(array)
3116  if len(array.shape) == 0:
3117    array = array[None]
3118  coords = np.atleast_1d(coords)
3119  if not np.issubdtype(_arr_dtype(array), np.number):
3120    raise ValueError(f'Type {array.dtype} is not numeric.')
3121  if not np.issubdtype(coords.dtype, np.floating):
3122    raise ValueError(f'Type {coords.dtype} is not floating.')
3123  array_ndim = len(array.shape)
3124  if coords.ndim == 1 and coords.shape[0] > 1 and array_ndim == 1:
3125    coords = coords[:, None]
3126  grid_ndim = coords.shape[-1]
3127  grid_shape = array.shape[:grid_ndim]
3128  sample_shape = array.shape[grid_ndim:]
3129  resampled_ndim = coords.ndim - 1
3130  resampled_shape = coords.shape[:-1]
3131  if grid_ndim > array_ndim:
3132    raise ValueError(
3133        f'There are more coordinate dimensions ({grid_ndim}) in {coords=} than in {array.shape=}.'
3134    )
3135  gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(gridtype), grid_ndim)]
3136  boundary2 = np.broadcast_to(np.array(boundary), grid_ndim).tolist()
3137  cval = np.broadcast_to(cval, sample_shape)
3138  prefilter = filter if prefilter is None else prefilter
3139  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), grid_ndim)]
3140  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), resampled_ndim)]
3141  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3142  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, _arr_dtype(array), dtype)
3143  del gridtype, boundary, filter, prefilter, src_gamma, dst_gamma
3144  if jacobian is not None:
3145    jacobian = np.broadcast_to(jacobian, resampled_shape + (coords.shape[-1],) * 2)
3146  precision = _get_precision(precision, [_arr_dtype(array), dtype], [coords.dtype])
3147  weight_precision = _real_precision(precision)
3148  coords = coords.astype(weight_precision, copy=False)
3149  is_minification = False  # Current limitation; no prefiltering!
3150  assert max_block_size >= 0 or max_block_size == _MAX_BLOCK_SIZE_RECURSING
3151  for dim in range(grid_ndim):
3152    if boundary2[dim] == 'auto':
3153      boundary2[dim] = 'clamp' if is_minification else 'reflect'
3154    boundary2[dim] = _get_boundary(boundary2[dim])
3155
3156  if max_block_size != _MAX_BLOCK_SIZE_RECURSING:
3157    array = src_gamma2.decode(array, precision)
3158    for dim in range(grid_ndim):
3159      assert not is_minification
3160      if filter2[dim].requires_digital_filter:
3161        array = _apply_digital_filter_1d(
3162            array, gridtype2[dim], boundary2[dim], cval, filter2[dim], axis=dim
3163        )
3164    cval = _arr_numpy(src_gamma2.decode(cval, precision))
3165
3166  if math.prod(resampled_shape) > max_block_size > 0:
3167    block_shape = _block_shape_with_min_size(resampled_shape, max_block_size)
3168    if debug:
3169      print(f'(resample: splitting coords into blocks {block_shape}).')
3170    coord_blocks = _split_array_into_blocks(coords, block_shape)
3171
3172    def process_block(coord_block: _NDArray) -> _Array:
3173      return resample(
3174          array,
3175          coord_block,
3176          gridtype=gridtype2,
3177          boundary=boundary2,
3178          cval=cval,
3179          filter=filter2,
3180          prefilter=prefilter2,
3181          src_gamma='identity',
3182          dst_gamma=dst_gamma2,
3183          jacobian=jacobian,
3184          precision=precision,
3185          dtype=dtype,
3186          max_block_size=_MAX_BLOCK_SIZE_RECURSING,
3187      )
3188
3189    result_blocks = _map_function_over_blocks(coord_blocks, process_block)
3190    array = _merge_array_from_blocks(result_blocks)
3191    return array
3192
3193  # A concrete example of upsampling:
3194  #   array = np.ones((5, 7, 3))  # source RGB image has height=5 width=7
3195  #   coords = np.random.default_rng(0).random((8, 9, 2))  # output RGB image has height=8 width=9
3196  #   resample(array, coords, filter=('cubic', 'lanczos3'))
3197  #   grid_shape = 5, 7  grid_ndim = 2
3198  #   resampled_shape = 8, 9  resampled_ndim = 2
3199  #   sample_shape = (3,)
3200  #   src_float_index.shape = 8, 9
3201  #   src_first_index.shape = 8, 9
3202  #   sample_index.shape = (4,) for dim == 0, then (6,) for dim == 1
3203  #   weight = [shape(8, 9, 4), shape(8, 9, 6)]
3204  #   src_index = [shape(8, 9, 4), shape(8, 9, 6)]
3205
3206  # Both:[shape(8, 9, 4), shape(8, 9, 6)]
3207  weight: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3208  src_index: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3209  uses_cval = False
3210  all_num_samples = []  # will be [4, 6]
3211
3212  for dim in range(grid_ndim):
3213    src_size = grid_shape[dim]  # scalar
3214    coords_dim = coords[..., dim]  # (8, 9)
3215    radius = filter2[dim].radius  # scalar
3216    num_samples = int(np.ceil(radius * 2))  # scalar
3217    all_num_samples.append(num_samples)
3218
3219    boundary_dim = boundary2[dim]
3220    coords_dim = boundary_dim.preprocess_coordinates(coords_dim)
3221
3222    # Sample positions mapped back to source unit domain [0, 1].
3223    src_float_index = gridtype2[dim].index_from_point(coords_dim, src_size)  # (8, 9)
3224    src_first_index = (
3225        np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
3226        - (num_samples - 1) // 2
3227    )  # (8, 9)
3228
3229    sample_index = np.arange(num_samples, dtype=np.int32)  # (4,) then (6,)
3230    src_index[dim] = src_first_index[..., None] + sample_index  # (8, 9, 4) then (8, 9, 6)
3231    if filter2[dim].name == 'trapezoid':
3232      # (It might require changing the filter radius at every sample.)
3233      raise ValueError('resample() cannot use adaptive `trapezoid` filter.')
3234    if filter2[dim].name == 'impulse':
3235      weight[dim] = np.ones_like(src_index[dim], weight_precision)
3236    else:
3237      x = src_float_index[..., None] - src_index[dim].astype(weight_precision)
3238      weight[dim] = filter2[dim](x).astype(weight_precision, copy=False)
3239      if filter2[dim].name != 'narrowbox' and (
3240          is_minification or not filter2[dim].partition_of_unity
3241      ):
3242        weight[dim] = weight[dim] / weight[dim].sum(axis=-1)[..., None]
3243
3244    src_index[dim], weight[dim] = boundary_dim.apply(
3245        src_index[dim], weight[dim], coords_dim, src_size, gridtype2[dim]
3246    )
3247    if boundary_dim.uses_cval or filter2[dim].name == 'narrowbox':
3248      uses_cval = True
3249
3250  # Gather the samples.
3251
3252  # Recall that src_index = [shape(8, 9, 4), shape(8, 9, 6)].
3253  src_index_expanded = []
3254  for dim in range(grid_ndim):
3255    src_index_dim = np.moveaxis(
3256        src_index[dim].reshape(src_index[dim].shape + (1,) * (grid_ndim - 1)),
3257        resampled_ndim,
3258        resampled_ndim + dim,
3259    )
3260    src_index_expanded.append(src_index_dim)
3261  indices = tuple(src_index_expanded)  # (shape(8, 9, 4, 1), shape(8, 9, 1, 6))
3262  samples = _arr_getitem(array, indices)  # (8, 9, 4, 6, 3)
3263
3264  # Indirectly derive samples.ndim (which is unavailable during Tensorflow grad computation).
3265  samples_ndim = resampled_ndim + grid_ndim + len(sample_shape)
3266
3267  # Compute an Einstein summation over the samples and each of the per-dimension weights.
3268
3269  def label(dims: Iterable[int]) -> str:
3270    return ''.join(chr(ord('a') + i) for i in dims)
3271
3272  operands = [samples]  # (8, 9, 4, 6, 3)
3273  assert samples_ndim < 26  # Letters 'a' through 'z'.
3274  labels = [label(range(samples_ndim))]  # ['abcde']
3275  for dim in range(grid_ndim):
3276    operands.append(weight[dim])  # (8, 9, 4), then (8, 9, 6)
3277    labels.append(label(list(range(resampled_ndim)) + [resampled_ndim + dim]))  # 'abc' then 'abd'
3278  output_label = label(
3279      list(range(resampled_ndim)) + list(range(resampled_ndim + grid_ndim, samples_ndim))
3280  )  # 'abe'
3281  subscripts = ','.join(labels) + '->' + output_label  # 'abcde,abc,abd->abe'
3282  # Starting in numpy 2.0, np.einsum() outputs np.float64 even with all np.float32 inputs;
3283  # GPT: "aligns np.einsum with other functions where intermediate calculations use higher
3284  # precision (np.float64) regardless of input type when floating-point arithmetic is involved."
3285  # we could explicitly add the parameter `dtype=precision`.
3286  array = _arr_einsum(subscripts, *operands)  # (8, 9, 3)
3287
3288  # Gathering `samples` is the memory bottleneck.  It would be ideal if the gather() and einsum()
3289  # computations could be fused.  In Jax, https://github.com/google/jax/issues/3206 suggests
3290  # that this may become possible.  In any case, for large outputs it helps to partition the
3291  # evaluation over output tiles (using max_block_size).
3292
3293  if uses_cval:
3294    cval_weight = 1.0 - np.multiply.reduce(
3295        [weight[dim].sum(axis=-1) for dim in range(resampled_ndim)]
3296    )  # (8, 9)
3297    cval_weight_reshaped = cval_weight.reshape(cval_weight.shape + (1,) * len(sample_shape))
3298    array += _make_array((cval_weight_reshaped * cval).astype(precision, copy=False), arraylib)
3299
3300  array = dst_gamma2.encode(array, dtype)
3301  return array
3302
3303
3304def resample_affine(
3305    array: _Array,
3306    /,
3307    shape: Iterable[int],
3308    matrix: _ArrayLike,
3309    *,
3310    gridtype: str | Gridtype | None = None,
3311    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3312    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3313    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3314    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3315    precision: _DTypeLike = None,
3316    dtype: _DTypeLike = None,
3317    **kwargs: Any,
3318) -> _Array:
3319  """Resample a source array using an affinely transformed grid of given shape.
3320
3321  The `matrix` transformation can be linear,
3322    `source_point = matrix @ destination_point`,
3323  or it can be affine where the last matrix column is an offset vector,
3324    `source_point = matrix @ (destination_point, 1.0)`.
3325
3326  Args:
3327    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3328      The array must have numeric type.  The number of grid dimensions is determined from
3329      `matrix.shape[0]`; the remaining dimensions are for each sample value and are all
3330      linearly interpolated.
3331    shape: Dimensions of the desired destination grid.  The number of destination grid dimensions
3332      may be different from that of the source grid.
3333    matrix: 2D array for a linear or affine transform from unit-domain destination points
3334      (in a space with `len(shape)` dimensions) into unit-domain source points (in a space with
3335      `matrix.shape[0]` dimensions).  If the matrix has `len(shape) + 1` columns, the last column
3336      is the affine offset (i.e., translation).
3337    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
3338      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
3339      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
3340    src_gridtype: Placement of samples in the source domain grid for each dimension.
3341      Parameters `gridtype` and `src_gridtype` cannot both be set.
3342    dst_gridtype: Placement of samples in the output domain grid for each dimension.
3343      Parameters `gridtype` and `dst_gridtype` cannot both be set.
3344    filter: The reconstruction kernel for each dimension in `matrix.shape[0]`, specified as either
3345      a filter name in `FILTERS` or a `Filter` instance.
3346    prefilter: The prefilter kernel for each dimension in `len(shape)`, specified as either
3347      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3348      (i.e., minification).  If `None`, it inherits the value of `filter`.
3349    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3350      on `array.dtype` and `dtype`.
3351    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3352      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3353      to the uint range.
3354    **kwargs: Additional parameters for `resample` function.
3355
3356  Returns:
3357    An array of the same class as the source `array`, representing a grid with specified `shape`,
3358    where each grid value is resampled from `array`.  Thus the shape of the returned array is
3359    `shape + array.shape[matrix.shape[0]:]`.
3360  """
3361  if isinstance(array, (tuple, list)):
3362    array = np.asarray(array)
3363  shape = tuple(shape)
3364  matrix = np.asarray(matrix)
3365  dst_ndim = len(shape)
3366  if matrix.ndim != 2:
3367    raise ValueError(f'Array {matrix} is not 2D matrix.')
3368  src_ndim = matrix.shape[0]
3369  # grid_shape = array.shape[:src_ndim]
3370  is_affine = matrix.shape[1] == dst_ndim + 1
3371  if src_ndim > len(array.shape):
3372    raise ValueError(
3373        f'Matrix {matrix} has more rows ({matrix.shape[0]}) than ndim in {array.shape=}.'
3374    )
3375  if matrix.shape[1] != dst_ndim and not is_affine:
3376    raise ValueError(
3377        f'Matrix has {matrix.shape=}, but we expect either {dst_ndim} or {dst_ndim + 1} columns.'
3378    )
3379  src_gridtype2, dst_gridtype2 = _get_gridtypes(
3380      gridtype, src_gridtype, dst_gridtype, src_ndim, dst_ndim
3381  )
3382  prefilter = filter if prefilter is None else prefilter
3383  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), src_ndim)]
3384  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), dst_ndim)]
3385  del src_gridtype, dst_gridtype, filter, prefilter
3386  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3387  precision = _get_precision(precision, [_arr_dtype(array), dtype], [])
3388  weight_precision = _real_precision(precision)
3389
3390  dst_position_list = []  # per dimension
3391  for dim in range(dst_ndim):
3392    dst_size = shape[dim]
3393    dst_index = np.arange(dst_size, dtype=weight_precision)
3394    dst_position_list.append(dst_gridtype2[dim].point_from_index(dst_index, dst_size))
3395  dst_position = np.meshgrid(*dst_position_list, indexing='ij')
3396
3397  linear_matrix = matrix[:, :-1] if is_affine else matrix
3398  src_position = np.tensordot(linear_matrix, dst_position, 1)
3399  coords = np.moveaxis(src_position, 0, -1)
3400  if is_affine:
3401    coords += matrix[:, -1]
3402
3403  # TODO: Based on grid_shape, shape, linear_matrix, and prefilter, determine a
3404  # convolution prefilter and apply it to bandlimit 'array', using boundary for padding.
3405
3406  return resample(
3407      array,
3408      coords,
3409      gridtype=src_gridtype2,
3410      filter=filter2,
3411      prefilter=prefilter2,
3412      precision=precision,
3413      dtype=dtype,
3414      **kwargs,
3415  )
3416
3417
3418def _resize_using_resample(
3419    array: _Array,
3420    /,
3421    shape: Iterable[int],
3422    *,
3423    scale: _ArrayLike = 1.0,
3424    translate: _ArrayLike = 0.0,
3425    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3426    fallback: bool = False,
3427    **kwargs: Any,
3428) -> _Array:
3429  """Use the more general `resample` operation for `resize`, as a debug tool."""
3430  if isinstance(array, (tuple, list)):
3431    array = np.asarray(array)
3432  shape = tuple(shape)
3433  scale = np.broadcast_to(scale, len(shape))
3434  translate = np.broadcast_to(translate, len(shape))
3435  # TODO: let resample() do prefiltering for proper downsampling.
3436  has_minification = np.any(np.array(shape) < array.shape[: len(shape)]) or np.any(scale < 1.0)
3437  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape))]
3438  has_auto_trapezoid = any(f.name == 'trapezoid' for f in filter2)
3439  if fallback and (has_minification or has_auto_trapezoid):
3440    return _original_resize(array, shape, scale=scale, translate=translate, filter=filter, **kwargs)
3441  offset = -translate / scale
3442  matrix = np.concatenate([np.diag(1.0 / scale), offset[:, None]], axis=1)
3443  return resample_affine(array, shape, matrix, filter=filter, **kwargs)
3444
3445
3446def rotation_about_center_in_2d(
3447    src_shape: _ArrayLike,
3448    /,
3449    angle: float,
3450    *,
3451    new_shape: _ArrayLike | None = None,
3452    scale: float = 1.0,
3453) -> _NDArray:
3454  """Return the 3x3 matrix mapping destination into a source unit domain.
3455
3456  The returned matrix accounts for the possibly non-square domain shapes.
3457
3458  Args:
3459    src_shape: Resolution `(ny, nx)` of the source domain grid.
3460    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3461      onto the destination domain.
3462    new_shape: Resolution `(ny, nx)` of the destination domain grid; it defaults to `src_shape`.
3463    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3464  """
3465
3466  def translation_matrix(vector: _NDArray) -> _NDArray:
3467    matrix = np.eye(len(vector) + 1)
3468    matrix[:-1, -1] = vector
3469    return matrix
3470
3471  def scaling_matrix(scale: _NDArray) -> _NDArray:
3472    return np.diag(tuple(scale) + (1.0,))
3473
3474  def rotation_matrix_2d(angle: float) -> _NDArray:
3475    cos, sin = np.cos(angle), np.sin(angle)
3476    return np.array([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
3477
3478  src_shape = np.asarray(src_shape)
3479  new_shape = src_shape if new_shape is None else np.asarray(new_shape)
3480  _check_eq(src_shape.shape, (2,))
3481  _check_eq(new_shape.shape, (2,))
3482  half = np.array([0.5, 0.5])
3483  matrix = (
3484      translation_matrix(half)
3485      @ scaling_matrix(min(src_shape) / src_shape)
3486      @ rotation_matrix_2d(angle)
3487      @ scaling_matrix(scale * new_shape / min(new_shape))
3488      @ translation_matrix(-half)
3489  )
3490  assert np.allclose(matrix[-1], [0.0, 0.0, 1.0])
3491  return matrix
3492
3493
3494def rotate_image_about_center(
3495    image: _NDArray,
3496    /,
3497    angle: float,
3498    *,
3499    new_shape: _ArrayLike | None = None,
3500    scale: float = 1.0,
3501    num_rotations: int = 1,
3502    **kwargs: Any,
3503) -> _NDArray:
3504  """Return a copy of `image` rotated about its center.
3505
3506  Args:
3507    image: Source grid samples; the first two dimensions are spatial (ny, nx).
3508    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3509      onto the destination domain.
3510    new_shape: Resolution `(ny, nx)` of the output grid; it defaults to `image.shape[:2]`.
3511    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3512    num_rotations: Number of rotations (each by `angle`).  Successive resamplings are useful in
3513      analyzing the filtering quality.
3514    **kwargs: Additional parameters for `resample_affine`.
3515  """
3516  new_shape = image.shape[:2] if new_shape is None else np.asarray(new_shape)
3517  matrix = rotation_about_center_in_2d(image.shape[:2], angle, new_shape=new_shape, scale=scale)
3518  for _ in range(num_rotations):
3519    image = resample_affine(image, new_shape, matrix[:-1], **kwargs)
3520  return image
3521
3522
3523def pil_image_resize(
3524    array: _ArrayLike,
3525    /,
3526    shape: Iterable[int],
3527    *,
3528    filter: str,
3529    boundary: str = 'natural',
3530    cval: float = 0.0,
3531) -> _NDArray:
3532  """Invoke `PIL.Image.resize` using the same parameters as `resize`."""
3533  import PIL.Image
3534
3535  if boundary != 'natural':
3536    raise ValueError(f"{boundary=} must equal 'natural'.")
3537  del cval
3538  array = np.asarray(array)
3539  assert 1 <= array.ndim <= 3
3540  assert np.issubdtype(array.dtype, np.floating)
3541  shape = tuple(shape)
3542  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3543  if array.ndim == 1:
3544    return pil_image_resize(array[None], (1, *shape), filter=filter)[0]
3545  if not hasattr(PIL.Image, 'Resampling'):  # Pillow<9.0
3546    PIL.Image.Resampling = PIL.Image  # type: ignore
3547  filters = {
3548      'impulse': PIL.Image.Resampling.NEAREST,
3549      'box': PIL.Image.Resampling.BOX,
3550      'triangle': PIL.Image.Resampling.BILINEAR,
3551      'hamming1': PIL.Image.Resampling.HAMMING,
3552      'cubic': PIL.Image.Resampling.BICUBIC,
3553      'lanczos3': PIL.Image.Resampling.LANCZOS,
3554  }
3555  if filter not in filters:
3556    raise ValueError(f'{filter=} not in {filters=}.')
3557  pil_resample = filters[filter]
3558  ny, nx = shape
3559  if array.ndim == 2:
3560    return np.array(PIL.Image.fromarray(array).resize((nx, ny), resample=pil_resample), array.dtype)
3561  stack = []
3562  for channel in np.moveaxis(array, -1, 0):
3563    pil_image = PIL.Image.fromarray(channel).resize((nx, ny), resample=pil_resample)
3564    stack.append(np.array(pil_image, array.dtype))
3565  return np.dstack(stack)
3566
3567
3568def cv_resize(
3569    array: _ArrayLike,
3570    /,
3571    shape: Iterable[int],
3572    *,
3573    filter: str,
3574    boundary: str = 'clamp',
3575    cval: float = 0.0,
3576) -> _NDArray:
3577  """Invoke `cv.resize` using the same parameters as `resize`."""
3578  import cv2 as cv
3579
3580  if boundary != 'clamp':
3581    raise ValueError(f"{boundary=} must equal 'clamp'.")
3582  del cval
3583  array = np.asarray(array)
3584  assert 1 <= array.ndim <= 3
3585  shape = tuple(shape)
3586  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3587  if array.ndim == 1:
3588    return cv_resize(array[None], (1, *shape), filter=filter)[0]
3589  filters = {
3590      'impulse': cv.INTER_NEAREST,  # Or consider cv.INTER_NEAREST_EXACT.
3591      'triangle': cv.INTER_LINEAR_EXACT,  # Or just cv.INTER_LINEAR.
3592      'trapezoid': cv.INTER_AREA,
3593      'sharpcubic': cv.INTER_CUBIC,
3594      'lanczos4': cv.INTER_LANCZOS4,
3595  }
3596  if filter not in filters:
3597    raise ValueError(f'{filter=} not in {filters=}.')
3598  interpolation = filters[filter]
3599  result = cv.resize(array, shape[::-1], interpolation=interpolation)
3600  if array.ndim == 3 and result.ndim == 2:
3601    assert array.shape[2] == 1
3602    return result[..., None]  # Add back the last dimension dropped by cv.resize().
3603  return result
3604
3605
3606def scipy_ndimage_resize(
3607    array: _ArrayLike,
3608    /,
3609    shape: Iterable[int],
3610    *,
3611    filter: str,
3612    boundary: str = 'reflect',
3613    cval: float = 0.0,
3614    scale: float | Iterable[float] = 1.0,
3615    translate: float | Iterable[float] = 0.0,
3616) -> _NDArray:
3617  """Invoke `scipy.ndimage.map_coordinates` using the same parameters as `resize`."""
3618  array = np.asarray(array)
3619  shape = tuple(shape)
3620  assert 1 <= len(shape) <= array.ndim
3621  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3622  if filter not in filters:
3623    raise ValueError(f'{filter=} not in {filters=}.')
3624  order = filters[filter]
3625  boundaries = {'reflect': 'reflect', 'wrap': 'grid-wrap', 'clamp': 'nearest', 'border': 'constant'}
3626  if boundary not in boundaries:
3627    raise ValueError(f'{boundary=} not in {boundaries=}.')
3628  mode = boundaries[boundary]
3629  shape_all = shape + array.shape[len(shape) :]
3630  coords = np.moveaxis(np.indices(shape_all, array.dtype), 0, -1)
3631  coords[..., : len(shape)] = (
3632      (coords[..., : len(shape)] + 0.5) / shape - np.asarray(translate)
3633  ) / np.asarray(scale) * np.array(array.shape)[: len(shape)] - 0.5
3634  coords = np.moveaxis(coords, -1, 0)
3635  return scipy.ndimage.map_coordinates(array, coords, order=order, mode=mode, cval=cval)
3636
3637
3638def skimage_transform_resize(
3639    array: _ArrayLike,
3640    /,
3641    shape: Iterable[int],
3642    *,
3643    filter: str,
3644    boundary: str = 'reflect',
3645    cval: float = 0.0,
3646) -> _NDArray:
3647  """Invoke `skimage.transform.resize` using the same parameters as `resize`."""
3648  import skimage.transform
3649
3650  array = np.asarray(array)
3651  shape = tuple(shape)
3652  assert 1 <= len(shape) <= array.ndim
3653  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3654  if filter not in filters:
3655    raise ValueError(f'{filter=} not in {filters=}.')
3656  order = filters[filter]
3657  boundaries = {'reflect': 'symmetric', 'wrap': 'wrap', 'clamp': 'edge', 'border': 'constant'}
3658  if boundary not in boundaries:
3659    raise ValueError(f'{boundary=} not in {boundaries=}.')
3660  mode = boundaries[boundary]
3661  shape_all = shape + array.shape[len(shape) :]
3662  # Default anti_aliasing=None automatically enables (poor) Gaussian prefilter if downsampling.
3663  # clip=False is the default behavior in `resampler` if the output type is non-integer.
3664  return skimage.transform.resize(
3665      array, shape_all, order=order, mode=mode, cval=cval, clip=False
3666  )  # type: ignore[no-untyped-call]
3667
3668
3669_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER = {
3670    'impulse': 'nearest',
3671    'trapezoid': 'area',
3672    'triangle': 'bilinear',
3673    'mitchell': 'mitchellcubic',
3674    'cubic': 'bicubic',
3675    'lanczos3': 'lanczos3',
3676    'lanczos5': 'lanczos5',
3677    # GaussianFilter(0.5): 'gaussian',  # radius_4 > desired_radius_3.
3678}
3679
3680
3681def tf_image_resize(
3682    array: _ArrayLike,
3683    /,
3684    shape: Iterable[int],
3685    *,
3686    filter: str,
3687    boundary: str = 'natural',
3688    cval: float = 0.0,
3689    antialias: bool = True,
3690) -> _TensorflowTensor:
3691  """Invoke `tf.image.resize` using the same parameters as `resize`."""
3692  import tensorflow as tf
3693
3694  if filter not in _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER:
3695    raise ValueError(f'{filter=} not in {_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER=}.')
3696  if boundary != 'natural':
3697    raise ValueError(f"{boundary=} must equal 'natural'.")
3698  del cval
3699  array2 = tf.convert_to_tensor(array)
3700  ndim = len(array2.shape)
3701  del array
3702  assert 1 <= ndim <= 3
3703  shape = tuple(shape)
3704  _check_eq(len(shape), 2 if ndim >= 2 else 1)
3705  match ndim:
3706    case 1:
3707      return tf_image_resize(array2[None], (1, *shape), filter=filter, antialias=antialias)[0]
3708    case 2:
3709      return tf_image_resize(array2[..., None], shape, filter=filter, antialias=antialias)[..., 0]
3710    case _:
3711      method = _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER[filter]
3712      return tf.image.resize(array2, shape, method=method, antialias=antialias)
3713
3714
3715_TORCH_INTERPOLATE_MODE_FROM_FILTER = {
3716    'impulse': 'nearest-exact',  # ('nearest' matches buggy OpenCV's INTER_NEAREST)
3717    'trapezoid': 'area',
3718    'triangle': 'bilinear',
3719    'sharpcubic': 'bicubic',
3720}
3721
3722
3723def torch_nn_resize(
3724    array: _ArrayLike,
3725    /,
3726    shape: Iterable[int],
3727    *,
3728    filter: str,
3729    boundary: str = 'clamp',
3730    cval: float = 0.0,
3731    antialias: bool = False,
3732) -> _TorchTensor:
3733  """Invoke `torch.nn.functional.interpolate` using the same parameters as `resize`."""
3734  import torch
3735
3736  if filter not in _TORCH_INTERPOLATE_MODE_FROM_FILTER:
3737    raise ValueError(f'{filter=} not in {_TORCH_INTERPOLATE_MODE_FROM_FILTER=}.')
3738  if boundary != 'clamp':
3739    raise ValueError(f"{boundary=} must equal 'clamp'.")
3740  del cval
3741  a = torch.as_tensor(array)
3742  del array
3743  assert 1 <= a.ndim <= 3
3744  shape = tuple(shape)
3745  _check_eq(len(shape), 2 if a.ndim >= 2 else 1)
3746  mode = _TORCH_INTERPOLATE_MODE_FROM_FILTER[filter]
3747
3748  def local_resize(a: _TorchTensor) -> _TorchTensor:
3749    # For upsampling, BILINEAR antialias is same PSNR and slower,
3750    #  and BICUBIC antialias is worse PSNR and faster.
3751    # For downsampling, antialias improves PSNR for both BILINEAR and BICUBIC.
3752    # Default align_corners=None corresponds to False which is what we desire.
3753    return torch.nn.functional.interpolate(a, shape, mode=mode, antialias=antialias)
3754
3755  match a.ndim:
3756    case 1:
3757      shape = (1, *shape)
3758      return local_resize(a[None, None, None])[0, 0, 0]
3759    case 2:
3760      return local_resize(a[None, None])[0, 0]
3761    case _:
3762      return local_resize(a.moveaxis(2, 0)[None])[0].moveaxis(0, 2)
3763
3764
3765def jax_image_resize(
3766    array: _ArrayLike,
3767    /,
3768    shape: Iterable[int],
3769    *,
3770    filter: str,
3771    boundary: str = 'natural',
3772    cval: float = 0.0,
3773    scale: float | Iterable[float] = 1.0,
3774    translate: float | Iterable[float] = 0.0,
3775) -> _JaxArray:
3776  """Invoke `jax.image.scale_and_translate` using the same parameters as `resize`."""
3777  import jax.image
3778  import jax.numpy as jnp
3779
3780  filters = 'triangle cubic lanczos3 lanczos5'.split()
3781  if filter not in filters:
3782    raise ValueError(f'{filter=} not in {filters=}.')
3783  if boundary != 'natural':
3784    raise ValueError(f"{boundary=} must equal 'natural'.")
3785  # When `scale` or `translate` are applied, any region outside the unit domain is assigned value 0.
3786  # To be consistent, the parameter `cval` must be zero.
3787  if scale != 1.0 and cval != 0.0:
3788    raise ValueError(f'Non-unity {scale=} requires that {cval=} be zero.')
3789  if translate != 0.0 and cval != 0.0:
3790    raise ValueError(f'Nonzero {translate=} requires that {cval=} be zero.')
3791  array2 = jnp.asarray(array)
3792  del array
3793  shape = tuple(shape)
3794  assert len(shape) <= array2.ndim
3795  completed_shape = shape + (1,) * (array2.ndim - len(shape))
3796  spatial_dims = list(range(len(shape)))
3797  scale2 = np.broadcast_to(np.array(scale), len(shape))
3798  scale2 = scale2 / np.array(array2.shape[: len(shape)]) * np.array(shape)
3799  translate2 = np.broadcast_to(np.array(translate), len(shape))
3800  translate2 = translate2 * np.array(shape)
3801  return jax.image.scale_and_translate(
3802      array2, completed_shape, spatial_dims, scale2, translate2, filter
3803  )
3804
3805
3806_CANDIDATE_RESIZERS = {
3807    'resampler.resize': resize,
3808    'PIL.Image.resize': pil_image_resize,
3809    'cv.resize': cv_resize,
3810    'scipy.ndimage.map_coordinates': scipy_ndimage_resize,
3811    'skimage.transform.resize': skimage_transform_resize,
3812    'tf.image.resize': tf_image_resize,
3813    'torch.nn.functional.interpolate': torch_nn_resize,
3814    'jax.image.scale_and_translate': jax_image_resize,
3815}
3816
3817
3818def _resizer_is_available(library_function: str) -> bool:
3819  """Return whether the resizer is available as an installed package."""
3820  top_name = library_function.split('.', 1)[0]
3821  module = {'PIL': 'Pillow', 'cv': 'cv2', 'tf': 'tensorflow'}.get(top_name, top_name)
3822  return importlib.util.find_spec(module) is not None  # type: ignore[attr-defined]
3823
3824
3825_RESIZERS = {
3826    library_function: resizer
3827    for library_function, resizer in _CANDIDATE_RESIZERS.items()
3828    if _resizer_is_available(library_function)
3829}
3830
3831
3832def _find_closest_filter(filter: str, resizer: Callable[..., Any]) -> str:
3833  """Return the filter supported by `resizer` (i.e., `*_resize`) that is closest to `filter`."""
3834  match filter:
3835    case 'box_like':
3836      return {
3837          cv_resize: 'trapezoid',
3838          skimage_transform_resize: 'box',
3839          tf_image_resize: 'trapezoid',
3840          torch_nn_resize: 'trapezoid',
3841      }.get(resizer, 'box')
3842    case 'cubic_like':
3843      return {
3844          cv_resize: 'sharpcubic',
3845          scipy_ndimage_resize: 'cardinal3',
3846          skimage_transform_resize: 'cardinal3',
3847          torch_nn_resize: 'sharpcubic',
3848      }.get(resizer, 'cubic')
3849    case 'high_quality':
3850      return {
3851          pil_image_resize: 'lanczos3',
3852          cv_resize: 'lanczos4',
3853          scipy_ndimage_resize: 'cardinal5',
3854          skimage_transform_resize: 'cardinal5',
3855          torch_nn_resize: 'sharpcubic',
3856      }.get(resizer, 'lanczos5')
3857    case _:
3858      return filter
3859
3860
3861# For Emacs:
3862# Local Variables:
3863# fill-column: 100
3864# End:
using_numba = False
def is_available(arraylib: str) -> bool:
806def is_available(arraylib: str) -> bool:
807  """Return whether the array library (e.g. 'tensorflow') is available as an installed package."""
808  # Faster than trying to import it.
809  return importlib.util.find_spec(arraylib) is not None  # type: ignore[attr-defined]

Return whether the array library (e.g. 'tensorflow') is available as an installed package.

ARRAYLIBS: list[str] = ['numpy']

Array libraries supported automatically in the resize and resampling operations.

  • The library is selected automatically based on the type of the array function parameter.

  • The class _Arraylib provides library-specific implementations of needed basic functions.

  • The _arr_*() functions dispatch the _Arraylib methods based on the array type.

@dataclasses.dataclass(frozen=True)
class Gridtype(abc.ABC):
1035@dataclasses.dataclass(frozen=True)
1036class Gridtype(abc.ABC):
1037  """Abstract base class for grid-types such as `'dual'` and `'primal'`.
1038
1039  In resampling operations, the grid-type may be specified separately as `src_gridtype` for the
1040  source domain and `dst_gridtype` for the destination domain.  Moreover, the grid-type may be
1041  specified per domain dimension.
1042
1043  Examples:
1044    `resize(source, shape, gridtype='primal')`  # Sets both src and dst to be `'primal'` grids.
1045
1046    `resize(source, shape, src_gridtype=['dual', 'primal'],
1047            dst_gridtype='dual')`  # Source is `'dual'` in dim0 and `'primal'` in dim1.
1048  """
1049
1050  name: str
1051  """Gridtype name."""
1052
1053  @abc.abstractmethod
1054  def min_size(self) -> int:
1055    """Return the necessary minimum number of grid samples."""
1056
1057  @abc.abstractmethod
1058  def size_in_samples(self, size: int, /) -> int:
1059    """Return the domain size in units of inter-sample spacing."""
1060
1061  @abc.abstractmethod
1062  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1063    """Return [0.0, 1.0] coordinates given [0, size - 1] indices."""
1064
1065  @abc.abstractmethod
1066  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1067    """Return location x given coordinates [0.0, 1.0], where x == 0.0 is the first grid sample
1068    and x == size - 1.0 is the last grid sample."""
1069
1070  @abc.abstractmethod
1071  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1072    """Map integer sample indices to interior ones using boundary reflection."""
1073
1074  @abc.abstractmethod
1075  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1076    """Map integer sample indices to interior ones using wrapping."""
1077
1078  @abc.abstractmethod
1079  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1080    """Map integer sample indices to interior ones using reflect-clamp."""

Abstract base class for grid-types such as 'dual' and 'primal'.

In resampling operations, the grid-type may be specified separately as src_gridtype for the source domain and dst_gridtype for the destination domain. Moreover, the grid-type may be specified per domain dimension.

Examples:

resize(source, shape, gridtype='primal') # Sets both src and dst to be 'primal' grids.

resize(source, shape, src_gridtype=['dual', 'primal'], dst_gridtype='dual') # Source is 'dual' in dim0 and 'primal' in dim1.

GRIDTYPES: list[str] = ['dual', 'primal']

Shortcut names for the two predefined grid types (specified per dimension):

gridtype 'dual'
DualGridtype()
(default)
'primal'
PrimalGridtype()
 
Sample positions in 2D
and in 1D at different resolutions
Dual Primal
Nesting of samples across resolutions The samples positions do not nest. The even samples remain at coarser scale.
Number $N_\ell$ of samples (per-dimension) at resolution level $\ell$ $N_\ell=2^\ell$ $N_\ell=2^\ell+1$
Position of sample index $i$ within domain $[0, 1]$ $\frac{i + 0.5}{N}$ ("half-integer" coordinates) $\frac{i}{N-1}$
Image resolutions ($N_\ell\times N_\ell$) for dyadic scales $1\times1, ~~2\times2, ~~4\times4, ~~8\times8, ~\ldots$ $2\times2, ~~3\times3, ~~5\times5, ~~9\times9, ~\ldots$

See the source code for extensibility.

@dataclasses.dataclass(frozen=True)
class Boundary:
1467@dataclasses.dataclass(frozen=True)
1468class Boundary:
1469  """Domain boundary rules.  These define the reconstruction over the source domain near and beyond
1470  the domain boundaries.  The rules may be specified separately for each domain dimension."""
1471
1472  name: str = ''
1473  """Boundary rule name."""
1474
1475  coord_remap: RemapCoordinates = NoRemapCoordinates()
1476  """Modify specified coordinates prior to evaluating the reconstruction kernels."""
1477
1478  extend_samples: ExtendSamples = ReflectExtendSamples()
1479  """Define the value of each grid sample outside the unit domain as an affine combination of
1480  interior sample(s) and possibly the constant value (`cval`)."""
1481
1482  override_value: OverrideExteriorValue = NoOverrideExteriorValue()
1483  """Set the value outside some extent to a constant value (`cval`)."""
1484
1485  @property
1486  def uses_cval(self) -> bool:
1487    """True if weights may be non-affine, involving the constant value (`cval`)."""
1488    return self.extend_samples.uses_cval or self.override_value.uses_cval
1489
1490  def preprocess_coordinates(self, point: _NDArray, /) -> _NDArray:
1491    """Modify coordinates prior to evaluating the filter kernels."""
1492    # Antialiasing across the tile boundaries may be feasible but seems hard.
1493    point = self.coord_remap(point)
1494    return point
1495
1496  def apply(
1497      self, index: _NDArray, weight: _NDArray, point: _NDArray, size: int, gridtype: Gridtype, /
1498  ) -> tuple[_NDArray, _NDArray]:
1499    """Replace exterior samples by combinations of interior samples."""
1500    index, weight = self.extend_samples(index, weight, size, gridtype)
1501    self.override_reconstruction(weight, point)
1502    return index, weight
1503
1504  def override_reconstruction(self, weight: _NDArray, point: _NDArray, /) -> None:
1505    """For points outside an extent, modify weight to zero to assign `cval`."""
1506    self.override_value(weight, point)

Domain boundary rules. These define the reconstruction over the source domain near and beyond the domain boundaries. The rules may be specified separately for each domain dimension.

BOUNDARIES: list[str] = ['reflect', 'wrap', 'tile', 'clamp', 'border', 'natural', 'linear_constant', 'quadratic_constant', 'reflect_clamp', 'constant', 'linear', 'quadratic']

Shortcut names for some predefined boundary rules (as defined by _DICT_BOUNDARIES):

name a.k.a. / comments
'reflect' reflected, symm, symmetric, mirror, grid-mirror
'wrap' periodic, repeat, grid-wrap
'tile' like 'reflect' within unit domain, then tile discontinuously
'clamp' clamped, nearest, edge, clamp-to-edge, repeat last sample
'border' grid-constant, use cval for samples outside unit domain
'natural' renormalize using only interior samples, use cval outside domain
'reflect_clamp' mirror-clamp-to-edge
'constant' like 'reflect' but replace by cval outside unit domain
'linear' extrapolate from 2 last samples
'quadratic' extrapolate from 3 last samples
'linear_constant' like 'linear' but replace by cval outside unit domain
'quadratic_constant' like 'quadratic' but replace by cval outside unit domain

These boundary rules may be specified per dimension. See the source code for extensibility using the classes RemapCoordinates, ExtendSamples, and OverrideExteriorValue.

Boundary rules illustrated in 1D:

Boundary rules illustrated in 2D:

@dataclasses.dataclass(frozen=True)
class Filter(abc.ABC):
1587@dataclasses.dataclass(frozen=True)
1588class Filter(abc.ABC):
1589  """Abstract base class for filter kernel functions.
1590
1591  Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support
1592  interval [-radius, radius].  (Some sites instead define kernels over the interval [0, N]
1593  where N = 2 * radius.)
1594
1595  Portions of this code are adapted from the C++ library in
1596  https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp
1597
1598  See also https://hhoppe.com/proj/filtering/.
1599  """
1600
1601  name: str
1602  """Filter kernel name."""
1603
1604  radius: float
1605  """Max absolute value of x for which self(x) is nonzero."""
1606
1607  interpolating: bool = True
1608  """True if self(0) == 1.0 and self(i) == 0.0 for all nonzero integers i."""
1609
1610  continuous: bool = True
1611  """True if the kernel function has $C^0$ continuity."""
1612
1613  partition_of_unity: bool = True
1614  """True if the convolution of the kernel with a Dirac comb reproduces the
1615  unity function."""
1616
1617  unit_integral: bool = True
1618  """True if the integral of the kernel function is 1."""
1619
1620  requires_digital_filter: bool = False
1621  """True if the filter needs a pre/post digital filter for interpolation."""
1622
1623  @abc.abstractmethod
1624  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1625    """Return evaluation of filter kernel at locations x."""

Abstract base class for filter kernel functions.

Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support interval [-radius, radius]. (Some sites instead define kernels over the interval [0, N] where N = 2 * radius.)

Portions of this code are adapted from the C++ library in https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp

See also https://hhoppe.com/proj/filtering/.

class ImpulseFilter(Filter):
1628class ImpulseFilter(Filter):
1629  """See https://en.wikipedia.org/wiki/Dirac_delta_function."""
1630
1631  def __init__(self) -> None:
1632    super().__init__(name='impulse', radius=1e-20, continuous=False, partition_of_unity=False)
1633
1634  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1635    raise AssertionError('The Impulse is infinitely narrow, so cannot be directly evaluated.')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class BoxFilter(Filter):
1638class BoxFilter(Filter):
1639  """See https://en.wikipedia.org/wiki/Box_function.
1640
1641  The kernel function has value 1.0 over the half-open interval [-.5, .5).
1642  """
1643
1644  def __init__(self) -> None:
1645    super().__init__(name='box', radius=0.5, continuous=False)
1646
1647  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1648    use_asymmetric = True
1649    if use_asymmetric:
1650      x = np.asarray(x)
1651      return np.where((-0.5 <= x) & (x < 0.5), 1.0, 0.0)
1652    x = np.abs(x)
1653    return np.where(x < 0.5, 1.0, np.where(x == 0.5, 0.5, 0.0))

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class TrapezoidFilter(Filter):
1656class TrapezoidFilter(Filter):
1657  """Filter for antialiased "area-based" filtering.
1658
1659  Args:
1660    radius: Specifies the support [-radius, radius] of the filter, where 0.5 < radius <= 1.0.
1661      The special case `radius = None` is a placeholder that indicates that the filter will be
1662      replaced by a trapezoid of the appropriate radius (based on scaling) for correct
1663      antialiasing in both minification and magnification.
1664
1665  This filter is similar to the BoxFilter but with linearly sloped sides.  It has value 1.0
1666  in the interval abs(x) <= 1.0 - radius and decreases linearly to value 0.0 in the interval
1667  1.0 - radius <= abs(x) <= radius, always with value 0.5 at x = 0.5.
1668  """
1669
1670  def __init__(self, *, radius: float | None = None) -> None:
1671    if radius is None:
1672      super().__init__(name='trapezoid', radius=0.0)
1673      return
1674    if not 0.5 < radius <= 1.0:
1675      raise ValueError(f'Radius {radius} is outside the range (0.5, 1.0].')
1676    super().__init__(name=f'trapezoid_{radius}', radius=radius)
1677
1678  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1679    x = np.abs(x)
1680    assert 0.5 < self.radius <= 1.0
1681    return ((0.5 + 0.25 / (self.radius - 0.5)) - (0.5 / (self.radius - 0.5)) * x).clip(0.0, 1.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class TriangleFilter(Filter):
1684class TriangleFilter(Filter):
1685  """See https://en.wikipedia.org/wiki/Triangle_function.
1686
1687  Also known as the hat or tent function.  It is used for piecewise-linear
1688  (or bilinear, or trilinear, ...) interpolation.
1689  """
1690
1691  def __init__(self) -> None:
1692    super().__init__(name='triangle', radius=1.0)
1693
1694  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1695    return (1.0 - np.abs(x)).clip(0.0, 1.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CubicFilter(Filter):
1698class CubicFilter(Filter):
1699  """Family of cubic filters parameterized by two scalar parameters.
1700
1701  Args:
1702    b: first scalar parameter.
1703    c: second scalar parameter.
1704
1705  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1706  https://doi.org/10.1145/378456.378514.
1707
1708  [D. P. Mitchell and A. N. Netravali. Reconstruction filters in computer graphics.
1709  Computer Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1710
1711  - The filter has quadratic precision iff b + 2 * c == 1.
1712  - The filter is interpolating iff b == 0.
1713  - (b=1, c=0) is the (non-interpolating) cubic B-spline basis;
1714  - (b=1/3, c=1/3) is the Mitchell filter;
1715  - (b=0, c=0.5) is the Catmull-Rom spline (which has cubic precision);
1716  - (b=0, c=0.75) is the "sharper cubic" used in Photoshop and OpenCV.
1717  """
1718
1719  def __init__(self, *, b: float, c: float, name: str | None = None) -> None:
1720    name = f'cubic_b{b}_c{c}' if name is None else name
1721    interpolating = b == 0
1722    super().__init__(name=name, radius=2.0, interpolating=interpolating)
1723    self.b, self.c = b, c
1724
1725  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1726    x = np.abs(x)
1727    b, c = self.b, self.c
1728    f3, f2, f0 = 2 - 9 / 6 * b - c, -3 + 2 * b + c, 1 - 1 / 3 * b
1729    g3, g2, g1, g0 = -b / 6 - c, b + 5 * c, -2 * b - 8 * c, 8 / 6 * b + 4 * c
1730    # (np.polynomial.polynomial.polyval(x, [f0, 0, f2, f3]) is almost
1731    # twice as slow; see also https://stackoverflow.com/questions/24065904)
1732    v01 = ((f3 * x + f2) * x) * x + f0
1733    v12 = ((g3 * x + g2) * x + g1) * x + g0
1734    return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CatmullRomFilter(CubicFilter):
1737class CatmullRomFilter(CubicFilter):
1738  """Cubic filter with cubic precision.  Also known as Keys filter.
1739
1740  [E. Catmull, R. Rom.  A class of local interpolating splines.  Computer aided geometric
1741  design, 1974]
1742  [Wikipedia](https://en.wikipedia.org/wiki/Cubic_Hermite_spline#Catmull%E2%80%93Rom_spline)
1743
1744  [R. G. Keys.  Cubic convolution interpolation for digital image processing.
1745  IEEE Trans. on Acoustics, Speech, and Signal Processing, 29(6), 1981.]
1746  https://ieeexplore.ieee.org/document/1163711/.
1747  """
1748
1749  def __init__(self) -> None:
1750    super().__init__(b=0, c=0.5, name='cubic')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class MitchellFilter(CubicFilter):
1753class MitchellFilter(CubicFilter):
1754  """See https://doi.org/10.1145/378456.378514.
1755
1756  [D. P. Mitchell and A. N. Netravali.  Reconstruction filters in computer graphics.  Computer
1757  Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1758  """
1759
1760  def __init__(self) -> None:
1761    super().__init__(b=1 / 3, c=1 / 3, name='mitchell')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class SharpCubicFilter(CubicFilter):
1764class SharpCubicFilter(CubicFilter):
1765  """Cubic filter that is sharper than Catmull-Rom filter.
1766
1767  Used by some tools including OpenCV and Photoshop.
1768
1769  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1770  https://entropymine.com/resamplescope/notes/photoshop/.
1771  """
1772
1773  def __init__(self) -> None:
1774    super().__init__(b=0, c=0.75, name='sharpcubic')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class LanczosFilter(Filter):
1777class LanczosFilter(Filter):
1778  """High-quality filter: sinc function modulated by a sinc window.
1779
1780  Args:
1781    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1782    sampled: If True, use a discretized approximation for improved speed.
1783
1784  See https://en.wikipedia.org/wiki/Lanczos_kernel.
1785  """
1786
1787  def __init__(self, *, radius: int, sampled: bool = True) -> None:
1788    super().__init__(
1789        name=f'lanczos_{radius}', radius=radius, partition_of_unity=False, unit_integral=False
1790    )
1791
1792    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1793    def _eval(x: _ArrayLike) -> _NDArray:
1794      x = np.abs(x)
1795      # Note that window[n] = sinc(2*n/N - 1), with 0 <= n <= N.
1796      # But, x = n - N/2, or equivalently, n = x + N/2, with -N/2 <= x <= N/2.
1797      window = _sinc(x / radius)  # Zero-phase function w_0(x).
1798      return np.where(x < radius, _sinc(x) * window, 0.0)
1799
1800    self._function = _eval
1801
1802  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1803    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class GeneralizedHammingFilter(Filter):
1806class GeneralizedHammingFilter(Filter):
1807  """Sinc function modulated by a Hamming window.
1808
1809  Args:
1810    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1811    a0: Scalar parameter, where 0.0 < a0 < 1.0.  The case of a0=0.5 is the Hann filter.
1812
1813  See https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows,
1814  and hamming() in https://github.com/scipy/scipy/blob/main/scipy/signal/windows/_windows.py.
1815
1816  Note that `'hamming3'` is `(radius=3, a0=25/46)`, which close to but different from `a0=0.54`.
1817
1818  See also np.hamming() and np.hanning().
1819  """
1820
1821  def __init__(self, *, radius: int, a0: float) -> None:
1822    super().__init__(
1823        name=f'hamming_{radius}',
1824        radius=radius,
1825        partition_of_unity=False,  # 1:1.00242  av=1.00188  sd=0.00052909
1826        unit_integral=False,  # 1.00188
1827    )
1828    assert 0.0 < a0 < 1.0
1829    self.a0 = a0
1830
1831  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1832    x = np.abs(x)
1833    # Note that window[n] = a0 - (1 - a0) * cos(2 * pi * n / N), 0 <= n <= N.
1834    # With n = x + N/2, we get the zero-phase function w_0(x):
1835    window = self.a0 + (1.0 - self.a0) * np.cos(np.pi / self.radius * x)
1836    return np.where(x < self.radius, _sinc(x) * window, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class KaiserFilter(Filter):
1839class KaiserFilter(Filter):
1840  """Sinc function modulated by a Kaiser-Bessel window.
1841
1842  See https://en.wikipedia.org/wiki/Kaiser_window, and example use in:
1843  [Karras et al. 20201.  Alias-free generative adversarial networks.
1844  https://arxiv.org/pdf/2106.12423.pdf].
1845
1846  See also np.kaiser().
1847
1848  Args:
1849    radius: Value L/2 in the definition.  It may be fractional for a (digital) resizing filter
1850      (sample spacing s != 1) with an even number of samples (dual grid), e.g., Eq. (6)
1851      in [Karras et al. 2021] --- this effects the precise shape of the window function.
1852    beta: Determines the trade-off between main-lobe width and side-lobe level.
1853    sampled: If True, use a discretized approximation for improved speed.
1854  """
1855
1856  def __init__(self, *, radius: float, beta: float, sampled: bool = True) -> None:
1857    assert beta >= 0.0
1858    super().__init__(
1859        name=f'kaiser_{radius}_{beta}', radius=radius, partition_of_unity=False, unit_integral=False
1860    )
1861
1862    @_cache_sampled_1d_function(xmin=-math.ceil(radius), xmax=math.ceil(radius), enable=sampled)
1863    def _eval(x: _ArrayLike) -> _NDArray:
1864      x = np.abs(x)
1865      window = np.i0(beta * np.sqrt((1.0 - np.square(x / radius)).clip(0.0, 1.0))) / np.i0(beta)
1866      return np.where(x <= radius + 1e-6, _sinc(x) * window, 0.0)
1867
1868    self._function = _eval
1869
1870  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1871    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class BsplineFilter(Filter):
1874class BsplineFilter(Filter):
1875  """B-spline of a non-negative degree.
1876
1877  Args:
1878    degree: The polynomial degree of the B-spline segments.
1879      With `degree=0`, it is like `BoxFilter` except with f(0.5) = f(-0.5) = 0.
1880      With `degree=1`, it is identical to `TriangleFilter`.
1881      With `degree >= 2`, it is no longer interpolating.
1882
1883  See [Carl de Boor.  A practical guide to splines.  Springer, 2001.]
1884  https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.BSpline.html
1885  """
1886
1887  def __init__(self, *, degree: int) -> None:
1888    if degree < 0:
1889      raise ValueError(f'Bspline of degree {degree} is invalid.')
1890    radius = (degree + 1) / 2
1891    interpolating = degree <= 1
1892    super().__init__(name=f'bspline{degree}', radius=radius, interpolating=interpolating)
1893    t = list(range(degree + 2))
1894    self._bspline = scipy.interpolate.BSpline.basis_element(t)
1895
1896  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1897    x = np.abs(x)
1898    return np.where(x < self.radius, self._bspline(x + self.radius), 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CardinalBsplineFilter(Filter):
1901class CardinalBsplineFilter(Filter):
1902  """Interpolating B-spline, achieved with aid of digital pre or post filter.
1903
1904  Args:
1905    degree: The polynomial degree of the B-spline segments.
1906    sampled: If True, use a discretized approximation for improved speed.
1907
1908  See [Hou and Andrews.  Cubic splines for image interpolation and digital filtering, 1978] and
1909  [Unser et al.  Fast B-spline transforms for continuous image representation and interpolation,
1910  1991].
1911  """
1912
1913  def __init__(self, *, degree: int, sampled: bool = True) -> None:
1914    self.degree = degree
1915    if degree < 0:
1916      raise ValueError(f'Bspline of degree {degree} is invalid.')
1917    radius = (degree + 1) / 2
1918    super().__init__(
1919        name=f'cardinal{degree}',
1920        radius=radius,
1921        requires_digital_filter=degree >= 2,
1922        continuous=degree >= 1,
1923    )
1924    t = list(range(degree + 2))
1925    bspline = scipy.interpolate.BSpline.basis_element(t)
1926
1927    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1928    def _eval(x: _ArrayLike) -> _NDArray:
1929      x = np.abs(x)
1930      return np.where(x < radius, bspline(x + radius), 0.0)
1931
1932    self._function = _eval
1933
1934  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1935    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class OmomsFilter(Filter):
1938class OmomsFilter(Filter):
1939  """OMOMS interpolating filter, with aid of digital pre or post filter.
1940
1941  Args:
1942    degree: The polynomial degree of the filter segments.
1943
1944  Optimal MOMS (maximal-order-minimal-support) function; see [Blu and Thevenaz, MOMS: Maximal-order
1945  interpolation of minimal support, 2001].
1946  https://infoscience.epfl.ch/record/63074/files/blu0101.pdf
1947  """
1948
1949  def __init__(self, *, degree: int) -> None:
1950    if degree not in (3, 5):
1951      raise ValueError(f'Degree {degree} not supported.')
1952    super().__init__(name=f'omoms{degree}', radius=(degree + 1) / 2, requires_digital_filter=True)
1953    self.degree = degree
1954
1955  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1956    x = np.abs(x)
1957    match self.degree:
1958      case 3:
1959        v01 = ((0.5 * x - 1.0) * x + 3 / 42) * x + 26 / 42
1960        v12 = ((-7 / 42 * x + 1.0) * x - 85 / 42) * x + 58 / 42
1961        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1962      case 5:
1963        v01 = ((((-1 / 12 * x + 1 / 4) * x - 5 / 99) * x - 9 / 22) * x - 1 / 792) * x + 229 / 440
1964        v12 = (
1965            (((1 / 24 * x - 3 / 8) * x + 505 / 396) * x - 83 / 44) * x + 1351 / 1584
1966        ) * x + 839 / 2640
1967        v23 = (
1968            (((-1 / 120 * x + 1 / 8) * x - 299 / 396) * x + 101 / 44) * x - 27811 / 7920
1969        ) * x + 5707 / 2640
1970        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, np.where(x < 3.0, v23, 0.0)))
1971      case _:
1972        raise ValueError(self.degree)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class GaussianFilter(Filter):
1975class GaussianFilter(Filter):
1976  r"""See https://en.wikipedia.org/wiki/Gaussian_function.
1977
1978  Args:
1979    standard_deviation: Sets the Gaussian $\sigma$.  The default value is 1.25/3.0, which
1980      creates a kernel that is as-close-as-possible to a partition of unity.
1981  """
1982
1983  DEFAULT_STANDARD_DEVIATION = 1.25 / 3.0
1984  """This value creates a kernel that is as-close-as-possible to a partition of unity; see
1985  mesh_processing/test/GridOp_test.cpp: `0.93503:1.06497     av=1           sd=0.0459424`.
1986  Another possibility is 0.5, as suggested on p. 4 of [Ken Turkowski.  Filters for common
1987  resampling tasks, 1990] for kernels with a support of 3 pixels.
1988  https://cadxfem.org/inf/ResamplingFilters.pdf
1989  """
1990
1991  def __init__(self, *, standard_deviation: float = DEFAULT_STANDARD_DEVIATION) -> None:
1992    super().__init__(
1993        name=f'gaussian_{standard_deviation:.3f}',
1994        radius=np.ceil(8.0 * standard_deviation),  # Sufficiently large.
1995        interpolating=False,
1996        partition_of_unity=False,
1997    )
1998    self.standard_deviation = standard_deviation
1999
2000  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2001    x = np.abs(x)
2002    sdv = self.standard_deviation
2003    v0r = np.exp(np.square(x / sdv) / -2.0) / (np.sqrt(math.tau) * sdv)
2004    return np.where(x < self.radius, v0r, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class NarrowBoxFilter(Filter):
2007class NarrowBoxFilter(Filter):
2008  """Compact footprint, used for visualization of grid sample location.
2009
2010  Args:
2011    radius: Specifies the support [-radius, radius] of the narrow box function.  (The default
2012      value 0.199 is an inexact 0.2 to avoid numerical ambiguities.)
2013  """
2014
2015  def __init__(self, *, radius: float = 0.199) -> None:
2016    super().__init__(
2017        name='narrowbox',
2018        radius=radius,
2019        continuous=False,
2020        unit_integral=False,
2021        partition_of_unity=False,
2022    )
2023
2024  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2025    radius = self.radius
2026    magnitude = 1.0
2027    x = np.asarray(x)
2028    return np.where((-radius <= x) & (x < radius), magnitude, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

FILTERS: list[str] = ['impulse', 'box', 'trapezoid', 'triangle', 'cubic', 'sharpcubic', 'lanczos3', 'lanczos5', 'lanczos10', 'cardinal3', 'cardinal5', 'omoms3', 'omoms5', 'hamming3', 'kaiser3', 'gaussian', 'bspline3', 'mitchell', 'narrowbox']

Shortcut names for some predefined filter kernels (specified per dimension). The names expand to:

name Filter a.k.a. / comments
'impulse' ImpulseFilter() nearest
'box' BoxFilter() non-antialiased box, e.g. ImageMagick
'trapezoid' TrapezoidFilter() area antialiasing, e.g. cv.INTER_AREA
'triangle' TriangleFilter() linear (bilinear in 2D), spline order=1
'cubic' CatmullRomFilter() catmullrom, keys, bicubic
'sharpcubic' SharpCubicFilter() cv.INTER_CUBIC, torch 'bicubic'
'lanczos3' LanczosFilter(radius=3) support window [-3, 3]
'lanczos5' LanczosFilter(radius=5) [-5, 5]
'lanczos10' LanczosFilter(radius=10) [-10, 10]
'cardinal3' CardinalBsplineFilter(degree=3) spline interpolation, order=3, GF
'cardinal5' CardinalBsplineFilter(degree=5) spline interpolation, order=5, GF
'omoms3' OmomsFilter(degree=3) non-$C^1$, [-3, 3], GF
'omoms5' OmomsFilter(degree=5) non-$C^1$, [-5, 5], GF
'hamming3' GeneralizedHammingFilter(...) (radius=3, a0=25/46)
'kaiser3' KaiserFilter(radius=3.0, beta=7.12)
'gaussian' GaussianFilter() non-interpolating, default $\sigma=1.25/3$
'bspline3' BsplineFilter(degree=3) non-interpolating
'mitchell' MitchellFilter() mitchellcubic
'narrowbox' NarrowBoxFilter() for visualization of sample positions

The comment label GF denotes a generalized filter, formed as the composition of a finitely supported kernel and a discrete inverse convolution.

Some example filter kernels:


A more extensive set of filters is presented here in the notebook, together with visualizations and analyses of the filter properties. See the source code for extensibility.

@dataclasses.dataclass(frozen=True)
class Gamma(abc.ABC):
2136@dataclasses.dataclass(frozen=True)
2137class Gamma(abc.ABC):
2138  """Abstract base class for transfer functions on sample values.
2139
2140  Image/video content is often stored using a color component transfer function.
2141  See https://en.wikipedia.org/wiki/Gamma_correction.
2142
2143  Converts between integer types and [0.0, 1.0] internal value range.
2144  """
2145
2146  name: str
2147  """Name of component transfer function."""
2148
2149  @abc.abstractmethod
2150  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2151    """Decode source sample values into floating-point, possibly nonlinearly.
2152
2153    Uint source values are mapped to the range [0.0, 1.0].
2154    """
2155
2156  @abc.abstractmethod
2157  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2158    """Encode float signal into destination samples, possibly nonlinearly.
2159
2160    Uint destination values are mapped from the range [0.0, 1.0].
2161
2162    Note that non-integer destination types are not clipped to the range [0.0, 1.0].
2163    If that is desired, it can be performed as a postprocess using `output.clip(0.0, 1.0)`.
2164    """

Abstract base class for transfer functions on sample values.

Image/video content is often stored using a color component transfer function. See https://en.wikipedia.org/wiki/Gamma_correction.

Converts between integer types and [0.0, 1.0] internal value range.

class IdentityGamma(Gamma):
2167class IdentityGamma(Gamma):
2168  """Identity component transfer function."""
2169
2170  def __init__(self) -> None:
2171    super().__init__('identity')
2172
2173  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2174    dtype = np.dtype(dtype)
2175    assert np.issubdtype(dtype, np.inexact)
2176    if np.issubdtype(_arr_dtype(array), np.unsignedinteger):
2177      return _to_float_01(array, dtype)
2178    return _arr_astype(array, dtype)
2179
2180  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2181    dtype = np.dtype(dtype)
2182    assert np.issubdtype(dtype, np.number)
2183    if np.issubdtype(dtype, np.unsignedinteger):
2184      return _from_float(_arr_clip(array, 0.0, 1.0), dtype)
2185    if np.issubdtype(dtype, np.integer):
2186      return _arr_astype(array + 0.5, dtype)
2187    return _arr_astype(array, dtype)

Gamma(name: 'str')

GAMMAS: list[str] = ['identity', 'power2', 'power22', 'srgb']

Shortcut names for some predefined gamma-correction schemes:

name Gamma Decoding function
(linear space from stored value)
Encoding function
(stored value from linear space)
'identity' IdentityGamma() $l = e$ $e = l$
'power2' PowerGamma(2.0) $l = e^{2.0}$ $e = l^{1/2.0}$
'power22' PowerGamma(2.2) $l = e^{2.2}$ $e = l^{1/2.2}$
'srgb' SrgbGamma() $l = \left(\left(e + 0.055\right) / 1.055\right)^{2.4}$ $e = l^{1/2.4} * 1.055 - 0.055$

See the source code for extensibility.

def resize( array: Array, /, shape: Iterable[int], *, gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, boundary: str | Boundary | Iterable[str | Boundary] = 'auto', cval: ArrayLike = 0, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, gamma: str | Gamma | None = None, src_gamma: str | Gamma | None = None, dst_gamma: str | Gamma | None = None, scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0, precision: DTypeLike = None, dtype: DTypeLike = None, dim_order: Iterable[int] | None = None, num_threads: int | Literal['auto'] = 'auto') -> Array:
2602def resize(
2603    array: _Array,
2604    /,
2605    shape: Iterable[int],
2606    *,
2607    gridtype: str | Gridtype | None = None,
2608    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2609    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2610    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2611    cval: _ArrayLike = 0.0,
2612    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2613    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2614    gamma: str | Gamma | None = None,
2615    src_gamma: str | Gamma | None = None,
2616    dst_gamma: str | Gamma | None = None,
2617    scale: float | Iterable[float] = 1.0,
2618    translate: float | Iterable[float] = 0.0,
2619    precision: _DTypeLike = None,
2620    dtype: _DTypeLike = None,
2621    dim_order: Iterable[int] | None = None,
2622    num_threads: int | Literal['auto'] = 'auto',
2623) -> _Array:
2624  """Resample `array` (a grid of sample values) onto a grid with resolution `shape`.
2625
2626  The source `array` is any object recognized by `ARRAYLIBS`.  It is interpreted as a grid
2627  with `len(shape)` domain coordinate dimensions, where each grid sample value has shape
2628  `array.shape[len(shape):]`.
2629
2630  Some examples:
2631
2632  - A grayscale image has `array.shape = height, width` and resizing it with `len(shape) == 2`
2633    produces a new image of scalar values.
2634  - An RGB image has `array.shape = height, width, 3` and resizing it with `len(shape) == 2`
2635    produces a new image of RGB values.
2636  - An 3D grid of 3x3 Jacobians has `array.shape = Z, Y, X, 3, 3` and resizing it with
2637    `len(shape) == 3` produces a new 3D grid of Jacobians.
2638
2639  This function also allows scaling and translation from the source domain to the output domain
2640  through the parameters `scale` and `translate`.  For more general transforms, see `resample`.
2641
2642  Args:
2643    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
2644      The array must have numeric type.  Its first `len(shape)` dimensions are the domain
2645      coordinate dimensions.  Each grid dimension must be at least 1 for a `'dual'` grid or
2646      at least 2 for a `'primal'` grid.
2647    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2648      `array` must have at least as many dimensions as `len(shape)`.
2649    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
2650      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
2651      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
2652    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2653      Parameters `gridtype` and `src_gridtype` cannot both be set.
2654    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2655      Parameters `gridtype` and `dst_gridtype` cannot both be set.
2656    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2657      a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses `'reflect'`
2658      for upsampling and `'clamp'` for downsampling.
2659    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
2660      onto `array.shape[len(shape):]`.  It is subject to `src_gamma`.
2661    filter: The reconstruction kernel for each dimension in `shape`, specified as either a filter
2662      name in `FILTERS` or a `Filter` instance.  It is used during upsampling (i.e., magnification).
2663    prefilter: The prefilter kernel for each dimension in `shape`, specified as either a filter
2664      name in `FILTERS` or a `Filter` instance.  It is used during downsampling
2665      (i.e., minification).  If `None`, it inherits the value of `filter`.  The default
2666      `'lanczos3'` is good for natural images.  For vector graphics images, `'trapezoid'` is better
2667      because it avoids ringing artifacts.
2668    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from
2669      `array` and when creating output grid samples.  It is specified as either a name in `GAMMAS`
2670      or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default is
2671      `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
2672      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
2673      range [0.0, 1.0].
2674    src_gamma: Component transfer function used to "decode" `array` samples.
2675      Parameters `gamma` and `src_gamma` cannot both be set.
2676    dst_gamma: Component transfer function used to "encode" the output samples.
2677      Parameters `gamma` and `dst_gamma` cannot both be set.
2678    scale: Scaling factor applied to each dimension of the source domain when it is mapped onto
2679      the destination domain.
2680    translate: Offset applied to each dimension of the scaled source domain when it is mapped onto
2681      the destination domain.
2682    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
2683      on `array.dtype` and `dtype`.
2684    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
2685      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
2686      to the uint range.
2687    dim_order: Override the automatically selected order in which the grid dimensions are resized.
2688      Must contain a permutation of `range(len(shape))`.
2689    num_threads: Used to determine multithread parallelism if `array` is from `numpy`.  If set to
2690      `'auto'`, it is selected automatically.  Otherwise, it must be a positive integer.
2691
2692  Returns:
2693    An array of the same class as the source `array`, with shape `shape + array.shape[len(shape):]`
2694      and data type `dtype`.
2695
2696  **Example of image upsampling:**
2697
2698  >>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
2699  >>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.
2700
2701  <center>
2702  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_upsampled.png"/>
2703  </center>
2704
2705  **Example of image downsampling:**
2706
2707  >>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
2708  >>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
2709  >>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
2710  >>> downsampled = resize(array, (24, 48))
2711
2712  <center>
2713  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_downsampled2.png"/>
2714  </center>
2715
2716  **Unit test:**
2717
2718  >>> result = resize([1.0, 4.0, 5.0], shape=(4,))
2719  >>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
2720  """
2721  if isinstance(array, (tuple, list)):
2722    array = np.asarray(array)
2723  arraylib = _arr_arraylib(array)
2724  array_dtype = _arr_dtype(array)
2725  if not np.issubdtype(array_dtype, np.number):
2726    raise ValueError(f'Type {array.dtype} is not numeric.')
2727  shape2 = tuple(shape)
2728  array_ndim = len(array.shape)
2729  if not 0 < len(shape2) <= array_ndim:
2730    raise ValueError(f'Shape {array.shape} cannot be resized to {shape2}.')
2731  src_shape = array.shape[: len(shape2)]
2732  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2733      gridtype, src_gridtype, dst_gridtype, len(shape2), len(shape2)
2734  )
2735  boundary2 = np.broadcast_to(np.array(boundary), len(shape2))
2736  cval = np.broadcast_to(cval, array.shape[len(shape2) :])
2737  prefilter = filter if prefilter is None else prefilter
2738  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape2))]
2739  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), len(shape2))]
2740  dtype = array_dtype if dtype is None else np.dtype(dtype)
2741  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, array_dtype, dtype)
2742  scale2 = np.broadcast_to(np.array(scale), len(shape2))
2743  translate2 = np.broadcast_to(np.array(translate), len(shape2))
2744  del (shape, src_gridtype, dst_gridtype, boundary, filter, prefilter)
2745  del (src_gamma, dst_gamma, scale, translate)
2746  precision = _get_precision(precision, [array_dtype, dtype], [])
2747  weight_precision = _real_precision(precision)
2748
2749  is_noop = (
2750      all(src == dst for src, dst in zip(src_shape, shape2))
2751      and all(gt1 == gt2 for gt1, gt2 in zip(src_gridtype2, dst_gridtype2))
2752      and all(f.interpolating for f in prefilter2)
2753      and np.all(scale2 == 1.0)
2754      and np.all(translate2 == 0.0)
2755      and src_gamma2 == dst_gamma2
2756  )
2757  if is_noop:
2758    return array
2759
2760  if dim_order is None:
2761    dim_order = _arr_best_dims_order_for_resize(array, shape2)
2762  else:
2763    dim_order = tuple(dim_order)
2764    if sorted(dim_order) != list(range(len(shape2))):
2765      raise ValueError(f'{dim_order} not a permutation of {list(range(len(shape2)))}.')
2766
2767  array = src_gamma2.decode(array, precision)
2768  cval = _arr_numpy(src_gamma2.decode(cval, precision))
2769
2770  can_use_fast_box_downsampling = (
2771      using_numba
2772      and arraylib == 'numpy'
2773      and len(shape2) == 2
2774      and array_ndim in (2, 3)
2775      and all(src > dst for src, dst in zip(src_shape, shape2))
2776      and all(src % dst == 0 for src, dst in zip(src_shape, shape2))
2777      and all(gridtype.name == 'dual' for gridtype in src_gridtype2)
2778      and all(gridtype.name == 'dual' for gridtype in dst_gridtype2)
2779      and all(f.name in ('box', 'trapezoid') for f in prefilter2)
2780      and np.all(scale2 == 1.0)
2781      and np.all(translate2 == 0.0)
2782  )
2783  if can_use_fast_box_downsampling:
2784    assert isinstance(array, np.ndarray)  # Help mypy.
2785    array = _downsample_in_2d_using_box_filter(array, cast(Any, shape2))
2786    array = dst_gamma2.encode(array, dtype)
2787    return array
2788
2789  # Multidimensional resize can be expressed using einsum() with multiple per-dim resize matrices,
2790  # e.g., as in jax.image.resize().  A benefit is to seek the optimal order of multiplications.
2791  # However, efficiency often requires sparse resize matrices, which are unsupported in einsum().
2792  # Sparse tensors requested for tf.einsum: https://github.com/tensorflow/tensorflow/issues/43497
2793  # https://github.com/tensor-compiler/taco: C++ library that computes tensor algebra expressions
2794  # on sparse and dense tensors; however it does not interoperate with tensorflow, torch, or jax.
2795
2796  for dim in dim_order:
2797    skip_resize_on_this_dim = (
2798        shape2[dim] == array.shape[dim]
2799        and scale2[dim] == 1.0
2800        and translate2[dim] == 0.0
2801        and filter2[dim].interpolating
2802    )
2803    if skip_resize_on_this_dim:
2804      continue
2805
2806    def get_is_minification() -> bool:
2807      src_in_samples = src_gridtype2[dim].size_in_samples(array.shape[dim])
2808      dst_in_samples = dst_gridtype2[dim].size_in_samples(shape2[dim])
2809      return dst_in_samples / src_in_samples * scale2[dim] < 1.0
2810
2811    is_minification = get_is_minification()
2812    boundary_dim = boundary2[dim]
2813    if boundary_dim == 'auto':
2814      boundary_dim = 'clamp' if is_minification else 'reflect'
2815    boundary_dim = _get_boundary(boundary_dim)
2816    resize_matrix, cval_weight = _create_resize_matrix(
2817        array.shape[dim],
2818        shape2[dim],
2819        src_gridtype=src_gridtype2[dim],
2820        dst_gridtype=dst_gridtype2[dim],
2821        boundary=boundary_dim,
2822        filter=filter2[dim],
2823        prefilter=prefilter2[dim],
2824        scale=scale2[dim],
2825        translate=translate2[dim],
2826        dtype=weight_precision,
2827        arraylib=arraylib,
2828    )
2829
2830    array_dim: _Array = _arr_moveaxis(array, dim, 0)
2831    array_flat = _arr_reshape(array_dim, (array_dim.shape[0], -1))
2832    array_flat = _arr_possibly_make_contiguous(array_flat)
2833    if not is_minification and filter2[dim].requires_digital_filter:
2834      array_flat = _apply_digital_filter_1d(
2835          array_flat, src_gridtype2[dim], boundary_dim, cval, filter2[dim]
2836      )
2837
2838    array_flat = _arr_matmul_sparse_dense(resize_matrix, array_flat, num_threads=num_threads)
2839    if cval_weight is not None:
2840      cval_flat = np.broadcast_to(cval, array_dim.shape[1:]).reshape(-1)
2841      if np.issubdtype(array_dtype, np.complexfloating):
2842        cval_weight = _arr_astype(cval_weight, array_dtype)  # (Only necessary for 'tensorflow'.)
2843      array_flat += cval_weight[:, None] * cval_flat
2844
2845    if is_minification and filter2[dim].requires_digital_filter:  # use prefilter2[dim]?
2846      array_flat = _apply_digital_filter_1d(
2847          array_flat, dst_gridtype2[dim], boundary_dim, cval, filter2[dim]
2848      )
2849    array_dim = _arr_reshape(array_flat, (array_flat.shape[0], *array_dim.shape[1:]))
2850    array = _arr_moveaxis(array_dim, 0, dim)
2851
2852  array = dst_gamma2.encode(array, dtype)
2853  return array

Resample array (a grid of sample values) onto a grid with resolution shape.

The source array is any object recognized by ARRAYLIBS. It is interpreted as a grid with len(shape) domain coordinate dimensions, where each grid sample value has shape array.shape[len(shape):].

Some examples:

  • A grayscale image has array.shape = height, width and resizing it with len(shape) == 2 produces a new image of scalar values.
  • An RGB image has array.shape = height, width, 3 and resizing it with len(shape) == 2 produces a new image of RGB values.
  • An 3D grid of 3x3 Jacobians has array.shape = Z, Y, X, 3, 3 and resizing it with len(shape) == 3 produces a new 3D grid of Jacobians.

This function also allows scaling and translation from the source domain to the output domain through the parameters scale and translate. For more general transforms, see resample.

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. Its first len(shape) dimensions are the domain coordinate dimensions. Each grid dimension must be at least 1 for a 'dual' grid or at least 2 for a 'primal' grid.
  • shape: The number of grid samples in each coordinate dimension of the output array. The source array must have at least as many dimensions as len(shape).
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual' if gridtype, src_gridtype, and dst_gridtype are all kept None.
  • src_gridtype: Placement of the samples in the source domain grid for each dimension. Parameters gridtype and src_gridtype cannot both be set.
  • dst_gridtype: Placement of the samples in the output domain grid for each dimension. Parameters gridtype and dst_gridtype cannot both be set.
  • boundary: The reconstruction boundary rule for each dimension in shape, specified as either a name in BOUNDARIES or a Boundary instance. The special value 'auto' uses 'reflect' for upsampling and 'clamp' for downsampling.
  • cval: Constant value used beyond the samples by some boundary rules. It must be broadcastable onto array.shape[len(shape):]. It is subject to src_gamma.
  • filter: The reconstruction kernel for each dimension in shape, specified as either a filter name in FILTERS or a Filter instance. It is used during upsampling (i.e., magnification).
  • prefilter: The prefilter kernel for each dimension in shape, specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter. The default 'lanczos3' is good for natural images. For vector graphics images, 'trapezoid' is better because it avoids ringing artifacts.
  • gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from array and when creating output grid samples. It is specified as either a name in GAMMAS or a Gamma instance. If both array.dtype and dtype are uint, the default is 'power2'. If both are non-uint, the default is 'identity'. Otherwise, gamma or src_gamma/dst_gamma must be set. Gamma correction assumes that float values are in the range [0.0, 1.0].
  • src_gamma: Component transfer function used to "decode" array samples. Parameters gamma and src_gamma cannot both be set.
  • dst_gamma: Component transfer function used to "encode" the output samples. Parameters gamma and dst_gamma cannot both be set.
  • scale: Scaling factor applied to each dimension of the source domain when it is mapped onto the destination domain.
  • translate: Offset applied to each dimension of the scaled source domain when it is mapped onto the destination domain.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • dim_order: Override the automatically selected order in which the grid dimensions are resized. Must contain a permutation of range(len(shape)).
  • num_threads: Used to determine multithread parallelism if array is from numpy. If set to 'auto', it is selected automatically. Otherwise, it must be a positive integer.
Returns:

An array of the same class as the source array, with shape shape + array.shape[len(shape):] and data type dtype.

Example of image upsampling:

>>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
>>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.

Example of image downsampling:

>>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
>>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
>>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
>>> downsampled = resize(array, (24, 48))

Unit test:

>>> result = resize([1.0, 4.0, 5.0], shape=(4,))
>>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
def jaxjit_resize(array: Array, /, *args: Any, **kwargs: Any) -> Array:
2909def jaxjit_resize(array: _Array, /, *args: Any, **kwargs: Any) -> _Array:
2910  """Compute `resize` but with resize function jitted using Jax."""
2911  return _create_jaxjit_resize()(array, *args, **kwargs)  # pylint: disable=not-callable

Compute resize but with resize function jitted using Jax.

def uniform_resize( array: Array, /, shape: Iterable[int], *, object_fit: Literal['contain', 'cover'] = 'contain', gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, boundary: str | Boundary | Iterable[str | Boundary] = 'border', scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0, **kwargs: Any) -> Array:
2914def uniform_resize(
2915    array: _Array,
2916    /,
2917    shape: Iterable[int],
2918    *,
2919    object_fit: Literal['contain', 'cover'] = 'contain',
2920    gridtype: str | Gridtype | None = None,
2921    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2922    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2923    boundary: str | Boundary | Iterable[str | Boundary] = 'natural',  # Instead of 'auto' default.
2924    scale: float | Iterable[float] = 1.0,
2925    translate: float | Iterable[float] = 0.0,
2926    **kwargs: Any,
2927) -> _Array:
2928  """Resample `array` onto a grid with resolution `shape` but with uniform scaling.
2929
2930  Calls function `resize` with `scale` and `translate` set such that the aspect ratio of `array`
2931  is preserved.  The effect is similar to CSS `object-fit: contain`.
2932  The parameter `boundary` (whose default is changed to `'natural'`) determines the values assigned
2933  outside the source domain.
2934
2935  Args:
2936    array: Regular grid of source sample values.
2937    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2938      `array` must have at least as many dimensions as `len(shape)`.
2939    object_fit: Like CSS `object-fit`.  If `'contain'`, `array` is resized uniformly to fit within
2940      `shape`. If `'cover'`, `array` is resized to fully cover `shape`.
2941    gridtype: Placement of samples on all dimensions of both the source and output domain grids.
2942    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2943    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2944    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2945      a name in `BOUNDARIES` or a `Boundary` instance.  The default is `'natural'`, which assigns
2946      `cval` to output points that map outside the source unit domain.
2947    scale: Parameter may not be specified.
2948    translate: Parameter may not be specified.
2949    **kwargs: Additional parameters for `resize` function (including `cval`).
2950
2951  Returns:
2952    An array with shape `shape + array.shape[len(shape):]`.
2953
2954  >>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
2955  array([[0., 1., 1., 0.],
2956         [0., 1., 1., 0.]])
2957
2958  >>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
2959  array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
2960         [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
2961
2962  >>> a = np.arange(6.0).reshape(2, 3)
2963  >>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
2964  array([[0.5, 1.5],
2965         [3.5, 4.5]])
2966  """
2967  if scale != 1.0 or translate != 0.0:
2968    raise ValueError('`uniform_resize()` does not accept `scale` or `translate` parameters.')
2969  if isinstance(array, (tuple, list)):
2970    array = np.asarray(array)
2971  shape = tuple(shape)
2972  array_ndim = len(array.shape)
2973  if not 0 < len(shape) <= array_ndim:
2974    raise ValueError(f'Shape {array.shape} cannot be resized to {shape}.')
2975  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2976      gridtype, src_gridtype, dst_gridtype, len(shape), len(shape)
2977  )
2978  raw_scales = [
2979      dst_gridtype2[dim].size_in_samples(shape[dim])
2980      / src_gridtype2[dim].size_in_samples(array.shape[dim])
2981      for dim in range(len(shape))
2982  ]
2983  scale0 = {'contain': min(raw_scales), 'cover': max(raw_scales)}[object_fit]
2984  scale2 = scale0 / np.array(raw_scales)
2985  translate = (1.0 - scale2) / 2
2986  return resize(array, shape, boundary=boundary, scale=scale2, translate=translate, **kwargs)

Resample array onto a grid with resolution shape but with uniform scaling.

Calls function resize with scale and translate set such that the aspect ratio of array is preserved. The effect is similar to CSS object-fit: contain. The parameter boundary (whose default is changed to 'natural') determines the values assigned outside the source domain.

Arguments:
  • array: Regular grid of source sample values.
  • shape: The number of grid samples in each coordinate dimension of the output array. The source array must have at least as many dimensions as len(shape).
  • object_fit: Like CSS object-fit. If 'contain', array is resized uniformly to fit within shape. If 'cover', array is resized to fully cover shape.
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids.
  • src_gridtype: Placement of the samples in the source domain grid for each dimension.
  • dst_gridtype: Placement of the samples in the output domain grid for each dimension.
  • boundary: The reconstruction boundary rule for each dimension in shape, specified as either a name in BOUNDARIES or a Boundary instance. The default is 'natural', which assigns cval to output points that map outside the source unit domain.
  • scale: Parameter may not be specified.
  • translate: Parameter may not be specified.
  • **kwargs: Additional parameters for resize function (including cval).
Returns:

An array with shape shape + array.shape[len(shape):].

>>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
array([[0., 1., 1., 0.],
       [0., 1., 1., 0.]])
>>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
       [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
>>> a = np.arange(6.0).reshape(2, 3)
>>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
array([[0.5, 1.5],
       [3.5, 4.5]])
def resample( array: Array, /, coords: ArrayLike, *, gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual', boundary: str | Boundary | Iterable[str | Boundary] = 'auto', cval: ArrayLike = 0, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, gamma: str | Gamma | None = None, src_gamma: str | Gamma | None = None, dst_gamma: str | Gamma | None = None, jacobian: ArrayLike | None = None, precision: DTypeLike = None, dtype: DTypeLike = None, max_block_size: int = 40000, debug: bool = False) -> Array:
2992def resample(
2993    array: _Array,
2994    /,
2995    coords: _ArrayLike,
2996    *,
2997    gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual',
2998    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2999    cval: _ArrayLike = 0.0,
3000    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3001    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3002    gamma: str | Gamma | None = None,
3003    src_gamma: str | Gamma | None = None,
3004    dst_gamma: str | Gamma | None = None,
3005    jacobian: _ArrayLike | None = None,
3006    precision: _DTypeLike = None,
3007    dtype: _DTypeLike = None,
3008    max_block_size: int = 40_000,
3009    debug: bool = False,
3010) -> _Array:
3011  """Interpolate `array` (a grid of samples) at specified unit-domain coordinates `coords`.
3012
3013  The last dimension of `coords` contains unit-domain coordinates at which to interpolate the
3014  domain grid samples in `array`.
3015
3016  The number of coordinates (`coords.shape[-1]`) determines how to interpret `array`: its first
3017  `coords.shape[-1]` dimensions define the grid, and the remaining dimensions describe each grid
3018  sample (e.g., scalar, vector, tensor).
3019
3020  Concretely, the grid has shape `array.shape[:coords.shape[-1]]` and each grid sample has shape
3021  `array.shape[coords.shape[-1]:]`.
3022
3023  Examples include:
3024
3025  - Resample a grayscale image with `array.shape = height, width` onto a new grayscale image with
3026    `new.shape = height2, width2` by using `coords.shape = height2, width2, 2`.
3027
3028  - Resample an RGB image with `array.shape = height, width, 3` onto a new RGB image with
3029    `new.shape = height2, width2, 3` by using `coords.shape = height2, width2, 2`.
3030
3031  - Sample an RGB image at `num` 2D points along a line segment by using `coords.shape = num, 2`.
3032
3033  - Sample an RGB image at a single 2D point by using `coords.shape = (2,)`.
3034
3035  - Sample a 3D grid of 3x3 Jacobians with `array.shape = nz, ny, nx, 3, 3` along a 2D plane by
3036    using `coords.shape = height, width, 3`.
3037
3038  - Map a grayscale image through a color map by using `array.shape = 256, 3` and
3039    `coords.shape = height, width`.
3040
3041  Args:
3042    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3043      The array must have numeric type.  The coordinate dimensions appear first, and
3044      each grid sample may have an arbitrary shape.  Each grid dimension must be at least 1 for
3045      a `'dual'` grid or at least 2 for a `'primal'` grid.
3046    coords: Grid of points at which to resample `array`.  The point coordinates are in the last
3047      dimension of `coords`.  The domain associated with the source grid is a unit hypercube,
3048      i.e. with a range [0, 1] on each coordinate dimension.  The output grid has shape
3049      `coords.shape[:-1]` and each of its grid samples has shape `array.shape[coords.shape[-1]:]`.
3050    gridtype: Placement of the samples in the source domain grid for each dimension, specified as
3051      either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`.
3052    boundary: The reconstruction boundary rule for each dimension in `coords.shape[-1]`, specified
3053      as either a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses
3054      `'reflect'` for upsampling and `'clamp'` for downsampling.
3055    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
3056      onto the shape `array.shape[coords.shape[-1]:]`.  It is subject to `src_gamma`.
3057    filter: The reconstruction kernel for each dimension in `coords.shape[-1]`, specified as either
3058      a filter name in `FILTERS` or a `Filter` instance.
3059    prefilter: The prefilter kernel for each dimension in `coords.shape[:-1]`, specified as either
3060      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3061      (i.e., minification).  If `None`, it inherits the value of `filter`.
3062    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples
3063      from `array` and when creating output grid samples.  It is specified as either a name in
3064      `GAMMAS` or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default
3065      is `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
3066      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
3067      range [0.0, 1.0].
3068    src_gamma: Component transfer function used to "decode" `array` samples.
3069      Parameters `gamma` and `src_gamma` cannot both be set.
3070    dst_gamma: Component transfer function used to "encode" the output samples.
3071      Parameters `gamma` and `dst_gamma` cannot both be set.
3072    jacobian: Optional array, which must be broadcastable onto the shape
3073      `coords.shape[:-1] + (coords.shape[-1], coords.shape[-1])`, storing for each point in the
3074      output grid the Jacobian matrix of the map from the unit output domain to the unit source
3075      domain.  If omitted, it is estimated by computing finite differences on `coords`.
3076    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3077      on `array.dtype`, `coords.dtype`, and `dtype`.
3078    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3079      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3080      to the uint range.
3081    max_block_size: If nonzero, maximum number of grid points in `coords` before the resampling
3082      evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
3083    debug: Show internal information.
3084
3085  Returns:
3086    A new sample grid of shape `coords.shape[:-1]`, represented as an array of shape
3087    `coords.shape[:-1] + array.shape[coords.shape[-1]:]`, of the same array library type as
3088    the source array.
3089
3090  **Example of resample operation:**
3091
3092  <center>
3093  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_warp_coords.png"/>
3094  </center>
3095
3096  For reference, the identity resampling for a scalar-valued grid with the default grid-type
3097  `'dual'` is:
3098
3099  >>> array = np.random.default_rng(0).random((5, 7, 3))
3100  >>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
3101  >>> new_array = resample(array, coords)
3102  >>> assert np.allclose(new_array, array)
3103
3104  It is more efficient to use the function `resize` for the special case where the `coords` are
3105  obtained as simple scaling and translation of a new regular grid over the source domain:
3106
3107  >>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
3108  >>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
3109  >>> coords = (coords - translate) / scale
3110  >>> resampled = resample(array, coords)
3111  >>> resized = resize(array, new_shape, scale=scale, translate=translate)
3112  >>> assert np.allclose(resampled, resized)
3113  """
3114  if isinstance(array, (tuple, list)):
3115    array = np.asarray(array)
3116  arraylib = _arr_arraylib(array)
3117  if len(array.shape) == 0:
3118    array = array[None]
3119  coords = np.atleast_1d(coords)
3120  if not np.issubdtype(_arr_dtype(array), np.number):
3121    raise ValueError(f'Type {array.dtype} is not numeric.')
3122  if not np.issubdtype(coords.dtype, np.floating):
3123    raise ValueError(f'Type {coords.dtype} is not floating.')
3124  array_ndim = len(array.shape)
3125  if coords.ndim == 1 and coords.shape[0] > 1 and array_ndim == 1:
3126    coords = coords[:, None]
3127  grid_ndim = coords.shape[-1]
3128  grid_shape = array.shape[:grid_ndim]
3129  sample_shape = array.shape[grid_ndim:]
3130  resampled_ndim = coords.ndim - 1
3131  resampled_shape = coords.shape[:-1]
3132  if grid_ndim > array_ndim:
3133    raise ValueError(
3134        f'There are more coordinate dimensions ({grid_ndim}) in {coords=} than in {array.shape=}.'
3135    )
3136  gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(gridtype), grid_ndim)]
3137  boundary2 = np.broadcast_to(np.array(boundary), grid_ndim).tolist()
3138  cval = np.broadcast_to(cval, sample_shape)
3139  prefilter = filter if prefilter is None else prefilter
3140  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), grid_ndim)]
3141  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), resampled_ndim)]
3142  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3143  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, _arr_dtype(array), dtype)
3144  del gridtype, boundary, filter, prefilter, src_gamma, dst_gamma
3145  if jacobian is not None:
3146    jacobian = np.broadcast_to(jacobian, resampled_shape + (coords.shape[-1],) * 2)
3147  precision = _get_precision(precision, [_arr_dtype(array), dtype], [coords.dtype])
3148  weight_precision = _real_precision(precision)
3149  coords = coords.astype(weight_precision, copy=False)
3150  is_minification = False  # Current limitation; no prefiltering!
3151  assert max_block_size >= 0 or max_block_size == _MAX_BLOCK_SIZE_RECURSING
3152  for dim in range(grid_ndim):
3153    if boundary2[dim] == 'auto':
3154      boundary2[dim] = 'clamp' if is_minification else 'reflect'
3155    boundary2[dim] = _get_boundary(boundary2[dim])
3156
3157  if max_block_size != _MAX_BLOCK_SIZE_RECURSING:
3158    array = src_gamma2.decode(array, precision)
3159    for dim in range(grid_ndim):
3160      assert not is_minification
3161      if filter2[dim].requires_digital_filter:
3162        array = _apply_digital_filter_1d(
3163            array, gridtype2[dim], boundary2[dim], cval, filter2[dim], axis=dim
3164        )
3165    cval = _arr_numpy(src_gamma2.decode(cval, precision))
3166
3167  if math.prod(resampled_shape) > max_block_size > 0:
3168    block_shape = _block_shape_with_min_size(resampled_shape, max_block_size)
3169    if debug:
3170      print(f'(resample: splitting coords into blocks {block_shape}).')
3171    coord_blocks = _split_array_into_blocks(coords, block_shape)
3172
3173    def process_block(coord_block: _NDArray) -> _Array:
3174      return resample(
3175          array,
3176          coord_block,
3177          gridtype=gridtype2,
3178          boundary=boundary2,
3179          cval=cval,
3180          filter=filter2,
3181          prefilter=prefilter2,
3182          src_gamma='identity',
3183          dst_gamma=dst_gamma2,
3184          jacobian=jacobian,
3185          precision=precision,
3186          dtype=dtype,
3187          max_block_size=_MAX_BLOCK_SIZE_RECURSING,
3188      )
3189
3190    result_blocks = _map_function_over_blocks(coord_blocks, process_block)
3191    array = _merge_array_from_blocks(result_blocks)
3192    return array
3193
3194  # A concrete example of upsampling:
3195  #   array = np.ones((5, 7, 3))  # source RGB image has height=5 width=7
3196  #   coords = np.random.default_rng(0).random((8, 9, 2))  # output RGB image has height=8 width=9
3197  #   resample(array, coords, filter=('cubic', 'lanczos3'))
3198  #   grid_shape = 5, 7  grid_ndim = 2
3199  #   resampled_shape = 8, 9  resampled_ndim = 2
3200  #   sample_shape = (3,)
3201  #   src_float_index.shape = 8, 9
3202  #   src_first_index.shape = 8, 9
3203  #   sample_index.shape = (4,) for dim == 0, then (6,) for dim == 1
3204  #   weight = [shape(8, 9, 4), shape(8, 9, 6)]
3205  #   src_index = [shape(8, 9, 4), shape(8, 9, 6)]
3206
3207  # Both:[shape(8, 9, 4), shape(8, 9, 6)]
3208  weight: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3209  src_index: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3210  uses_cval = False
3211  all_num_samples = []  # will be [4, 6]
3212
3213  for dim in range(grid_ndim):
3214    src_size = grid_shape[dim]  # scalar
3215    coords_dim = coords[..., dim]  # (8, 9)
3216    radius = filter2[dim].radius  # scalar
3217    num_samples = int(np.ceil(radius * 2))  # scalar
3218    all_num_samples.append(num_samples)
3219
3220    boundary_dim = boundary2[dim]
3221    coords_dim = boundary_dim.preprocess_coordinates(coords_dim)
3222
3223    # Sample positions mapped back to source unit domain [0, 1].
3224    src_float_index = gridtype2[dim].index_from_point(coords_dim, src_size)  # (8, 9)
3225    src_first_index = (
3226        np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
3227        - (num_samples - 1) // 2
3228    )  # (8, 9)
3229
3230    sample_index = np.arange(num_samples, dtype=np.int32)  # (4,) then (6,)
3231    src_index[dim] = src_first_index[..., None] + sample_index  # (8, 9, 4) then (8, 9, 6)
3232    if filter2[dim].name == 'trapezoid':
3233      # (It might require changing the filter radius at every sample.)
3234      raise ValueError('resample() cannot use adaptive `trapezoid` filter.')
3235    if filter2[dim].name == 'impulse':
3236      weight[dim] = np.ones_like(src_index[dim], weight_precision)
3237    else:
3238      x = src_float_index[..., None] - src_index[dim].astype(weight_precision)
3239      weight[dim] = filter2[dim](x).astype(weight_precision, copy=False)
3240      if filter2[dim].name != 'narrowbox' and (
3241          is_minification or not filter2[dim].partition_of_unity
3242      ):
3243        weight[dim] = weight[dim] / weight[dim].sum(axis=-1)[..., None]
3244
3245    src_index[dim], weight[dim] = boundary_dim.apply(
3246        src_index[dim], weight[dim], coords_dim, src_size, gridtype2[dim]
3247    )
3248    if boundary_dim.uses_cval or filter2[dim].name == 'narrowbox':
3249      uses_cval = True
3250
3251  # Gather the samples.
3252
3253  # Recall that src_index = [shape(8, 9, 4), shape(8, 9, 6)].
3254  src_index_expanded = []
3255  for dim in range(grid_ndim):
3256    src_index_dim = np.moveaxis(
3257        src_index[dim].reshape(src_index[dim].shape + (1,) * (grid_ndim - 1)),
3258        resampled_ndim,
3259        resampled_ndim + dim,
3260    )
3261    src_index_expanded.append(src_index_dim)
3262  indices = tuple(src_index_expanded)  # (shape(8, 9, 4, 1), shape(8, 9, 1, 6))
3263  samples = _arr_getitem(array, indices)  # (8, 9, 4, 6, 3)
3264
3265  # Indirectly derive samples.ndim (which is unavailable during Tensorflow grad computation).
3266  samples_ndim = resampled_ndim + grid_ndim + len(sample_shape)
3267
3268  # Compute an Einstein summation over the samples and each of the per-dimension weights.
3269
3270  def label(dims: Iterable[int]) -> str:
3271    return ''.join(chr(ord('a') + i) for i in dims)
3272
3273  operands = [samples]  # (8, 9, 4, 6, 3)
3274  assert samples_ndim < 26  # Letters 'a' through 'z'.
3275  labels = [label(range(samples_ndim))]  # ['abcde']
3276  for dim in range(grid_ndim):
3277    operands.append(weight[dim])  # (8, 9, 4), then (8, 9, 6)
3278    labels.append(label(list(range(resampled_ndim)) + [resampled_ndim + dim]))  # 'abc' then 'abd'
3279  output_label = label(
3280      list(range(resampled_ndim)) + list(range(resampled_ndim + grid_ndim, samples_ndim))
3281  )  # 'abe'
3282  subscripts = ','.join(labels) + '->' + output_label  # 'abcde,abc,abd->abe'
3283  # Starting in numpy 2.0, np.einsum() outputs np.float64 even with all np.float32 inputs;
3284  # GPT: "aligns np.einsum with other functions where intermediate calculations use higher
3285  # precision (np.float64) regardless of input type when floating-point arithmetic is involved."
3286  # we could explicitly add the parameter `dtype=precision`.
3287  array = _arr_einsum(subscripts, *operands)  # (8, 9, 3)
3288
3289  # Gathering `samples` is the memory bottleneck.  It would be ideal if the gather() and einsum()
3290  # computations could be fused.  In Jax, https://github.com/google/jax/issues/3206 suggests
3291  # that this may become possible.  In any case, for large outputs it helps to partition the
3292  # evaluation over output tiles (using max_block_size).
3293
3294  if uses_cval:
3295    cval_weight = 1.0 - np.multiply.reduce(
3296        [weight[dim].sum(axis=-1) for dim in range(resampled_ndim)]
3297    )  # (8, 9)
3298    cval_weight_reshaped = cval_weight.reshape(cval_weight.shape + (1,) * len(sample_shape))
3299    array += _make_array((cval_weight_reshaped * cval).astype(precision, copy=False), arraylib)
3300
3301  array = dst_gamma2.encode(array, dtype)
3302  return array

Interpolate array (a grid of samples) at specified unit-domain coordinates coords.

The last dimension of coords contains unit-domain coordinates at which to interpolate the domain grid samples in array.

The number of coordinates (coords.shape[-1]) determines how to interpret array: its first coords.shape[-1] dimensions define the grid, and the remaining dimensions describe each grid sample (e.g., scalar, vector, tensor).

Concretely, the grid has shape array.shape[:coords.shape[-1]] and each grid sample has shape array.shape[coords.shape[-1]:].

Examples include:

  • Resample a grayscale image with array.shape = height, width onto a new grayscale image with new.shape = height2, width2 by using coords.shape = height2, width2, 2.

  • Resample an RGB image with array.shape = height, width, 3 onto a new RGB image with new.shape = height2, width2, 3 by using coords.shape = height2, width2, 2.

  • Sample an RGB image at num 2D points along a line segment by using coords.shape = num, 2.

  • Sample an RGB image at a single 2D point by using coords.shape = (2,).

  • Sample a 3D grid of 3x3 Jacobians with array.shape = nz, ny, nx, 3, 3 along a 2D plane by using coords.shape = height, width, 3.

  • Map a grayscale image through a color map by using array.shape = 256, 3 and coords.shape = height, width.

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. The coordinate dimensions appear first, and each grid sample may have an arbitrary shape. Each grid dimension must be at least 1 for a 'dual' grid or at least 2 for a 'primal' grid.
  • coords: Grid of points at which to resample array. The point coordinates are in the last dimension of coords. The domain associated with the source grid is a unit hypercube, i.e. with a range [0, 1] on each coordinate dimension. The output grid has shape coords.shape[:-1] and each of its grid samples has shape array.shape[coords.shape[-1]:].
  • gridtype: Placement of the samples in the source domain grid for each dimension, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual'.
  • boundary: The reconstruction boundary rule for each dimension in coords.shape[-1], specified as either a name in BOUNDARIES or a Boundary instance. The special value 'auto' uses 'reflect' for upsampling and 'clamp' for downsampling.
  • cval: Constant value used beyond the samples by some boundary rules. It must be broadcastable onto the shape array.shape[coords.shape[-1]:]. It is subject to src_gamma.
  • filter: The reconstruction kernel for each dimension in coords.shape[-1], specified as either a filter name in FILTERS or a Filter instance.
  • prefilter: The prefilter kernel for each dimension in coords.shape[:-1], specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter.
  • gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from array and when creating output grid samples. It is specified as either a name in GAMMAS or a Gamma instance. If both array.dtype and dtype are uint, the default is 'power2'. If both are non-uint, the default is 'identity'. Otherwise, gamma or src_gamma/dst_gamma must be set. Gamma correction assumes that float values are in the range [0.0, 1.0].
  • src_gamma: Component transfer function used to "decode" array samples. Parameters gamma and src_gamma cannot both be set.
  • dst_gamma: Component transfer function used to "encode" the output samples. Parameters gamma and dst_gamma cannot both be set.
  • jacobian: Optional array, which must be broadcastable onto the shape coords.shape[:-1] + (coords.shape[-1], coords.shape[-1]), storing for each point in the output grid the Jacobian matrix of the map from the unit output domain to the unit source domain. If omitted, it is estimated by computing finite differences on coords.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype, coords.dtype, and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • max_block_size: If nonzero, maximum number of grid points in coords before the resampling evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
  • debug: Show internal information.
Returns:

A new sample grid of shape coords.shape[:-1], represented as an array of shape coords.shape[:-1] + array.shape[coords.shape[-1]:], of the same array library type as the source array.

Example of resample operation:

For reference, the identity resampling for a scalar-valued grid with the default grid-type 'dual' is:

>>> array = np.random.default_rng(0).random((5, 7, 3))
>>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
>>> new_array = resample(array, coords)
>>> assert np.allclose(new_array, array)

It is more efficient to use the function resize for the special case where the coords are obtained as simple scaling and translation of a new regular grid over the source domain:

>>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
>>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
>>> coords = (coords - translate) / scale
>>> resampled = resample(array, coords)
>>> resized = resize(array, new_shape, scale=scale, translate=translate)
>>> assert np.allclose(resampled, resized)
def resample_affine( array: Array, /, shape: Iterable[int], matrix: ArrayLike, *, gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, precision: DTypeLike = None, dtype: DTypeLike = None, **kwargs: Any) -> Array:
3305def resample_affine(
3306    array: _Array,
3307    /,
3308    shape: Iterable[int],
3309    matrix: _ArrayLike,
3310    *,
3311    gridtype: str | Gridtype | None = None,
3312    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3313    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3314    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3315    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3316    precision: _DTypeLike = None,
3317    dtype: _DTypeLike = None,
3318    **kwargs: Any,
3319) -> _Array:
3320  """Resample a source array using an affinely transformed grid of given shape.
3321
3322  The `matrix` transformation can be linear,
3323    `source_point = matrix @ destination_point`,
3324  or it can be affine where the last matrix column is an offset vector,
3325    `source_point = matrix @ (destination_point, 1.0)`.
3326
3327  Args:
3328    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3329      The array must have numeric type.  The number of grid dimensions is determined from
3330      `matrix.shape[0]`; the remaining dimensions are for each sample value and are all
3331      linearly interpolated.
3332    shape: Dimensions of the desired destination grid.  The number of destination grid dimensions
3333      may be different from that of the source grid.
3334    matrix: 2D array for a linear or affine transform from unit-domain destination points
3335      (in a space with `len(shape)` dimensions) into unit-domain source points (in a space with
3336      `matrix.shape[0]` dimensions).  If the matrix has `len(shape) + 1` columns, the last column
3337      is the affine offset (i.e., translation).
3338    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
3339      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
3340      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
3341    src_gridtype: Placement of samples in the source domain grid for each dimension.
3342      Parameters `gridtype` and `src_gridtype` cannot both be set.
3343    dst_gridtype: Placement of samples in the output domain grid for each dimension.
3344      Parameters `gridtype` and `dst_gridtype` cannot both be set.
3345    filter: The reconstruction kernel for each dimension in `matrix.shape[0]`, specified as either
3346      a filter name in `FILTERS` or a `Filter` instance.
3347    prefilter: The prefilter kernel for each dimension in `len(shape)`, specified as either
3348      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3349      (i.e., minification).  If `None`, it inherits the value of `filter`.
3350    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3351      on `array.dtype` and `dtype`.
3352    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3353      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3354      to the uint range.
3355    **kwargs: Additional parameters for `resample` function.
3356
3357  Returns:
3358    An array of the same class as the source `array`, representing a grid with specified `shape`,
3359    where each grid value is resampled from `array`.  Thus the shape of the returned array is
3360    `shape + array.shape[matrix.shape[0]:]`.
3361  """
3362  if isinstance(array, (tuple, list)):
3363    array = np.asarray(array)
3364  shape = tuple(shape)
3365  matrix = np.asarray(matrix)
3366  dst_ndim = len(shape)
3367  if matrix.ndim != 2:
3368    raise ValueError(f'Array {matrix} is not 2D matrix.')
3369  src_ndim = matrix.shape[0]
3370  # grid_shape = array.shape[:src_ndim]
3371  is_affine = matrix.shape[1] == dst_ndim + 1
3372  if src_ndim > len(array.shape):
3373    raise ValueError(
3374        f'Matrix {matrix} has more rows ({matrix.shape[0]}) than ndim in {array.shape=}.'
3375    )
3376  if matrix.shape[1] != dst_ndim and not is_affine:
3377    raise ValueError(
3378        f'Matrix has {matrix.shape=}, but we expect either {dst_ndim} or {dst_ndim + 1} columns.'
3379    )
3380  src_gridtype2, dst_gridtype2 = _get_gridtypes(
3381      gridtype, src_gridtype, dst_gridtype, src_ndim, dst_ndim
3382  )
3383  prefilter = filter if prefilter is None else prefilter
3384  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), src_ndim)]
3385  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), dst_ndim)]
3386  del src_gridtype, dst_gridtype, filter, prefilter
3387  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3388  precision = _get_precision(precision, [_arr_dtype(array), dtype], [])
3389  weight_precision = _real_precision(precision)
3390
3391  dst_position_list = []  # per dimension
3392  for dim in range(dst_ndim):
3393    dst_size = shape[dim]
3394    dst_index = np.arange(dst_size, dtype=weight_precision)
3395    dst_position_list.append(dst_gridtype2[dim].point_from_index(dst_index, dst_size))
3396  dst_position = np.meshgrid(*dst_position_list, indexing='ij')
3397
3398  linear_matrix = matrix[:, :-1] if is_affine else matrix
3399  src_position = np.tensordot(linear_matrix, dst_position, 1)
3400  coords = np.moveaxis(src_position, 0, -1)
3401  if is_affine:
3402    coords += matrix[:, -1]
3403
3404  # TODO: Based on grid_shape, shape, linear_matrix, and prefilter, determine a
3405  # convolution prefilter and apply it to bandlimit 'array', using boundary for padding.
3406
3407  return resample(
3408      array,
3409      coords,
3410      gridtype=src_gridtype2,
3411      filter=filter2,
3412      prefilter=prefilter2,
3413      precision=precision,
3414      dtype=dtype,
3415      **kwargs,
3416  )

Resample a source array using an affinely transformed grid of given shape.

The matrix transformation can be linear, source_point = matrix @ destination_point, or it can be affine where the last matrix column is an offset vector, source_point = matrix @ (destination_point, 1.0).

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. The number of grid dimensions is determined from matrix.shape[0]; the remaining dimensions are for each sample value and are all linearly interpolated.
  • shape: Dimensions of the desired destination grid. The number of destination grid dimensions may be different from that of the source grid.
  • matrix: 2D array for a linear or affine transform from unit-domain destination points (in a space with len(shape) dimensions) into unit-domain source points (in a space with matrix.shape[0] dimensions). If the matrix has len(shape) + 1 columns, the last column is the affine offset (i.e., translation).
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual' if gridtype, src_gridtype, and dst_gridtype are all kept None.
  • src_gridtype: Placement of samples in the source domain grid for each dimension. Parameters gridtype and src_gridtype cannot both be set.
  • dst_gridtype: Placement of samples in the output domain grid for each dimension. Parameters gridtype and dst_gridtype cannot both be set.
  • filter: The reconstruction kernel for each dimension in matrix.shape[0], specified as either a filter name in FILTERS or a Filter instance.
  • prefilter: The prefilter kernel for each dimension in len(shape), specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • **kwargs: Additional parameters for resample function.
Returns:

An array of the same class as the source array, representing a grid with specified shape, where each grid value is resampled from array. Thus the shape of the returned array is shape + array.shape[matrix.shape[0]:].

def rotation_about_center_in_2d( src_shape: ArrayLike, /, angle: float, *, new_shape: ArrayLike | None = None, scale: float = 1.0) -> np.ndarray:
3447def rotation_about_center_in_2d(
3448    src_shape: _ArrayLike,
3449    /,
3450    angle: float,
3451    *,
3452    new_shape: _ArrayLike | None = None,
3453    scale: float = 1.0,
3454) -> _NDArray:
3455  """Return the 3x3 matrix mapping destination into a source unit domain.
3456
3457  The returned matrix accounts for the possibly non-square domain shapes.
3458
3459  Args:
3460    src_shape: Resolution `(ny, nx)` of the source domain grid.
3461    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3462      onto the destination domain.
3463    new_shape: Resolution `(ny, nx)` of the destination domain grid; it defaults to `src_shape`.
3464    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3465  """
3466
3467  def translation_matrix(vector: _NDArray) -> _NDArray:
3468    matrix = np.eye(len(vector) + 1)
3469    matrix[:-1, -1] = vector
3470    return matrix
3471
3472  def scaling_matrix(scale: _NDArray) -> _NDArray:
3473    return np.diag(tuple(scale) + (1.0,))
3474
3475  def rotation_matrix_2d(angle: float) -> _NDArray:
3476    cos, sin = np.cos(angle), np.sin(angle)
3477    return np.array([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
3478
3479  src_shape = np.asarray(src_shape)
3480  new_shape = src_shape if new_shape is None else np.asarray(new_shape)
3481  _check_eq(src_shape.shape, (2,))
3482  _check_eq(new_shape.shape, (2,))
3483  half = np.array([0.5, 0.5])
3484  matrix = (
3485      translation_matrix(half)
3486      @ scaling_matrix(min(src_shape) / src_shape)
3487      @ rotation_matrix_2d(angle)
3488      @ scaling_matrix(scale * new_shape / min(new_shape))
3489      @ translation_matrix(-half)
3490  )
3491  assert np.allclose(matrix[-1], [0.0, 0.0, 1.0])
3492  return matrix

Return the 3x3 matrix mapping destination into a source unit domain.

The returned matrix accounts for the possibly non-square domain shapes.

Arguments:
  • src_shape: Resolution (ny, nx) of the source domain grid.
  • angle: Angle in radians (positive from x to y axis) applied when mapping the source domain onto the destination domain.
  • new_shape: Resolution (ny, nx) of the destination domain grid; it defaults to src_shape.
  • scale: Scaling factor applied when mapping the source domain onto the destination domain.
def rotate_image_about_center( image: np.ndarray, /, angle: float, *, new_shape: ArrayLike | None = None, scale: float = 1.0, num_rotations: int = 1, **kwargs: Any) -> np.ndarray:
3495def rotate_image_about_center(
3496    image: _NDArray,
3497    /,
3498    angle: float,
3499    *,
3500    new_shape: _ArrayLike | None = None,
3501    scale: float = 1.0,
3502    num_rotations: int = 1,
3503    **kwargs: Any,
3504) -> _NDArray:
3505  """Return a copy of `image` rotated about its center.
3506
3507  Args:
3508    image: Source grid samples; the first two dimensions are spatial (ny, nx).
3509    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3510      onto the destination domain.
3511    new_shape: Resolution `(ny, nx)` of the output grid; it defaults to `image.shape[:2]`.
3512    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3513    num_rotations: Number of rotations (each by `angle`).  Successive resamplings are useful in
3514      analyzing the filtering quality.
3515    **kwargs: Additional parameters for `resample_affine`.
3516  """
3517  new_shape = image.shape[:2] if new_shape is None else np.asarray(new_shape)
3518  matrix = rotation_about_center_in_2d(image.shape[:2], angle, new_shape=new_shape, scale=scale)
3519  for _ in range(num_rotations):
3520    image = resample_affine(image, new_shape, matrix[:-1], **kwargs)
3521  return image

Return a copy of image rotated about its center.

Arguments:
  • image: Source grid samples; the first two dimensions are spatial (ny, nx).
  • angle: Angle in radians (positive from x to y axis) applied when mapping the source domain onto the destination domain.
  • new_shape: Resolution (ny, nx) of the output grid; it defaults to image.shape[:2].
  • scale: Scaling factor applied when mapping the source domain onto the destination domain.
  • num_rotations: Number of rotations (each by angle). Successive resamplings are useful in analyzing the filtering quality.
  • **kwargs: Additional parameters for resample_affine.
def pil_image_resize(array: ArrayLike, /, shape: Iterable[int], *, filter: str) -> np.ndarray:
3524def pil_image_resize(
3525    array: _ArrayLike,
3526    /,
3527    shape: Iterable[int],
3528    *,
3529    filter: str,
3530    boundary: str = 'natural',
3531    cval: float = 0.0,
3532) -> _NDArray:
3533  """Invoke `PIL.Image.resize` using the same parameters as `resize`."""
3534  import PIL.Image
3535
3536  if boundary != 'natural':
3537    raise ValueError(f"{boundary=} must equal 'natural'.")
3538  del cval
3539  array = np.asarray(array)
3540  assert 1 <= array.ndim <= 3
3541  assert np.issubdtype(array.dtype, np.floating)
3542  shape = tuple(shape)
3543  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3544  if array.ndim == 1:
3545    return pil_image_resize(array[None], (1, *shape), filter=filter)[0]
3546  if not hasattr(PIL.Image, 'Resampling'):  # Pillow<9.0
3547    PIL.Image.Resampling = PIL.Image  # type: ignore
3548  filters = {
3549      'impulse': PIL.Image.Resampling.NEAREST,
3550      'box': PIL.Image.Resampling.BOX,
3551      'triangle': PIL.Image.Resampling.BILINEAR,
3552      'hamming1': PIL.Image.Resampling.HAMMING,
3553      'cubic': PIL.Image.Resampling.BICUBIC,
3554      'lanczos3': PIL.Image.Resampling.LANCZOS,
3555  }
3556  if filter not in filters:
3557    raise ValueError(f'{filter=} not in {filters=}.')
3558  pil_resample = filters[filter]
3559  ny, nx = shape
3560  if array.ndim == 2:
3561    return np.array(PIL.Image.fromarray(array).resize((nx, ny), resample=pil_resample), array.dtype)
3562  stack = []
3563  for channel in np.moveaxis(array, -1, 0):
3564    pil_image = PIL.Image.fromarray(channel).resize((nx, ny), resample=pil_resample)
3565    stack.append(np.array(pil_image, array.dtype))
3566  return np.dstack(stack)

Invoke PIL.Image.resize using the same parameters as resize.

def cv_resize(array: ArrayLike, /, shape: Iterable[int], *, filter: str) -> np.ndarray:
3569def cv_resize(
3570    array: _ArrayLike,
3571    /,
3572    shape: Iterable[int],
3573    *,
3574    filter: str,
3575    boundary: str = 'clamp',
3576    cval: float = 0.0,
3577) -> _NDArray:
3578  """Invoke `cv.resize` using the same parameters as `resize`."""
3579  import cv2 as cv
3580
3581  if boundary != 'clamp':
3582    raise ValueError(f"{boundary=} must equal 'clamp'.")
3583  del cval
3584  array = np.asarray(array)
3585  assert 1 <= array.ndim <= 3
3586  shape = tuple(shape)
3587  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3588  if array.ndim == 1:
3589    return cv_resize(array[None], (1, *shape), filter=filter)[0]
3590  filters = {
3591      'impulse': cv.INTER_NEAREST,  # Or consider cv.INTER_NEAREST_EXACT.
3592      'triangle': cv.INTER_LINEAR_EXACT,  # Or just cv.INTER_LINEAR.
3593      'trapezoid': cv.INTER_AREA,
3594      'sharpcubic': cv.INTER_CUBIC,
3595      'lanczos4': cv.INTER_LANCZOS4,
3596  }
3597  if filter not in filters:
3598    raise ValueError(f'{filter=} not in {filters=}.')
3599  interpolation = filters[filter]
3600  result = cv.resize(array, shape[::-1], interpolation=interpolation)
3601  if array.ndim == 3 and result.ndim == 2:
3602    assert array.shape[2] == 1
3603    return result[..., None]  # Add back the last dimension dropped by cv.resize().
3604  return result

Invoke cv.resize using the same parameters as resize.

def scipy_ndimage_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, boundary: str = 'reflect', cval: float = 0.0) -> np.ndarray:
3607def scipy_ndimage_resize(
3608    array: _ArrayLike,
3609    /,
3610    shape: Iterable[int],
3611    *,
3612    filter: str,
3613    boundary: str = 'reflect',
3614    cval: float = 0.0,
3615    scale: float | Iterable[float] = 1.0,
3616    translate: float | Iterable[float] = 0.0,
3617) -> _NDArray:
3618  """Invoke `scipy.ndimage.map_coordinates` using the same parameters as `resize`."""
3619  array = np.asarray(array)
3620  shape = tuple(shape)
3621  assert 1 <= len(shape) <= array.ndim
3622  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3623  if filter not in filters:
3624    raise ValueError(f'{filter=} not in {filters=}.')
3625  order = filters[filter]
3626  boundaries = {'reflect': 'reflect', 'wrap': 'grid-wrap', 'clamp': 'nearest', 'border': 'constant'}
3627  if boundary not in boundaries:
3628    raise ValueError(f'{boundary=} not in {boundaries=}.')
3629  mode = boundaries[boundary]
3630  shape_all = shape + array.shape[len(shape) :]
3631  coords = np.moveaxis(np.indices(shape_all, array.dtype), 0, -1)
3632  coords[..., : len(shape)] = (
3633      (coords[..., : len(shape)] + 0.5) / shape - np.asarray(translate)
3634  ) / np.asarray(scale) * np.array(array.shape)[: len(shape)] - 0.5
3635  coords = np.moveaxis(coords, -1, 0)
3636  return scipy.ndimage.map_coordinates(array, coords, order=order, mode=mode, cval=cval)

Invoke scipy.ndimage.map_coordinates using the same parameters as resize.

def skimage_transform_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, boundary: str = 'reflect', cval: float = 0.0) -> np.ndarray:
3639def skimage_transform_resize(
3640    array: _ArrayLike,
3641    /,
3642    shape: Iterable[int],
3643    *,
3644    filter: str,
3645    boundary: str = 'reflect',
3646    cval: float = 0.0,
3647) -> _NDArray:
3648  """Invoke `skimage.transform.resize` using the same parameters as `resize`."""
3649  import skimage.transform
3650
3651  array = np.asarray(array)
3652  shape = tuple(shape)
3653  assert 1 <= len(shape) <= array.ndim
3654  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3655  if filter not in filters:
3656    raise ValueError(f'{filter=} not in {filters=}.')
3657  order = filters[filter]
3658  boundaries = {'reflect': 'symmetric', 'wrap': 'wrap', 'clamp': 'edge', 'border': 'constant'}
3659  if boundary not in boundaries:
3660    raise ValueError(f'{boundary=} not in {boundaries=}.')
3661  mode = boundaries[boundary]
3662  shape_all = shape + array.shape[len(shape) :]
3663  # Default anti_aliasing=None automatically enables (poor) Gaussian prefilter if downsampling.
3664  # clip=False is the default behavior in `resampler` if the output type is non-integer.
3665  return skimage.transform.resize(
3666      array, shape_all, order=order, mode=mode, cval=cval, clip=False
3667  )  # type: ignore[no-untyped-call]

Invoke skimage.transform.resize using the same parameters as resize.

def tf_image_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, antialias: bool = True) -> TensorflowTensor:
3682def tf_image_resize(
3683    array: _ArrayLike,
3684    /,
3685    shape: Iterable[int],
3686    *,
3687    filter: str,
3688    boundary: str = 'natural',
3689    cval: float = 0.0,
3690    antialias: bool = True,
3691) -> _TensorflowTensor:
3692  """Invoke `tf.image.resize` using the same parameters as `resize`."""
3693  import tensorflow as tf
3694
3695  if filter not in _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER:
3696    raise ValueError(f'{filter=} not in {_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER=}.')
3697  if boundary != 'natural':
3698    raise ValueError(f"{boundary=} must equal 'natural'.")
3699  del cval
3700  array2 = tf.convert_to_tensor(array)
3701  ndim = len(array2.shape)
3702  del array
3703  assert 1 <= ndim <= 3
3704  shape = tuple(shape)
3705  _check_eq(len(shape), 2 if ndim >= 2 else 1)
3706  match ndim:
3707    case 1:
3708      return tf_image_resize(array2[None], (1, *shape), filter=filter, antialias=antialias)[0]
3709    case 2:
3710      return tf_image_resize(array2[..., None], shape, filter=filter, antialias=antialias)[..., 0]
3711    case _:
3712      method = _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER[filter]
3713      return tf.image.resize(array2, shape, method=method, antialias=antialias)

Invoke tf.image.resize using the same parameters as resize.

def torch_nn_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, antialias: bool = False) -> TorchTensor:
3724def torch_nn_resize(
3725    array: _ArrayLike,
3726    /,
3727    shape: Iterable[int],
3728    *,
3729    filter: str,
3730    boundary: str = 'clamp',
3731    cval: float = 0.0,
3732    antialias: bool = False,
3733) -> _TorchTensor:
3734  """Invoke `torch.nn.functional.interpolate` using the same parameters as `resize`."""
3735  import torch
3736
3737  if filter not in _TORCH_INTERPOLATE_MODE_FROM_FILTER:
3738    raise ValueError(f'{filter=} not in {_TORCH_INTERPOLATE_MODE_FROM_FILTER=}.')
3739  if boundary != 'clamp':
3740    raise ValueError(f"{boundary=} must equal 'clamp'.")
3741  del cval
3742  a = torch.as_tensor(array)
3743  del array
3744  assert 1 <= a.ndim <= 3
3745  shape = tuple(shape)
3746  _check_eq(len(shape), 2 if a.ndim >= 2 else 1)
3747  mode = _TORCH_INTERPOLATE_MODE_FROM_FILTER[filter]
3748
3749  def local_resize(a: _TorchTensor) -> _TorchTensor:
3750    # For upsampling, BILINEAR antialias is same PSNR and slower,
3751    #  and BICUBIC antialias is worse PSNR and faster.
3752    # For downsampling, antialias improves PSNR for both BILINEAR and BICUBIC.
3753    # Default align_corners=None corresponds to False which is what we desire.
3754    return torch.nn.functional.interpolate(a, shape, mode=mode, antialias=antialias)
3755
3756  match a.ndim:
3757    case 1:
3758      shape = (1, *shape)
3759      return local_resize(a[None, None, None])[0, 0, 0]
3760    case 2:
3761      return local_resize(a[None, None])[0, 0]
3762    case _:
3763      return local_resize(a.moveaxis(2, 0)[None])[0].moveaxis(0, 2)

Invoke torch.nn.functional.interpolate using the same parameters as resize.

def jax_image_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0) -> JaxArray:
3766def jax_image_resize(
3767    array: _ArrayLike,
3768    /,
3769    shape: Iterable[int],
3770    *,
3771    filter: str,
3772    boundary: str = 'natural',
3773    cval: float = 0.0,
3774    scale: float | Iterable[float] = 1.0,
3775    translate: float | Iterable[float] = 0.0,
3776) -> _JaxArray:
3777  """Invoke `jax.image.scale_and_translate` using the same parameters as `resize`."""
3778  import jax.image
3779  import jax.numpy as jnp
3780
3781  filters = 'triangle cubic lanczos3 lanczos5'.split()
3782  if filter not in filters:
3783    raise ValueError(f'{filter=} not in {filters=}.')
3784  if boundary != 'natural':
3785    raise ValueError(f"{boundary=} must equal 'natural'.")
3786  # When `scale` or `translate` are applied, any region outside the unit domain is assigned value 0.
3787  # To be consistent, the parameter `cval` must be zero.
3788  if scale != 1.0 and cval != 0.0:
3789    raise ValueError(f'Non-unity {scale=} requires that {cval=} be zero.')
3790  if translate != 0.0 and cval != 0.0:
3791    raise ValueError(f'Nonzero {translate=} requires that {cval=} be zero.')
3792  array2 = jnp.asarray(array)
3793  del array
3794  shape = tuple(shape)
3795  assert len(shape) <= array2.ndim
3796  completed_shape = shape + (1,) * (array2.ndim - len(shape))
3797  spatial_dims = list(range(len(shape)))
3798  scale2 = np.broadcast_to(np.array(scale), len(shape))
3799  scale2 = scale2 / np.array(array2.shape[: len(shape)]) * np.array(shape)
3800  translate2 = np.broadcast_to(np.array(translate), len(shape))
3801  translate2 = translate2 * np.array(shape)
3802  return jax.image.scale_and_translate(
3803      array2, completed_shape, spatial_dims, scale2, translate2, filter
3804  )

Invoke jax.image.scale_and_translate using the same parameters as resize.