resampler

resampler: fast differentiable resizing and warping of arbitrary grids.

   1"""resampler: fast differentiable resizing and warping of arbitrary grids.
   2
   3.. include:: ../README.md
   4"""
   5
   6__docformat__ = 'google'
   7__version__ = '1.0.0'
   8__version_info__ = tuple(int(num) for num in __version__.split('.'))
   9
  10from collections.abc import Callable, Iterable, Sequence
  11import abc
  12import dataclasses
  13import functools
  14import importlib
  15import itertools
  16import math
  17import os
  18import sys
  19import types
  20import typing
  21from typing import Any, Generic, Literal, Union
  22
  23import numpy as np
  24import numpy.typing
  25import scipy.interpolate
  26import scipy.linalg
  27import scipy.ndimage
  28import scipy.sparse
  29import scipy.sparse.linalg
  30
  31
  32def _noop_decorator(*args: Any, **kwargs: Any) -> Any:
  33  """Return function decorated with no-operation; invocable with or without args."""
  34  if len(args) != 1 or not callable(args[0]) or kwargs:
  35    return _noop_decorator  # Decorator is invoked with arguments; ignore them.
  36  func: Callable[..., Any] = args[0]
  37  return func
  38
  39
  40try:
  41  import numba
  42except ModuleNotFoundError:
  43  numba = sys.modules['numba'] = types.ModuleType('numba')
  44  numba.njit = _noop_decorator
  45using_numba = hasattr(numba, 'jit')
  46
  47if TYPE_CHECKING:
  48  import jax.numpy
  49  import tensorflow as tf
  50  import torch
  51
  52  _DType = np.dtype[Any]  # (Requires Python 3.9 or TYPE_CHECKING.)
  53  _NDArray = numpy.NDArray[Any]
  54  _DTypeLike = numpy.DTypeLike
  55  _ArrayLike = numpy.ArrayLike
  56  _TensorflowTensor: TypeAlias = tf.Tensor
  57  _TorchTensor: TypeAlias = torch.Tensor
  58  _JaxArray: TypeAlias = jax.numpy.ndarray
  59
  60else:
  61  # Typically, create named types for use in the `pdoc` documentation.
  62  # But here, these are superseded by the declarations in __init__.pyi!
  63  _DType = Any
  64  _NDArray = Any
  65  _DTypeLike = Any
  66  _ArrayLike = Any
  67  _TensorflowTensor = Any
  68  _TorchTensor = Any
  69  _JaxArray = Any
  70
  71_Array = TypeVar('_Array', _NDArray, _TensorflowTensor, _TorchTensor, _JaxArray)
  72_AnyArray = Union[_NDArray, _TensorflowTensor, _TorchTensor, _JaxArray]
  73
  74
  75def _check_eq(a: Any, b: Any, /) -> None:
  76  """If the two values or arrays are not equal, raise an exception with a useful message."""
  77  are_equal = np.all(a == b) if isinstance(a, np.ndarray) else a == b
  78  if not are_equal:
  79    raise AssertionError(f'{a!r} == {b!r}')
  80
  81
  82def _real_precision(dtype: _DTypeLike, /) -> _DType:
  83  """Return the type of the real part of a complex number."""
  84  return np.array([], dtype).real.dtype
  85
  86
  87def _complex_precision(dtype: _DTypeLike, /) -> _DType:
  88  """Return a complex type to represent a non-complex type."""
  89  return np.result_type(dtype, np.complex64)
  90
  91
  92def _get_precision(
  93    precision: _DTypeLike, dtypes: list[_DType], weight_dtypes: list[_DType], /
  94) -> _DType:
  95  """Return dtype based on desired precision or on data and weight types."""
  96  precision2 = np.dtype(
  97      precision if precision is not None else np.result_type(np.float32, *dtypes, *weight_dtypes)
  98  )
  99  if not np.issubdtype(precision2, np.inexact):
 100    raise ValueError(f'Precision {precision2} is not floating or complex.')
 101  check_complex = [precision2, *dtypes]
 102  is_complex = [np.issubdtype(dtype, np.complexfloating) for dtype in check_complex]
 103  if len(set(is_complex)) != 1:
 104    s_types = ','.join(str(dtype) for dtype in check_complex)
 105    raise ValueError(f'Types {s_types} must be all real or all complex.')
 106  return precision2
 107
 108
 109def _sinc(x: _ArrayLike, /) -> _NDArray:
 110  """Return the value `np.sinc(x)` but improved to:
 111  (1) ignore underflow that occurs at 0.0 for np.float32, and
 112  (2) output exact zero for integer input values.
 113
 114  >>> _sinc(np.array([-3, -2, -1, 0], np.float32))
 115  array([0., 0., 0., 1.], dtype=float32)
 116
 117  >>> _sinc(np.array([-3, -2, -1, 0]))
 118  array([0., 0., 0., 1.])
 119
 120  >>> _sinc(0)
 121  1.0
 122  """
 123  x = np.asarray(x)
 124  x_is_scalar = x.ndim == 0
 125  with np.errstate(under='ignore'):
 126    result = np.sinc(np.atleast_1d(x))
 127    result[x == np.floor(x)] = 0.0
 128    result[x == 0] = 1.0
 129    return result.item() if x_is_scalar else result
 130
 131
 132def _is_symmetric(matrix: scipy.sparse.spmatrix, /, tol: float = 1e-10) -> bool:
 133  """Return True if the sparse matrix is symmetric."""
 134  norm: float = scipy.sparse.linalg.norm(matrix - matrix.T, np.inf)
 135  return norm <= tol
 136
 137
 138def _cache_sampled_1d_function(
 139    xmin: float,
 140    xmax: float,
 141    *,
 142    num_samples: int = 3_600,
 143    enable: bool = True,
 144) -> Callable[[Callable[[_ArrayLike], _NDArray]], Callable[[_ArrayLike], _NDArray]]:
 145  """Function decorator to linearly interpolate cached function values."""
 146  # Speed unchanged up to num_samples=12_000, then slow decrease until 100_000.
 147
 148  def wrap_it(func: Callable[[_ArrayLike], _NDArray]) -> Callable[[_ArrayLike], _NDArray]:
 149    if not enable:
 150      return func
 151
 152    dx = (xmax - xmin) / num_samples
 153    x = np.linspace(xmin, xmax + dx, num_samples + 2, dtype=np.float32)
 154    samples_func = func(x)
 155    assert np.all(samples_func[[0, -1, -2]] == 0.0)
 156
 157    @functools.wraps(func)
 158    def interpolate_using_cached_samples(x: _ArrayLike) -> _NDArray:
 159      x = np.asarray(x)
 160      index_float = np.clip((x - xmin) / dx, 0.0, num_samples)
 161      index = index_float.astype(np.int64)
 162      frac = np.subtract(index_float, index, dtype=np.float32)
 163      return (1 - frac) * samples_func[index] + frac * samples_func[index + 1]
 164
 165    return interpolate_using_cached_samples
 166
 167  return wrap_it
 168
 169
 170class _DownsampleIn2dUsingBoxFilter:
 171  """Fast 2D box-filter downsampling using cached numba-jitted functions."""
 172
 173  def __init__(self) -> None:
 174    # Downsampling function for params (dtype, block_height, block_width, ch).
 175    self._jitted_function: dict[tuple[_DType, int, int, int], Callable[[_NDArray], _NDArray]] = {}
 176
 177  def __call__(self, array: _NDArray, shape: tuple[int, int]) -> _NDArray:
 178    assert using_numba
 179    assert array.ndim in (2, 3), array.ndim
 180    _check_eq(len(shape), 2)
 181    dtype = array.dtype
 182    a = array[..., None] if array.ndim == 2 else array
 183    height, width, ch = a.shape
 184    new_height, new_width = shape
 185    if height % new_height != 0 or width % new_width != 0:
 186      raise ValueError(f'Shape {array.shape} not a multiple of {shape}.')
 187    block_height, block_width = height // new_height, width // new_width
 188
 189    def func(array: _NDArray) -> _NDArray:
 190      new_height = array.shape[0] // block_height
 191      new_width = array.shape[1] // block_width
 192      result = np.empty((new_height, new_width, ch), dtype)
 193      totals = np.empty(ch, dtype)
 194      factor = dtype.type(1.0 / (block_height * block_width))
 195      for y in numba.prange(new_height):  # pylint: disable=not-an-iterable
 196        for x in range(new_width):
 197          # Introducing "y2, x2 = y * block_height, x * block_width" is actually slower.
 198          if ch == 1:  # All the branches involve compile-time constants.
 199            total = dtype.type(0.0)
 200            for yy in range(block_height):
 201              for xx in range(block_width):
 202                total += array[y * block_height + yy, x * block_width + xx, 0]
 203            result[y, x, 0] = total * factor
 204          elif ch == 3:
 205            total0 = total1 = total2 = dtype.type(0.0)
 206            for yy in range(block_height):
 207              for xx in range(block_width):
 208                total0 += array[y * block_height + yy, x * block_width + xx, 0]
 209                total1 += array[y * block_height + yy, x * block_width + xx, 1]
 210                total2 += array[y * block_height + yy, x * block_width + xx, 2]
 211            result[y, x, 0] = total0 * factor
 212            result[y, x, 1] = total1 * factor
 213            result[y, x, 2] = total2 * factor
 214          elif block_height * block_width >= 9:
 215            for c in range(ch):
 216              totals[c] = 0.0
 217            for yy in range(block_height):
 218              for xx in range(block_width):
 219                for c in range(ch):
 220                  totals[c] += array[y * block_height + yy, x * block_width + xx, c]
 221            for c in range(ch):
 222              result[y, x, c] = totals[c] * factor
 223          else:
 224            for c in range(ch):
 225              total = dtype.type(0.0)
 226              for yy in range(block_height):
 227                for xx in range(block_width):
 228                  total += array[y * block_height + yy, x * block_width + xx, c]
 229              result[y, x, c] = total * factor
 230      return result
 231
 232    signature = dtype, block_height, block_width, ch
 233    jitted_function = self._jitted_function.get(signature)
 234    if not jitted_function:
 235      if 0:
 236        print(f'Creating numba jit-wrapper for {signature}.')
 237      jitted_function = numba.njit(func, parallel=True, fastmath=True, cache=True)
 238      self._jitted_function[signature] = jitted_function
 239
 240    try:
 241      result = jitted_function(a)
 242    except RuntimeError:
 243      message = (
 244          'resampler: This runtime error may be due to a corrupt resampler/__pycache__;'
 245          ' try deleting that directory.'
 246      )
 247      print(message, file=sys.stdout, flush=True)
 248      print(message, file=sys.stderr, flush=True)
 249      raise
 250
 251    return result[..., 0] if array.ndim == 2 else result
 252
 253
 254_downsample_in_2d_using_box_filter = _DownsampleIn2dUsingBoxFilter()
 255
 256
 257@numba.njit(nogil=True, fastmath=True, cache=True)  # type: ignore[misc]
 258def _numba_serial_csr_dense_mult(
 259    indptr: _NDArray,
 260    indices: _NDArray,
 261    data: _NDArray,
 262    src: _NDArray,
 263    dst: _NDArray,
 264) -> None:
 265  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 266
 267  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 268  """
 269  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 270  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 271  acc = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 272  for i in range(dst.shape[0]):
 273    acc[:] = 0
 274    for jj in range(indptr[i], indptr[i + 1]):
 275      j = indices[jj]
 276      acc += data[jj] * src[j]
 277    dst[i] = acc
 278
 279
 280# I tried using the "minimal" "parallel=" config but this did not result in any jit speedup:
 281#  parallel=dict(comprehension=False, prange=True, numpy=True, reduction=False,
 282#                setitem=False, stencil=False, fusion=False)
 283@numba.njit(parallel=True, fastmath=True, cache=True)  # type: ignore[misc]
 284def _numba_parallel_csr_dense_mult(
 285    indptr: _NDArray,
 286    indices: _NDArray,
 287    data: _NDArray,
 288    src: _NDArray,
 289    dst: _NDArray,
 290) -> None:
 291  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 292
 293  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 294  The introduction of parallel omp threads provides another 2-4x speedup.
 295  However, "parallel=True" leads to slow jitting (~3 s), so we cache the jitted code on disk.
 296  """
 297  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 298  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 299  acc0 = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 300  # Default is static scheduling, which is fine.
 301  for i in numba.prange(dst.shape[0]):  # pylint: disable=not-an-iterable
 302    acc = np.zeros_like(acc0)  # Numba automatically hoists the allocation outside the loop.
 303    for jj in range(indptr[i], indptr[i + 1]):
 304      j = indices[jj]
 305      acc += data[jj] * src[j]
 306    dst[i] = acc
 307
 308
 309@dataclasses.dataclass
 310class _Arraylib(abc.ABC, Generic[_Array]):
 311  """Abstract base class for abstraction of array libraries."""
 312
 313  arraylib: str
 314  """Name of array library (e.g., `'numpy'`, `'tensorflow'`, `'torch'`, `'jax'`)."""
 315
 316  array: _Array
 317
 318  @staticmethod
 319  @abc.abstractmethod
 320  def recognize(array: Any) -> bool:
 321    """Return True if `array` is recognized by this _Arraylib."""
 322
 323  @abc.abstractmethod
 324  def numpy(self) -> _NDArray:
 325    """Return a `numpy` version of `self.array`."""
 326
 327  @abc.abstractmethod
 328  def dtype(self) -> _DType:
 329    """Return the equivalent of `self.array.dtype` as a `numpy` `dtype`."""
 330
 331  @abc.abstractmethod
 332  def astype(self, dtype: _DTypeLike) -> _Array:
 333    """Return the equivalent of `self.array.astype(dtype, copy=False)` with `numpy` `dtype`."""
 334
 335  def reshape(self, shape: tuple[int, ...]) -> _Array:
 336    """Return the equivalent of `self.array.reshape(shape)`."""
 337    return self.array.reshape(shape)
 338
 339  def possibly_make_contiguous(self) -> _Array:
 340    """Return a contiguous copy of `self.array` or just `self.array` if already contiguous."""
 341    return self.array
 342
 343  @abc.abstractmethod
 344  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _Array:
 345    """Return the equivalent of `self.array.clip(low, high, dtype=dtype)` with `numpy` `dtype`."""
 346
 347  @abc.abstractmethod
 348  def square(self) -> _Array:
 349    """Return the equivalent of `np.square(self.array)`."""
 350
 351  @abc.abstractmethod
 352  def sqrt(self) -> _Array:
 353    """Return the equivalent of `np.sqrt(self.array)`."""
 354
 355  def getitem(self, indices: Any) -> _Array:
 356    """Return the equivalent of `self.array[indices]` (a "gather" operation)."""
 357    return self.array[indices]
 358
 359  @abc.abstractmethod
 360  def where(self, if_true: Any, if_false: Any) -> _Array:
 361    """Return the equivalent of `np.where(self.array, if_true, if_false)`."""
 362
 363  @abc.abstractmethod
 364  def transpose(self, axes: Sequence[int]) -> _Array:
 365    """Return the equivalent of `np.transpose(self.array, axes)`."""
 366
 367  @abc.abstractmethod
 368  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 369    """Return the best order in which to process dims for resizing `self.array` to `dst_shape`."""
 370
 371  @abc.abstractmethod
 372  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _Array:
 373    """Return the multiplication of the `sparse` matrix and `self.array`."""
 374
 375  @staticmethod
 376  @abc.abstractmethod
 377  def concatenate(arrays: Sequence[_Array], axis: int) -> _Array:
 378    """Return the equivalent of `np.concatenate(arrays, axis)`."""
 379
 380  @staticmethod
 381  @abc.abstractmethod
 382  def einsum(subscripts: str, *operands: _Array) -> _Array:
 383    """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 384
 385  @staticmethod
 386  @abc.abstractmethod
 387  def make_sparse_matrix(
 388      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 389  ) -> _Array:
 390    """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 391    However, the indices must be ordered and unique."""
 392
 393
 394class _NumpyArraylib(_Arraylib[_NDArray]):
 395  """Numpy implementation of the array abstraction."""
 396
 397  # pylint: disable=missing-function-docstring
 398
 399  def __init__(self, array: _NDArray) -> None:
 400    super().__init__(arraylib='numpy', array=np.asarray(array))
 401
 402  @staticmethod
 403  def recognize(array: Any) -> bool:
 404    return isinstance(array, (np.ndarray, np.number))
 405
 406  def numpy(self) -> _NDArray:
 407    return self.array
 408
 409  def dtype(self) -> _DType:
 410    dtype: _DType = self.array.dtype
 411    return dtype
 412
 413  def astype(self, dtype: _DTypeLike) -> _NDArray:
 414    return self.array.astype(dtype, copy=False)
 415
 416  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _NDArray:
 417    return self.array.clip(low, high, dtype=dtype)
 418
 419  def square(self) -> _NDArray:
 420    return np.square(self.array)
 421
 422  def sqrt(self) -> _NDArray:
 423    return np.sqrt(self.array)
 424
 425  def where(self, if_true: Any, if_false: Any) -> _NDArray:
 426    condition = self.array
 427    return np.where(condition, if_true, if_false)
 428
 429  def transpose(self, axes: Sequence[int]) -> _NDArray:
 430    return np.transpose(self.array, tuple(axes))
 431
 432  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 433    # Our heuristics: (1) a dimension with small scaling (especially minification) gets priority,
 434    # and (2) timings show preference to resizing dimensions with larger strides first.
 435    # (Of course, tensorflow.Tensor lacks strides, so (2) does not apply.)
 436    # The optimal ordering might be related to the logic in np.einsum_path().  (Unfortunately,
 437    # np.einsum() does not support the sparse multiplications that we require here.)
 438    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 439    strides: Sequence[int] = self.array.strides
 440    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 441
 442    def priority(dim: int) -> float:
 443      scaling = dst_shape[dim] / src_shape[dim]
 444      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 445
 446    return sorted(range(len(src_shape)), key=priority)
 447
 448  def premult_with_sparse(
 449      self, sparse: scipy.sparse.csr_matrix, num_threads: int | Literal['auto']
 450  ) -> _NDArray:
 451    assert self.array.ndim == sparse.ndim == 2 and sparse.shape[1] == self.array.shape[0]
 452    # Empirically faster than with default numba.config.NUMBA_NUM_THREADS (e.g., 24).
 453    if using_numba:
 454      num_threads2 = min(6, os.cpu_count()) if num_threads == 'auto' else num_threads
 455      src = np.ascontiguousarray(self.array)  # Like .ravel() in _mul_multivector().
 456      dtype = np.result_type(sparse.dtype, src.dtype)
 457      dst = np.empty((sparse.shape[0], src.shape[1]), dtype)
 458      num_scalar_multiplies = len(sparse.data) * src.shape[1]
 459      is_small_size = num_scalar_multiplies < 200_000
 460      if is_small_size or num_threads2 == 1:
 461        _numba_serial_csr_dense_mult(sparse.indptr, sparse.indices, sparse.data, src, dst)
 462      else:
 463        numba.set_num_threads(num_threads2)
 464        _numba_parallel_csr_dense_mult(sparse.indptr, sparse.indices, sparse.data, src, dst)
 465      return dst
 466
 467    # Note that sicpy.sparse does not use multithreading.  The "@" operation
 468    # calls _spbase.__matmul__() -> _spbase._mul_dispatch() -> _cs_matrix._mul_multivector() ->
 469    # scipy.sparse._sparsetools.csr_matvecs() in
 470    # https://github.com/scipy/scipy/blob/main/scipy/sparse/sparsetools/csr.h
 471    # which iteratively calls the (in theory, LEVEL 1 BLAS) function axpy() in
 472    # https://github.com/scipy/scipy/blob/main/scipy/sparse/sparsetools/dense.h
 473    return sparse @ self.array
 474
 475  @staticmethod
 476  def concatenate(arrays: Sequence[_NDArray], axis: int) -> _NDArray:
 477    return np.concatenate(arrays, axis)
 478
 479  @staticmethod
 480  def einsum(subscripts: str, *operands: _NDArray) -> _NDArray:
 481    return np.einsum(subscripts, *operands, optimize=True)
 482
 483  @staticmethod
 484  def make_sparse_matrix(
 485      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 486  ) -> _NDArray:
 487    return scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=shape)
 488
 489
 490class _TensorflowArraylib(_Arraylib[_TensorflowTensor]):
 491  """Tensorflow implementation of the array abstraction."""
 492
 493  def __init__(self, array: _NDArray) -> None:
 494    import tensorflow
 495
 496    self.tf = tensorflow
 497    super().__init__(arraylib='tensorflow', array=self.tf.convert_to_tensor(array))
 498
 499  @staticmethod
 500  def recognize(array: Any) -> bool:
 501    # Eager: tensorflow.python.framework.ops.Tensor
 502    # Non-eager: tensorflow.python.ops.resource_variable_ops.ResourceVariable
 503    return type(array).__module__.startswith('tensorflow.')
 504
 505  def numpy(self) -> _NDArray:
 506    return self.array.numpy()
 507
 508  def dtype(self) -> _DType:
 509    return np.dtype(self.array.dtype.as_numpy_dtype)
 510
 511  def astype(self, dtype: _DTypeLike) -> _TensorflowTensor:
 512    return self.tf.cast(self.array, dtype)
 513
 514  def reshape(self, shape: tuple[int, ...]) -> _TensorflowTensor:
 515    return self.tf.reshape(self.array, shape)
 516
 517  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _TensorflowTensor:
 518    array = self.array
 519    if dtype is not None:
 520      array = self.tf.cast(array, dtype)
 521    return self.tf.clip_by_value(array, low, high)
 522
 523  def square(self) -> _TensorflowTensor:
 524    return self.tf.square(self.array)
 525
 526  def sqrt(self) -> _TensorflowTensor:
 527    return self.tf.sqrt(self.array)
 528
 529  def getitem(self, indices: Any) -> _TensorflowTensor:
 530    if isinstance(indices, tuple):
 531      basic = all(isinstance(x, (type(None), type(Ellipsis), int, slice)) for x in indices)
 532      if not basic:
 533        # We require tf.gather_nd(), which unfortunately requires broadcast expansion of indices.
 534        assert all(isinstance(a, np.ndarray) for a in indices)
 535        assert all(a.ndim == indices[0].ndim for a in indices)
 536        broadcast_indices = np.broadcast_arrays(*indices)  # list of np.ndarray
 537        indices_array = np.moveaxis(np.array(broadcast_indices), 0, -1)
 538        return self.tf.gather_nd(self.array, indices_array)
 539    elif _arr_dtype(indices).type in (np.uint8, np.uint16):
 540      indices = self.tf.cast(indices, np.int32)
 541    return self.tf.gather(self.array, indices)
 542
 543  def where(self, if_true: Any, if_false: Any) -> _TensorflowTensor:
 544    condition = self.array
 545    return self.tf.where(condition, if_true, if_false)
 546
 547  def transpose(self, axes: Sequence[int]) -> _TensorflowTensor:
 548    return self.tf.transpose(self.array, tuple(axes))
 549
 550  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 551    # Note that a tensorflow.Tensor does not have strides.
 552    # Our heuristic is to process dimension 1 first iff dimension 0 is upsampling.  Improve?
 553    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 554    dims = list(range(len(src_shape)))
 555    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 556      dims[:2] = [1, 0]
 557    return dims
 558
 559  def premult_with_sparse(
 560      self, sparse: 'tf.sparse.SparseTensor', num_threads: int | Literal['auto']
 561  ) -> _TensorflowTensor:
 562    import tensorflow as tf
 563
 564    del num_threads
 565    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 566      sparse = sparse.with_values(_arr_astype(sparse.values, _arr_dtype(self.array)))
 567    return tf.sparse.sparse_dense_matmul(sparse, self.array)
 568
 569  @staticmethod
 570  def concatenate(arrays: Sequence[_TensorflowTensor], axis: int) -> _TensorflowTensor:
 571    import tensorflow as tf
 572
 573    return tf.concat(arrays, axis)
 574
 575  @staticmethod
 576  def einsum(subscripts: str, *operands: _TensorflowTensor) -> _TensorflowTensor:
 577    import tensorflow as tf
 578
 579    return tf.einsum(subscripts, *operands, optimize='greedy')
 580
 581  @staticmethod
 582  def make_sparse_matrix(
 583      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 584  ) -> _TensorflowTensor:
 585    import tensorflow as tf
 586
 587    indices = np.vstack((row_ind, col_ind)).T
 588    return tf.sparse.SparseTensor(indices, data, shape)
 589
 590
 591class _TorchArraylib(_Arraylib[_TorchTensor]):
 592  """Torch implementation of the array abstraction."""
 593
 594  # pylint: disable=missing-function-docstring
 595
 596  def __init__(self, array: _NDArray) -> None:
 597    import torch
 598
 599    self.torch = torch
 600    super().__init__(arraylib='torch', array=self.torch.as_tensor(array))
 601
 602  @staticmethod
 603  def recognize(array: Any) -> bool:
 604    return type(array).__module__ == 'torch'
 605
 606  def numpy(self) -> _NDArray:
 607    return self.array.numpy()
 608
 609  def dtype(self) -> _DType:
 610    numpy_type = {
 611        self.torch.float32: np.float32,
 612        self.torch.float64: np.float64,
 613        self.torch.complex64: np.complex64,
 614        self.torch.complex128: np.complex128,
 615        self.torch.uint8: np.uint8,  # No uint16, uint32, uint64.
 616        self.torch.int16: np.int16,
 617        self.torch.int32: np.int32,
 618        self.torch.int64: np.int64,
 619    }[self.array.dtype]
 620    return np.dtype(numpy_type)
 621
 622  def astype(self, dtype: _DTypeLike) -> _TorchTensor:
 623    torch_type = {
 624        np.float32: self.torch.float32,
 625        np.float64: self.torch.float64,
 626        np.complex64: self.torch.complex64,
 627        np.complex128: self.torch.complex128,
 628        np.uint8: self.torch.uint8,  # No uint16, uint32, uint64.
 629        np.int16: self.torch.int16,
 630        np.int32: self.torch.int32,
 631        np.int64: self.torch.int64,
 632    }[np.dtype(dtype).type]
 633    return self.array.type(torch_type)
 634
 635  def possibly_make_contiguous(self) -> _TorchTensor:
 636    return self.array.contiguous()
 637
 638  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _TorchTensor:
 639    array = self.array
 640    array = _arr_astype(array, dtype) if dtype is not None else array
 641    return array.clip(low, high)
 642
 643  def square(self) -> _TorchTensor:
 644    return self.array.square()
 645
 646  def sqrt(self) -> _TorchTensor:
 647    return self.array.sqrt()
 648
 649  def getitem(self, indices: Any) -> _TorchTensor:
 650    if not isinstance(indices, tuple):
 651      indices = indices.type(self.torch.int64)
 652    return self.array[indices]  # pylint: disable=unsubscriptable-object
 653
 654  def where(self, if_true: Any, if_false: Any) -> _TorchTensor:
 655    condition = self.array
 656    return if_true.where(condition, if_false)
 657
 658  def transpose(self, axes: Sequence[int]) -> _TorchTensor:
 659    return self.torch.permute(self.array, tuple(axes))
 660
 661  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 662    # Similar to `_NumpyArraylib`.  We access `array.stride()` instead of `array.strides`.
 663    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 664    strides: Sequence[int] = self.array.stride()
 665    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 666
 667    def priority(dim: int) -> float:
 668      scaling = dst_shape[dim] / src_shape[dim]
 669      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 670
 671    return sorted(range(len(src_shape)), key=priority)
 672
 673  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _TorchTensor:
 674    del num_threads
 675    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 676      sparse = _arr_astype(sparse, _arr_dtype(self.array))
 677    return sparse @ self.array  # Calls torch.sparse.mm().
 678
 679  @staticmethod
 680  def concatenate(arrays: Sequence[_TorchTensor], axis: int) -> _TorchTensor:
 681    import torch
 682
 683    return torch.cat(tuple(arrays), axis)
 684
 685  @staticmethod
 686  def einsum(subscripts: str, *operands: _TorchTensor) -> _TorchTensor:
 687    import torch
 688
 689    operands = tuple(torch.as_tensor(operand) for operand in operands)
 690    if any(np.issubdtype(_arr_dtype(array), np.complexfloating) for array in operands):
 691      operands = tuple(
 692          _arr_astype(array, _complex_precision(_arr_dtype(array))) for array in operands
 693      )
 694    return torch.einsum(subscripts, *operands)
 695
 696  @staticmethod
 697  def make_sparse_matrix(
 698      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 699  ) -> _TorchTensor:
 700    import torch
 701
 702    indices = np.vstack((row_ind, col_ind))
 703    return torch.sparse_coo_tensor(torch.as_tensor(indices), torch.as_tensor(data), shape)
 704    # .coalesce() is unnecessary because indices/data are already merged.
 705
 706
 707class _JaxArraylib(_Arraylib[_JaxArray]):
 708  """Jax implementation of the array abstraction."""
 709
 710  def __init__(self, array: _NDArray) -> None:
 711    import jax.numpy
 712
 713    self.jnp = jax.numpy
 714    super().__init__(arraylib='jax', array=self.jnp.asarray(array))
 715
 716  @staticmethod
 717  def recognize(array: Any) -> bool:
 718    # e.g., jaxlib.xla_extension.DeviceArray, jax.interpreters.ad.JVPTracer
 719    return type(array).__module__.startswith(('jaxlib.', 'jax.'))
 720
 721  def numpy(self) -> _NDArray:
 722    # 2023-01-09: jax 0.3.17 "DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead."
 723    # Whereas array.to_py() and np.asarray(array) may return a non-writable np.ndarray,
 724    # np.array(array) always returns a writable array but the copy may be more costly.
 725    # return self.array.to_py()
 726    return np.asarray(self.array)
 727
 728  def dtype(self) -> _DType:
 729    return np.dtype(self.array.dtype)
 730
 731  def astype(self, dtype: _DTypeLike) -> _JaxArray:
 732    return self.array.astype(dtype)  # (copy=False is unavailable)
 733
 734  def possibly_make_contiguous(self) -> _JaxArray:
 735    return self.array.copy()
 736
 737  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _JaxArray:
 738    array = self.array
 739    if dtype is not None:
 740      array = array.astype(dtype)  # (copy=False is unavailable)
 741    return self.jnp.clip(array, low, high)
 742
 743  def square(self) -> _JaxArray:
 744    return self.jnp.square(self.array)
 745
 746  def sqrt(self) -> _JaxArray:
 747    return self.jnp.sqrt(self.array)
 748
 749  def where(self, if_true: Any, if_false: Any) -> _JaxArray:
 750    condition = self.array
 751    return self.jnp.where(condition, if_true, if_false)
 752
 753  def transpose(self, axes: Sequence[int]) -> _JaxArray:
 754    return self.jnp.transpose(self.array, tuple(axes))
 755
 756  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 757    # Jax/XLA does not have strides.  Arrays are contiguous, almost always in C order; see
 758    # https://github.com/google/jax/discussions/7544#discussioncomment-1197038.
 759    # We use a heuristic similar to `_TensorflowArraylib`.
 760    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 761    dims = list(range(len(src_shape)))
 762    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 763      dims[:2] = [1, 0]
 764    return dims
 765
 766  def premult_with_sparse(
 767      self, sparse: 'jax.experimental.sparse.BCOO', num_threads: int | Literal['auto']
 768  ) -> _JaxArray:
 769    del num_threads
 770    return sparse @ self.array  # Calls jax.bcoo_multiply_dense().
 771
 772  @staticmethod
 773  def concatenate(arrays: Sequence[_JaxArray], axis: int) -> _JaxArray:
 774    import jax.numpy as jnp
 775
 776    return jnp.concatenate(arrays, axis)
 777
 778  @staticmethod
 779  def einsum(subscripts: str, *operands: _JaxArray) -> _JaxArray:
 780    import jax.numpy as jnp
 781
 782    return jnp.einsum(subscripts, *operands, optimize='greedy')
 783
 784  @staticmethod
 785  def make_sparse_matrix(
 786      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 787  ) -> _JaxArray:
 788    # https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html
 789    import jax.experimental.sparse
 790
 791    indices = np.vstack((row_ind, col_ind)).T
 792    return jax.experimental.sparse.BCOO(
 793        (data, indices), shape=shape, indices_sorted=True, unique_indices=True
 794    )
 795
 796
 797_CANDIDATE_ARRAYLIBS = {
 798    'numpy': _NumpyArraylib,
 799    'tensorflow': _TensorflowArraylib,
 800    'torch': _TorchArraylib,
 801    'jax': _JaxArraylib,
 802}
 803
 804
 805def is_available(arraylib: str) -> bool:
 806  """Return whether the array library (e.g. 'tensorflow') is available as an installed package."""
 807  return importlib.util.find_spec(arraylib) is not None  # Faster than trying to import it.
 808
 809
 810_DICT_ARRAYLIBS = {
 811    arraylib: cls for arraylib, cls in _CANDIDATE_ARRAYLIBS.items() if is_available(arraylib)
 812}
 813
 814ARRAYLIBS = list(_DICT_ARRAYLIBS)
 815"""Array libraries supported automatically in the resize and resampling operations.
 816
 817- The library is selected automatically based on the type of the `array` function parameter.
 818
 819- The class `_Arraylib` provides library-specific implementations of needed basic functions.
 820
 821- The `_arr_*()` functions dispatch the `_Arraylib` methods based on the array type.
 822"""
 823
 824
 825def _as_arr(array: _Array, /) -> _Arraylib[_Array]:
 826  """Return `array` wrapped as an `_Arraylib` for dispatch of functions."""
 827  for cls in _DICT_ARRAYLIBS.values():
 828    if cls.recognize(array):
 829      return cls(array)  # type: ignore[abstract]
 830  raise ValueError(f'{array} {type(array)} {type(array).__module__} unrecognized by {ARRAYLIBS}.')
 831
 832
 833def _arr_arraylib(array: _Array, /) -> str:
 834  """Return the name of the `Arraylib` representing `array`."""
 835  return _as_arr(array).arraylib
 836
 837
 838def _arr_numpy(array: _Array, /) -> _NDArray:
 839  """Return a `numpy` version of `array`."""
 840  return _as_arr(array).numpy()
 841
 842
 843def _arr_dtype(array: _Array, /) -> _DType:
 844  """Return the equivalent of `array.dtype` as a `numpy` `dtype`."""
 845  return _as_arr(array).dtype()
 846
 847
 848def _arr_astype(array: _Array, dtype: _DTypeLike, /) -> _Array:
 849  """Return the equivalent of `array.astype(dtype)` with `numpy` `dtype`."""
 850  return _as_arr(array).astype(dtype)
 851
 852
 853def _arr_reshape(array: _Array, shape: tuple[int, ...], /) -> _Array:
 854  """Return the equivalent of `array.reshape(shape)."""
 855  return _as_arr(array).reshape(shape)
 856
 857
 858def _arr_possibly_make_contiguous(array: _Array, /) -> _Array:
 859  """Return a contiguous copy of `array` or just `array` if already contiguous."""
 860  return _as_arr(array).possibly_make_contiguous()
 861
 862
 863def _arr_clip(array: _Array, low: _Array, high: _Array, /, dtype: _DTypeLike = None) -> _Array:
 864  """Return the equivalent of `array.clip(low, high, dtype)` with `numpy` `dtype`."""
 865  return _as_arr(array).clip(low, high, dtype)
 866
 867
 868def _arr_square(array: _Array, /) -> _Array:
 869  """Return the equivalent of `np.square(array)`."""
 870  return _as_arr(array).square()
 871
 872
 873def _arr_sqrt(array: _Array, /) -> _Array:
 874  """Return the equivalent of `np.sqrt(array)`."""
 875  return _as_arr(array).sqrt()
 876
 877
 878def _arr_getitem(array: _Array, indices: _Array, /) -> _Array:
 879  """Return the equivalent of `array[indices]`."""
 880  return _as_arr(array).getitem(indices)
 881
 882
 883def _arr_where(condition: _Array, if_true: _Array, if_false: _Array, /) -> _Array:
 884  """Return the equivalent of `np.where(condition, if_true, if_false)`."""
 885  return _as_arr(condition).where(if_true, if_false)
 886
 887
 888def _arr_transpose(array: _Array, axes: Sequence[int], /) -> _Array:
 889  """Return the equivalent of `np.transpose(array, axes)`."""
 890  return _as_arr(array).transpose(axes)
 891
 892
 893def _arr_best_dims_order_for_resize(array: _Array, dst_shape: tuple[int, ...], /) -> list[int]:
 894  """Return the best order in which to process dims for resizing `array` to `dst_shape`."""
 895  return _as_arr(array).best_dims_order_for_resize(dst_shape)
 896
 897
 898def _arr_matmul_sparse_dense(
 899    sparse: Any, dense: _Array, /, *, num_threads: int | Literal['auto'] = 'auto'
 900) -> _Array:
 901  """Return the multiplication of the `sparse` and `dense` matrices."""
 902  assert num_threads == 'auto' or num_threads >= 1
 903  return _as_arr(dense).premult_with_sparse(sparse, num_threads)
 904
 905
 906def _arr_concatenate(arrays: Sequence[_Array], axis: int, /) -> _Array:
 907  """Return the equivalent of `np.concatenate(arrays, axis)`."""
 908  arraylib = _arr_arraylib(arrays[0])
 909  return _DICT_ARRAYLIBS[arraylib].concatenate(arrays, axis)  # type: ignore
 910
 911
 912def _arr_einsum(subscripts: str, /, *operands: _Array) -> _Array:
 913  """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 914  arraylib = _arr_arraylib(operands[0])
 915  return _DICT_ARRAYLIBS[arraylib].einsum(subscripts, *operands)  # type: ignore
 916
 917
 918def _arr_swapaxes(array: _Array, axis1: int, axis2: int, /) -> _Array:
 919  """Return the equivalent of `np.swapaxes(array, axis1, axis2)`."""
 920  ndim = len(array.shape)
 921  assert 0 <= axis1 < ndim and 0 <= axis2 < ndim, (axis1, axis2, ndim)
 922  axes = list(range(ndim))
 923  axes[axis1] = axis2
 924  axes[axis2] = axis1
 925  return _arr_transpose(array, axes)
 926
 927
 928def _arr_moveaxis(array: _Array, source: int, destination: int, /) -> _Array:
 929  """Return the equivalent of `np.moveaxis(array, source, destination)`."""
 930  ndim = len(array.shape)
 931  assert 0 <= source < ndim and 0 <= destination < ndim, (source, destination, ndim)
 932  axes = [n for n in range(ndim) if n != source]
 933  axes.insert(destination, source)
 934  return _arr_transpose(array, axes)
 935
 936
 937def _make_sparse_matrix(
 938    data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int], arraylib: str, /
 939) -> Any:
 940  """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 941  However, indices must be ordered and unique."""
 942  return _DICT_ARRAYLIBS[arraylib].make_sparse_matrix(data, row_ind, col_ind, shape)
 943
 944
 945def _make_array(array: _ArrayLike, arraylib: str, /) -> Any:
 946  """Return an array from the library `arraylib` initialized with the `numpy` `array`."""
 947  return _DICT_ARRAYLIBS[arraylib](np.asarray(array)).array  # type: ignore[abstract]
 948
 949
 950# Because np.ndarray supports strides, np.moveaxis() and np.permute() are constant-time.
 951# However, ndarray.reshape() often creates a copy of the array if the data is non-contiguous,
 952# e.g. dim=1 in an RGB image.
 953#
 954# In contrast, tf.Tensor does not support strides, so tf.transpose() returns a new permuted
 955# tensor.  However, tf.reshape() is always efficient.
 956
 957
 958def _block_shape_with_min_size(
 959    shape: tuple[int, ...], min_size: int, compact: bool = True
 960) -> tuple[int, ...]:
 961  """Return shape of block (of size at least `min_size`) to subdivide shape."""
 962  if math.prod(shape) < min_size:
 963    raise ValueError(f'Shape {shape} smaller than min_size {min_size}.')
 964  if compact:
 965    root = int(math.ceil(min_size ** (1 / len(shape))))
 966    block_shape = np.minimum(shape, root)
 967    for dim in range(len(shape)):
 968      if block_shape[dim] == 2 and block_shape.prod() >= min_size * 2:
 969        block_shape[dim] = 1
 970    for dim in range(len(shape) - 1, -1, -1):
 971      if block_shape.prod() < min_size:
 972        block_shape[dim] = shape[dim]
 973  else:
 974    block_shape = np.ones_like(shape)
 975    for dim in range(len(shape) - 1, -1, -1):
 976      if block_shape.prod() < min_size:
 977        block_shape[dim] = min(shape[dim], math.ceil(min_size / block_shape.prod()))
 978  return tuple(block_shape)
 979
 980
 981def _array_split(array: _Array, axis: int, num_sections: int) -> list[Any]:
 982  """Split `array` into `num_sections` along `axis`."""
 983  assert 0 <= axis < len(array.shape)
 984  assert 1 <= num_sections <= array.shape[axis]
 985
 986  if 0:
 987    split = np.array_split(array, num_sections, axis=axis)  # Numpy-specific.
 988
 989  else:
 990    # Adapted from https://github.com/numpy/numpy/blob/main/numpy/lib/shape_base.py#L739-L792.
 991    num_total = array.shape[axis]
 992    num_each, num_extra = divmod(num_total, num_sections)
 993    section_sizes = [0] + num_extra * [num_each + 1] + (num_sections - num_extra) * [num_each]
 994    div_points = np.array(section_sizes).cumsum()
 995    split = []
 996    tmp = _arr_swapaxes(array, axis, 0)
 997    for i in range(num_sections):
 998      split.append(_arr_swapaxes(tmp[div_points[i] : div_points[i + 1]], axis, 0))
 999
1000  return split
1001
1002
1003def _split_array_into_blocks(array: _Array, block_shape: Sequence[int], start_axis: int = 0) -> Any:
1004  """Split `array` into nested lists of blocks of size at most `block_shape`."""
1005  # See https://stackoverflow.com/a/50305924.  (If the block_shape is known to
1006  # exactly partition the array, see https://stackoverflow.com/a/16858283.)
1007  if len(block_shape) > len(array.shape):
1008    raise ValueError(f'Block ndim {len(block_shape)} > array ndim {len(array.shape)}.')
1009  if start_axis == len(block_shape):
1010    return array
1011
1012  num_sections = math.ceil(array.shape[start_axis] / block_shape[start_axis])
1013  split = _array_split(array, start_axis, num_sections)
1014  return [_split_array_into_blocks(split_a, block_shape, start_axis + 1) for split_a in split]
1015
1016
1017def _map_function_over_blocks(blocks: Any, func: Callable[[Any], Any]) -> Any:
1018  """Apply `func` to each block in the nested lists of `blocks`."""
1019  if isinstance(blocks, list):
1020    return [_map_function_over_blocks(block, func) for block in blocks]
1021  return func(blocks)
1022
1023
1024def _merge_array_from_blocks(blocks: Any, axis: int = 0) -> Any:
1025  """Merge an array from the nested lists of array blocks in `blocks`."""
1026  # More general than np.block() because the blocks can have additional dims.
1027  if isinstance(blocks, list):
1028    new_blocks = [_merge_array_from_blocks(block, axis + 1) for block in blocks]
1029    return _arr_concatenate(new_blocks, axis)
1030  return blocks
1031
1032
1033@dataclasses.dataclass(frozen=True)
1034class Gridtype(abc.ABC):
1035  """Abstract base class for grid-types such as `'dual'` and `'primal'`.
1036
1037  In resampling operations, the grid-type may be specified separately as `src_gridtype` for the
1038  source domain and `dst_gridtype` for the destination domain.  Moreover, the grid-type may be
1039  specified per domain dimension.
1040
1041  Examples:
1042    `resize(source, shape, gridtype='primal')`  # Sets both src and dst to be `'primal'` grids.
1043
1044    `resize(source, shape, src_gridtype=['dual', 'primal'],
1045            dst_gridtype='dual')`  # Source is `'dual'` in dim0 and `'primal'` in dim1.
1046  """
1047
1048  name: str
1049  """Gridtype name."""
1050
1051  @abc.abstractmethod
1052  def min_size(self) -> int:
1053    """Return the necessary minimum number of grid samples."""
1054
1055  @abc.abstractmethod
1056  def size_in_samples(self, size: int, /) -> int:
1057    """Return the domain size in units of inter-sample spacing."""
1058
1059  @abc.abstractmethod
1060  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1061    """Return [0.0, 1.0] coordinates given [0, size - 1] indices."""
1062
1063  @abc.abstractmethod
1064  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1065    """Return location x given coordinates [0.0, 1.0], where x == 0.0 is the first grid sample
1066    and x == size - 1.0 is the last grid sample."""
1067
1068  @abc.abstractmethod
1069  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1070    """Map integer sample indices to interior ones using boundary reflection."""
1071
1072  @abc.abstractmethod
1073  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1074    """Map integer sample indices to interior ones using wrapping."""
1075
1076  @abc.abstractmethod
1077  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1078    """Map integer sample indices to interior ones using reflect-clamp."""
1079
1080
1081class DualGridtype(Gridtype):
1082  """Samples are at the center of cells in a uniform partition of the domain.
1083
1084  For a unit-domain dimension with N samples, each sample 0 <= i < N has position (i + 0.5) / N,
1085  e.g., [0.125, 0.375, 0.625, 0.875] for N = 4.
1086  """
1087
1088  def __init__(self) -> None:
1089    super().__init__(name='dual')
1090
1091  def min_size(self) -> int:
1092    return 1
1093
1094  def size_in_samples(self, size: int, /) -> int:
1095    return size
1096
1097  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1098    return (index + 0.5) / size
1099
1100  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1101    return point * size - 0.5
1102
1103  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1104    index = np.mod(index, size * 2)
1105    return np.where(index < size, index, 2 * size - 1 - index)
1106
1107  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1108    return np.mod(index, size)
1109
1110  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1111    return np.minimum(np.where(index < 0, -1 - index, index), size - 1)
1112
1113
1114class PrimalGridtype(Gridtype):
1115  """Samples are at the vertices of cells in a uniform partition of the domain.
1116
1117  For a unit-domain dimension with N samples, each sample 0 <= i < N has position i / (N - 1),
1118  e.g., [0, 1/3, 2/3, 1] for N = 4.
1119  """
1120
1121  def __init__(self) -> None:
1122    super().__init__(name='primal')
1123
1124  def min_size(self) -> int:
1125    return 2
1126
1127  def size_in_samples(self, size: int, /) -> int:
1128    return size - 1
1129
1130  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1131    return index / (size - 1)
1132
1133  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1134    return point * (size - 1)
1135
1136  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1137    index = np.mod(index, size * 2 - 2)
1138    return np.where(index < size, index, 2 * size - 2 - index)
1139
1140  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1141    return np.mod(index, size - 1)
1142
1143  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1144    return np.minimum(np.abs(index), size - 1)
1145
1146
1147_DICT_GRIDTYPES = {
1148    'dual': DualGridtype(),
1149    'primal': PrimalGridtype(),
1150}
1151
1152GRIDTYPES = list(_DICT_GRIDTYPES)
1153r"""Shortcut names for the two predefined grid types (specified per dimension):
1154
1155| `gridtype` | `'dual'`<br/>`DualGridtype()`<br/>(default) | `'primal'`<br/>`PrimalGridtype()`<br/>&nbsp; |
1156| --- |:---:|:---:|
1157| Sample positions in 2D<br/>and in 1D at different resolutions | ![Dual](https://github.com/hhoppe/resampler/raw/main/media/dual_grid_small.png) | ![Primal](https://github.com/hhoppe/resampler/raw/main/media/primal_grid_small.png) |
1158| Nesting of samples across resolutions | The samples positions do *not* nest. | The *even* samples remain at coarser scale. |
1159| Number $N_\ell$ of samples (per-dimension) at resolution level $\ell$ | $N_\ell=2^\ell$ | $N_\ell=2^\ell+1$ |
1160| Position of sample index $i$ within domain $[0, 1]$ | $\frac{i + 0.5}{N}$ ("half-integer" coordinates) | $\frac{i}{N-1}$ |
1161| Image resolutions ($N_\ell\times N_\ell$) for dyadic scales | $1\times1, ~~2\times2, ~~4\times4, ~~8\times8, ~\ldots$ | $2\times2, ~~3\times3, ~~5\times5, ~~9\times9, ~\ldots$ |
1162
1163See the source code for extensibility.
1164"""
1165
1166
1167def _get_gridtype(gridtype: str | Gridtype) -> Gridtype:
1168  """Return a `Gridtype`, which can be specified as a name in `GRIDTYPES`."""
1169  return gridtype if isinstance(gridtype, Gridtype) else _DICT_GRIDTYPES[gridtype]
1170
1171
1172def _get_gridtypes(
1173    gridtype: str | Gridtype | None,
1174    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1175    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1176    src_ndim: int,
1177    dst_ndim: int,
1178) -> tuple[list[Gridtype], list[Gridtype]]:
1179  """Return per-dim source and destination grid types given all parameters."""
1180  if gridtype is None and src_gridtype is None and dst_gridtype is None:
1181    gridtype = 'dual'
1182  if gridtype is not None:
1183    if src_gridtype is not None:
1184      raise ValueError('Cannot have both gridtype and src_gridtype.')
1185    if dst_gridtype is not None:
1186      raise ValueError('Cannot have both gridtype and dst_gridtype.')
1187    src_gridtype = dst_gridtype = gridtype
1188  src_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(src_gridtype), src_ndim)]
1189  dst_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(dst_gridtype), dst_ndim)]
1190  return src_gridtype2, dst_gridtype2
1191
1192
1193@dataclasses.dataclass(frozen=True)
1194class RemapCoordinates(abc.ABC):
1195  """Abstract base class for modifying the specified coordinates prior to evaluating the
1196  reconstruction kernels."""
1197
1198  @abc.abstractmethod
1199  def __call__(self, point: _NDArray, /) -> _NDArray:
1200    ...
1201
1202
1203class NoRemapCoordinates(RemapCoordinates):
1204  """The coordinates are not remapped."""
1205
1206  def __call__(self, point: _NDArray, /) -> _NDArray:
1207    return point
1208
1209
1210class MirrorRemapCoordinates(RemapCoordinates):
1211  """The coordinates are reflected across the domain boundaries so that they lie in the unit
1212  interval.  The resulting function is continuous but not smooth across the boundaries."""
1213
1214  def __call__(self, point: _NDArray, /) -> _NDArray:
1215    point = np.mod(point, 2.0)
1216    return np.where(point >= 1.0, 2.0 - point, point)
1217
1218
1219class TileRemapCoordinates(RemapCoordinates):
1220  """The coordinates are mapped to the unit interval using a "modulo 1.0" operation.  The resulting
1221  function is generally discontinuous across the domain boundaries."""
1222
1223  def __call__(self, point: _NDArray, /) -> _NDArray:
1224    return np.mod(point, 1.0)
1225
1226
1227@dataclasses.dataclass(frozen=True)
1228class ExtendSamples(abc.ABC):
1229  """Abstract base class for replacing references to grid samples exterior to the unit domain by
1230  affine combinations of interior sample(s) and possibly the constant value (`cval`)."""
1231
1232  uses_cval: bool = False
1233  """True if some exterior samples are defined in terms of `cval`, i.e., if the computed weight
1234  is non-affine."""
1235
1236  @abc.abstractmethod
1237  def __call__(
1238      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1239  ) -> tuple[_NDArray, _NDArray]:
1240    """Detect references to exterior samples, i.e., entries of `index` that lie outside the
1241    interval [0, size), and update these indices (and possibly their associated weights) to
1242    reference only interior samples.  Return `new_index, new_weight`."""
1243
1244
1245class ReflectExtendSamples(ExtendSamples):
1246  """Find the interior sample by reflecting across domain boundaries."""
1247
1248  def __call__(
1249      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1250  ) -> tuple[_NDArray, _NDArray]:
1251    index = gridtype.reflect(index, size)
1252    return index, weight
1253
1254
1255class WrapExtendSamples(ExtendSamples):
1256  """Wrap the interior samples periodically.  For a `'primal'` grid, the last
1257  sample is ignored as its value is replaced by the first sample."""
1258
1259  def __call__(
1260      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1261  ) -> tuple[_NDArray, _NDArray]:
1262    index = gridtype.wrap(index, size)
1263    return index, weight
1264
1265
1266class ClampExtendSamples(ExtendSamples):
1267  """Use the nearest interior sample."""
1268
1269  def __call__(
1270      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1271  ) -> tuple[_NDArray, _NDArray]:
1272    index = index.clip(0, size - 1)
1273    return index, weight
1274
1275
1276class ReflectClampExtendSamples(ExtendSamples):
1277  """Extend the grid samples from [0, 1] into [-1, 0] using reflection and then define grid
1278  samples outside [-1, 1] as that of the nearest sample."""
1279
1280  def __call__(
1281      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1282  ) -> tuple[_NDArray, _NDArray]:
1283    index = gridtype.reflect_clamp(index, size)
1284    return index, weight
1285
1286
1287class BorderExtendSamples(ExtendSamples):
1288  """Let all exterior samples have the constant value (`cval`)."""
1289
1290  def __init__(self) -> None:
1291    super().__init__(uses_cval=True)
1292
1293  def __call__(
1294      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1295  ) -> tuple[_NDArray, _NDArray]:
1296    low = index < 0
1297    weight[low] = 0.0
1298    index[low] = 0
1299    high = index >= size
1300    weight[high] = 0.0
1301    index[high] = size - 1
1302    return index, weight
1303
1304
1305class ValidExtendSamples(ExtendSamples):
1306  """Assign all domain samples weight 1 and all outside samples weight 0.
1307  Compute a weighted reconstruction and divide by the reconstructed weight."""
1308
1309  def __init__(self) -> None:
1310    super().__init__(uses_cval=True)
1311
1312  def __call__(
1313      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1314  ) -> tuple[_NDArray, _NDArray]:
1315    low = index < 0
1316    weight[low] = 0.0
1317    index[low] = 0
1318    high = index >= size
1319    weight[high] = 0.0
1320    index[high] = size - 1
1321    sum_weight = weight.sum(axis=-1)
1322    nonzero_sum = sum_weight != 0.0
1323    np.divide(weight, sum_weight[..., None], out=weight, where=nonzero_sum[..., None])
1324    return index, weight
1325
1326
1327class LinearExtendSamples(ExtendSamples):
1328  """Linearly extrapolate beyond boundary samples."""
1329
1330  def __call__(
1331      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1332  ) -> tuple[_NDArray, _NDArray]:
1333    if size < 2:
1334      index = gridtype.reflect(index, size)
1335      return index, weight
1336    # For each boundary, define new columns in index and weight arrays to represent the last and
1337    # next-to-last samples.  When we later construct the sparse resize matrix, we will sum the
1338    # duplicate index entries.
1339    low = index < 0
1340    high = index >= size
1341    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 4), weight.dtype)
1342    x = index
1343    w[..., -4] = ((1 - x) * weight).sum(where=low, axis=-1)
1344    w[..., -3] = ((x) * weight).sum(where=low, axis=-1)
1345    x = (size - 1) - index
1346    w[..., -2] = ((x) * weight).sum(where=high, axis=-1)
1347    w[..., -1] = ((1 - x) * weight).sum(where=high, axis=-1)
1348    weight[low] = 0.0
1349    index[low] = 0
1350    weight[high] = 0.0
1351    index[high] = size - 1
1352    w[..., :-4] = weight
1353    weight = w
1354    new_index = np.empty(w.shape, index.dtype)
1355    new_index[..., :-4] = index
1356    # Let matrix (including zero values) be banded.
1357    new_index[..., -4:] = np.where(w[..., -4:] != 0.0, [0, 1, size - 2, size - 1], index[..., :1])
1358    index = new_index
1359    return index, weight
1360
1361
1362class QuadraticExtendSamples(ExtendSamples):
1363  """Quadratically extrapolate beyond boundary samples."""
1364
1365  def __call__(
1366      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1367  ) -> tuple[_NDArray, _NDArray]:
1368    # [Keys 1981] suggests this as x[-1] = 3*x[0] - 3*x[1] + x[2], calling it "cubic precision",
1369    # but it seems just quadratic.
1370    if size < 3:
1371      index = gridtype.reflect(index, size)
1372      return index, weight
1373    low = index < 0
1374    high = index >= size
1375    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 6), weight.dtype)
1376    x = index
1377    w[..., -6] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=low, axis=-1)
1378    w[..., -5] = (((-x + 2) * x) * weight).sum(where=low, axis=-1)
1379    w[..., -4] = (((0.5 * x - 0.5) * x) * weight).sum(where=low, axis=-1)
1380    x = (size - 1) - index
1381    w[..., -3] = (((0.5 * x - 0.5) * x) * weight).sum(where=high, axis=-1)
1382    w[..., -2] = (((-x + 2) * x) * weight).sum(where=high, axis=-1)
1383    w[..., -1] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=high, axis=-1)
1384    weight[low] = 0.0
1385    index[low] = 0
1386    weight[high] = 0.0
1387    index[high] = size - 1
1388    w[..., :-6] = weight
1389    weight = w
1390    new_index = np.empty(w.shape, index.dtype)
1391    new_index[..., :-6] = index
1392    # Let matrix (including zero values) be banded.
1393    new_index[..., -6:] = np.where(
1394        w[..., -6:] != 0.0, [0, 1, 2, size - 3, size - 2, size - 1], index[..., :1]
1395    )
1396    index = new_index
1397    return index, weight
1398
1399
1400@dataclasses.dataclass(frozen=True)
1401class OverrideExteriorValue:
1402  """Abstract base class to set the value outside some domain extent to a
1403  constant value (`cval`)."""
1404
1405  boundary_antialiasing: bool = True
1406  """Antialias the pixel values adjacent to the boundary of the extent."""
1407
1408  uses_cval: bool = False
1409  """Modify some weights to introduce references to the `cval` constant value."""
1410
1411  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1412    """For all `point` outside some extent, modify the weight to be zero."""
1413
1414  def override_using_signed_distance(
1415      self, weight: _NDArray, point: _NDArray, signed_distance: _NDArray, /
1416  ) -> None:
1417    """Reduce sample weights for "outside" values based on the signed distance function,
1418    to effectively assign the constant value `cval`."""
1419    all_points_inside_domain = np.all(signed_distance <= 0.0)
1420    if all_points_inside_domain:
1421      return
1422    if self.boundary_antialiasing and min(point.shape) >= 2:
1423      # For discontinuous coordinate mappings, we may need to somehow ignore
1424      # the large finite differences computed across the map discontinuities.
1425      gradient = np.gradient(point)
1426      gradient_norm = np.linalg.norm(np.atleast_2d(gradient), axis=0)
1427      signed_distance_in_samples = signed_distance / (gradient_norm + 1e-20)
1428      # Opacity is in linear space, which is correct if Gamma is set.
1429      opacity = (0.5 - signed_distance_in_samples).clip(0.0, 1.0)
1430      weight *= opacity[..., None]
1431    else:
1432      is_outside = signed_distance > 0.0
1433      weight[is_outside, :] = 0.0
1434
1435
1436class NoOverrideExteriorValue(OverrideExteriorValue):
1437  """The function value is not overridden."""
1438
1439  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1440    pass
1441
1442
1443class UnitDomainOverrideExteriorValue(OverrideExteriorValue):
1444  """Values outside the unit interval [0, 1] are replaced by the constant `cval`."""
1445
1446  def __init__(self, **kwargs: Any) -> None:
1447    super().__init__(uses_cval=True, **kwargs)
1448
1449  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1450    signed_distance = abs(point - 0.5) - 0.5  # Boundaries at 0.0 and 1.0.
1451    self.override_using_signed_distance(weight, point, signed_distance)
1452
1453
1454class PlusMinusOneOverrideExteriorValue(OverrideExteriorValue):
1455  """Values outside the interval [-1, 1] are replaced by the constant `cval`."""
1456
1457  def __init__(self, **kwargs: Any) -> None:
1458    super().__init__(uses_cval=True, **kwargs)
1459
1460  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1461    signed_distance = abs(point) - 1.0  # Boundaries at -1.0 and 1.0.
1462    self.override_using_signed_distance(weight, point, signed_distance)
1463
1464
1465@dataclasses.dataclass(frozen=True)
1466class Boundary:
1467  """Domain boundary rules.  These define the reconstruction over the source domain near and beyond
1468  the domain boundaries.  The rules may be specified separately for each domain dimension."""
1469
1470  name: str = ''
1471  """Boundary rule name."""
1472
1473  coord_remap: RemapCoordinates = NoRemapCoordinates()
1474  """Modify specified coordinates prior to evaluating the reconstruction kernels."""
1475
1476  extend_samples: ExtendSamples = ReflectExtendSamples()
1477  """Define the value of each grid sample outside the unit domain as an affine combination of
1478  interior sample(s) and possibly the constant value (`cval`)."""
1479
1480  override_value: OverrideExteriorValue = NoOverrideExteriorValue()
1481  """Set the value outside some extent to a constant value (`cval`)."""
1482
1483  @property
1484  def uses_cval(self) -> bool:
1485    """True if weights may be non-affine, involving the constant value (`cval`)."""
1486    return self.extend_samples.uses_cval or self.override_value.uses_cval
1487
1488  def preprocess_coordinates(self, point: _NDArray, /) -> _NDArray:
1489    """Modify coordinates prior to evaluating the filter kernels."""
1490    # Antialiasing across the tile boundaries may be feasible but seems hard.
1491    point = self.coord_remap(point)
1492    return point
1493
1494  def apply(
1495      self, index: _NDArray, weight: _NDArray, point: _NDArray, size: int, gridtype: Gridtype, /
1496  ) -> tuple[_NDArray, _NDArray]:
1497    """Replace exterior samples by combinations of interior samples."""
1498    index, weight = self.extend_samples(index, weight, size, gridtype)
1499    self.override_reconstruction(weight, point)
1500    return index, weight
1501
1502  def override_reconstruction(self, weight: _NDArray, point: _NDArray, /) -> None:
1503    """For points outside an extent, modify weight to zero to assign `cval`."""
1504    self.override_value(weight, point)
1505
1506
1507_DICT_BOUNDARIES = {
1508    'reflect': Boundary('reflect', extend_samples=ReflectExtendSamples()),
1509    'wrap': Boundary('wrap', extend_samples=WrapExtendSamples()),
1510    'tile': Boundary(
1511        'title', coord_remap=TileRemapCoordinates(), extend_samples=ReflectExtendSamples()
1512    ),
1513    'clamp': Boundary('clamp', extend_samples=ClampExtendSamples()),
1514    'border': Boundary('border', extend_samples=BorderExtendSamples()),
1515    'natural': Boundary(
1516        'natural',
1517        extend_samples=ValidExtendSamples(),
1518        override_value=UnitDomainOverrideExteriorValue(),
1519    ),
1520    'linear_constant': Boundary(
1521        'linear_constant',
1522        extend_samples=LinearExtendSamples(),
1523        override_value=UnitDomainOverrideExteriorValue(),
1524    ),
1525    'quadratic_constant': Boundary(
1526        'quadratic_constant',
1527        extend_samples=QuadraticExtendSamples(),
1528        override_value=UnitDomainOverrideExteriorValue(),
1529    ),
1530    'reflect_clamp': Boundary('reflect_clamp', extend_samples=ReflectClampExtendSamples()),
1531    'constant': Boundary(
1532        'constant',
1533        extend_samples=ReflectExtendSamples(),
1534        override_value=UnitDomainOverrideExteriorValue(),
1535    ),
1536    'linear': Boundary('linear', extend_samples=LinearExtendSamples()),
1537    'quadratic': Boundary('quadratic', extend_samples=QuadraticExtendSamples()),
1538}
1539
1540BOUNDARIES = list(_DICT_BOUNDARIES)
1541"""Shortcut names for some predefined boundary rules (as defined by `_DICT_BOUNDARIES`):
1542
1543| name                   | a.k.a. / comments |
1544|------------------------|-------------------|
1545| `'reflect'`            | *reflected*, *symm*, *symmetric*, *mirror*, *grid-mirror* |
1546| `'wrap'`               | *periodic*, *repeat*, *grid-wrap* |
1547| `'tile'`               | like `'reflect'` within unit domain, then tile discontinuously |
1548| `'clamp'`              | *clamped*, *nearest*, *edge*, *clamp-to-edge*, repeat last sample |
1549| `'border'`             | *grid-constant*, use `cval` for samples outside unit domain |
1550| `'natural'`            | *renormalize* using only interior samples, use `cval` outside domain |
1551| `'reflect_clamp'`      | *mirror-clamp-to-edge* |
1552| `'constant'`           | like `'reflect'` but replace by `cval` outside unit domain |
1553| `'linear'`             | extrapolate from 2 last samples |
1554| `'quadratic'`          | extrapolate from 3 last samples |
1555| `'linear_constant'`    | like `'linear'` but replace by `cval` outside unit domain |
1556| `'quadratic_constant'` | like `'quadratic'` but replace by `cval` outside unit domain |
1557
1558These boundary rules may be specified per dimension.  See the source code for extensibility
1559using the classes `RemapCoordinates`, `ExtendSamples`, and `OverrideExteriorValue`.
1560
1561**Boundary rules illustrated in 1D:**
1562
1563<center>
1564<img src="https://github.com/hhoppe/resampler/raw/main/media/boundary_rules_in_1D.png" width="100%"/>
1565</center>
1566
1567**Boundary rules illustrated in 2D:**
1568
1569<center>
1570<img src="https://github.com/hhoppe/resampler/raw/main/media/boundary_rules_in_2D.png" width="100%"/>
1571</center>
1572"""
1573
1574_OFTUSED_BOUNDARIES = (
1575    'reflect wrap tile clamp border natural linear_constant quadratic_constant'.split()
1576)
1577"""A useful subset of `BOUNDARIES` for visualization in figures."""
1578
1579
1580def _get_boundary(boundary: str | Boundary, /) -> Boundary:
1581  """Return a `Boundary`, which can be specified as a name in `BOUNDARIES`."""
1582  return boundary if isinstance(boundary, Boundary) else _DICT_BOUNDARIES[boundary]
1583
1584
1585@dataclasses.dataclass(frozen=True)
1586class Filter(abc.ABC):
1587  """Abstract base class for filter kernel functions.
1588
1589  Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support
1590  interval [-radius, radius].  (Some sites instead define kernels over the interval [0, N]
1591  where N = 2 * radius.)
1592
1593  Portions of this code are adapted from the C++ library in
1594  https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp
1595
1596  See also https://hhoppe.com/proj/filtering/.
1597  """
1598
1599  name: str
1600  """Filter kernel name."""
1601
1602  radius: float
1603  """Max absolute value of x for which self(x) is nonzero."""
1604
1605  interpolating: bool = True
1606  """True if self(0) == 1.0 and self(i) == 0.0 for all nonzero integers i."""
1607
1608  continuous: bool = True
1609  """True if the kernel function has $C^0$ continuity."""
1610
1611  partition_of_unity: bool = True
1612  """True if the convolution of the kernel with a Dirac comb reproduces the
1613  unity function."""
1614
1615  unit_integral: bool = True
1616  """True if the integral of the kernel function is 1."""
1617
1618  requires_digital_filter: bool = False
1619  """True if the filter needs a pre/post digital filter for interpolation."""
1620
1621  @abc.abstractmethod
1622  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1623    """Return evaluation of filter kernel at locations x."""
1624
1625
1626class ImpulseFilter(Filter):
1627  """See https://en.wikipedia.org/wiki/Dirac_delta_function."""
1628
1629  def __init__(self) -> None:
1630    super().__init__(name='impulse', radius=1e-20, continuous=False, partition_of_unity=False)
1631
1632  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1633    raise AssertionError('The Impulse is infinitely narrow, so cannot be directly evaluated.')
1634
1635
1636class BoxFilter(Filter):
1637  """See https://en.wikipedia.org/wiki/Box_function.
1638
1639  The kernel function has value 1.0 over the half-open interval [-.5, .5).
1640  """
1641
1642  def __init__(self) -> None:
1643    super().__init__(name='box', radius=0.5, continuous=False)
1644
1645  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1646    use_asymmetric = True
1647    if use_asymmetric:
1648      x = np.asarray(x)
1649      return np.where((-0.5 <= x) & (x < 0.5), 1.0, 0.0)
1650    x = np.abs(x)
1651    return np.where(x < 0.5, 1.0, np.where(x == 0.5, 0.5, 0.0))
1652
1653
1654class TrapezoidFilter(Filter):
1655  """Filter for antialiased "area-based" filtering.
1656
1657  Args:
1658    radius: Specifies the support [-radius, radius] of the filter, where 0.5 < radius <= 1.0.
1659      The special case `radius = None` is a placeholder that indicates that the filter will be
1660      replaced by a trapezoid of the appropriate radius (based on scaling) for correct
1661      antialiasing in both minification and magnification.
1662
1663  This filter is similar to the BoxFilter but with linearly sloped sides.  It has value 1.0
1664  in the interval abs(x) <= 1.0 - radius and decreases linearly to value 0.0 in the interval
1665  1.0 - radius <= abs(x) <= radius, always with value 0.5 at x = 0.5.
1666  """
1667
1668  def __init__(self, *, radius: float | None = None) -> None:
1669    if radius is None:
1670      super().__init__(name='trapezoid', radius=0.0)
1671      return
1672    if not 0.5 < radius <= 1.0:
1673      raise ValueError(f'Radius {radius} is outside the range (0.5, 1.0].')
1674    super().__init__(name=f'trapezoid_{radius}', radius=radius)
1675
1676  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1677    x = np.abs(x)
1678    assert 0.5 < self.radius <= 1.0
1679    return ((0.5 + 0.25 / (self.radius - 0.5)) - (0.5 / (self.radius - 0.5)) * x).clip(0.0, 1.0)
1680
1681
1682class TriangleFilter(Filter):
1683  """See https://en.wikipedia.org/wiki/Triangle_function.
1684
1685  Also known as the hat or tent function.  It is used for piecewise-linear
1686  (or bilinear, or trilinear, ...) interpolation.
1687  """
1688
1689  def __init__(self) -> None:
1690    super().__init__(name='triangle', radius=1.0)
1691
1692  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1693    return (1.0 - np.abs(x)).clip(0.0, 1.0)
1694
1695
1696class CubicFilter(Filter):
1697  """Family of cubic filters parameterized by two scalar parameters.
1698
1699  Args:
1700    b: first scalar parameter.
1701    c: second scalar parameter.
1702
1703  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1704  https://doi.org/10.1145/378456.378514.
1705
1706  [D. P. Mitchell and A. N. Netravali. Reconstruction filters in computer graphics.
1707  Computer Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1708
1709  - The filter has quadratic precision iff b + 2 * c == 1.
1710  - The filter is interpolating iff b == 0.
1711  - (b=1, c=0) is the (non-interpolating) cubic B-spline basis;
1712  - (b=1/3, c=1/3) is the Mitchell filter;
1713  - (b=0, c=0.5) is the Catmull-Rom spline (which has cubic precision);
1714  - (b=0, c=0.75) is the "sharper cubic" used in Photoshop and OpenCV.
1715  """
1716
1717  def __init__(self, *, b: float, c: float, name: str | None = None) -> None:
1718    name = f'cubic_b{b}_c{c}' if name is None else name
1719    interpolating = b == 0
1720    super().__init__(name=name, radius=2.0, interpolating=interpolating)
1721    self.b, self.c = b, c
1722
1723  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1724    x = np.abs(x)
1725    b, c = self.b, self.c
1726    f3, f2, f0 = 2 - 9 / 6 * b - c, -3 + 2 * b + c, 1 - 1 / 3 * b
1727    g3, g2, g1, g0 = -b / 6 - c, b + 5 * c, -2 * b - 8 * c, 8 / 6 * b + 4 * c
1728    # (np.polynomial.polynomial.polyval(x, [f0, 0, f2, f3]) is almost
1729    # twice as slow; see also https://stackoverflow.com/questions/24065904)
1730    v01 = ((f3 * x + f2) * x) * x + f0
1731    v12 = ((g3 * x + g2) * x + g1) * x + g0
1732    return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1733
1734
1735class CatmullRomFilter(CubicFilter):
1736  """Cubic filter with cubic precision.  Also known as Keys filter.
1737
1738  [E. Catmull, R. Rom.  A class of local interpolating splines.  Computer aided geometric
1739  design, 1974]
1740  [Wikipedia](https://en.wikipedia.org/wiki/Cubic_Hermite_spline#Catmull%E2%80%93Rom_spline)
1741
1742  [R. G. Keys.  Cubic convolution interpolation for digital image processing.
1743  IEEE Trans. on Acoustics, Speech, and Signal Processing, 29(6), 1981.]
1744  https://ieeexplore.ieee.org/document/1163711/.
1745  """
1746
1747  def __init__(self) -> None:
1748    super().__init__(b=0, c=0.5, name='cubic')
1749
1750
1751class MitchellFilter(CubicFilter):
1752  """See https://doi.org/10.1145/378456.378514.
1753
1754  [D. P. Mitchell and A. N. Netravali.  Reconstruction filters in computer graphics.  Computer
1755  Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1756  """
1757
1758  def __init__(self) -> None:
1759    super().__init__(b=1 / 3, c=1 / 3, name='mitchell')
1760
1761
1762class SharpCubicFilter(CubicFilter):
1763  """Cubic filter that is sharper than Catmull-Rom filter.
1764
1765  Used by some tools including OpenCV and Photoshop.
1766
1767  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1768  https://entropymine.com/resamplescope/notes/photoshop/.
1769  """
1770
1771  def __init__(self) -> None:
1772    super().__init__(b=0, c=0.75, name='sharpcubic')
1773
1774
1775class LanczosFilter(Filter):
1776  """High-quality filter: sinc function modulated by a sinc window.
1777
1778  Args:
1779    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1780    sampled: If True, use a discretized approximation for improved speed.
1781
1782  See https://en.wikipedia.org/wiki/Lanczos_kernel.
1783  """
1784
1785  def __init__(self, *, radius: int, sampled: bool = True) -> None:
1786    super().__init__(
1787        name=f'lanczos_{radius}', radius=radius, partition_of_unity=False, unit_integral=False
1788    )
1789
1790    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1791    def _eval(x: _ArrayLike) -> _NDArray:
1792      x = np.abs(x)
1793      # Note that window[n] = sinc(2*n/N - 1), with 0 <= n <= N.
1794      # But, x = n - N/2, or equivalently, n = x + N/2, with -N/2 <= x <= N/2.
1795      window = _sinc(x / radius)  # Zero-phase function w_0(x).
1796      return np.where(x < radius, _sinc(x) * window, 0.0)
1797
1798    self._function = _eval
1799
1800  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1801    return self._function(x)
1802
1803
1804class GeneralizedHammingFilter(Filter):
1805  """Sinc function modulated by a Hamming window.
1806
1807  Args:
1808    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1809    a0: Scalar parameter, where 0.0 < a0 < 1.0.  The case of a0=0.5 is the Hann filter.
1810
1811  See https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows,
1812  and hamming() in https://github.com/scipy/scipy/blob/main/scipy/signal/windows/_windows.py.
1813
1814  Note that `'hamming3'` is `(radius=3, a0=25/46)`, which close to but different from `a0=0.54`.
1815
1816  See also np.hamming() and np.hanning().
1817  """
1818
1819  def __init__(self, *, radius: int, a0: float) -> None:
1820    super().__init__(
1821        name=f'hamming_{radius}',
1822        radius=radius,
1823        partition_of_unity=False,  # 1:1.00242  av=1.00188  sd=0.00052909
1824        unit_integral=False,  # 1.00188
1825    )
1826    assert 0.0 < a0 < 1.0
1827    self.a0 = a0
1828
1829  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1830    x = np.abs(x)
1831    # Note that window[n] = a0 - (1 - a0) * cos(2 * pi * n / N), 0 <= n <= N.
1832    # With n = x + N/2, we get the zero-phase function w_0(x):
1833    window = self.a0 + (1.0 - self.a0) * np.cos(np.pi / self.radius * x)
1834    return np.where(x < self.radius, _sinc(x) * window, 0.0)
1835
1836
1837class KaiserFilter(Filter):
1838  """Sinc function modulated by a Kaiser-Bessel window.
1839
1840  See https://en.wikipedia.org/wiki/Kaiser_window, and example use in:
1841  [Karras et al. 20201.  Alias-free generative adversarial networks.
1842  https://arxiv.org/pdf/2106.12423.pdf].
1843
1844  See also np.kaiser().
1845
1846  Args:
1847    radius: Value L/2 in the definition.  It may be fractional for a (digital) resizing filter
1848      (sample spacing s != 1) with an even number of samples (dual grid), e.g., Eq. (6)
1849      in [Karras et al. 2021] --- this effects the precise shape of the window function.
1850    beta: Determines the trade-off between main-lobe width and side-lobe level.
1851    sampled: If True, use a discretized approximation for improved speed.
1852  """
1853
1854  def __init__(self, *, radius: float, beta: float, sampled: bool = True) -> None:
1855    assert beta >= 0.0
1856    super().__init__(
1857        name=f'kaiser_{radius}_{beta}', radius=radius, partition_of_unity=False, unit_integral=False
1858    )
1859
1860    @_cache_sampled_1d_function(xmin=-math.ceil(radius), xmax=math.ceil(radius), enable=sampled)
1861    def _eval(x: _ArrayLike) -> _NDArray:
1862      x = np.abs(x)
1863      window = np.i0(beta * np.sqrt((1.0 - np.square(x / radius)).clip(0.0, 1.0))) / np.i0(beta)
1864      return np.where(x <= radius + 1e-6, _sinc(x) * window, 0.0)
1865
1866    self._function = _eval
1867
1868  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1869    return self._function(x)
1870
1871
1872class BsplineFilter(Filter):
1873  """B-spline of a non-negative degree.
1874
1875  Args:
1876    degree: The polynomial degree of the B-spline segments.
1877      With `degree=0`, it is like `BoxFilter` except with f(0.5) = f(-0.5) = 0.
1878      With `degree=1`, it is identical to `TriangleFilter`.
1879      With `degree >= 2`, it is no longer interpolating.
1880
1881  See [Carl de Boor.  A practical guide to splines.  Springer, 2001.]
1882  https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.BSpline.html
1883  """
1884
1885  def __init__(self, *, degree: int) -> None:
1886    if degree < 0:
1887      raise ValueError(f'Bspline of degree {degree} is invalid.')
1888    radius = (degree + 1) / 2
1889    interpolating = degree <= 1
1890    super().__init__(name=f'bspline{degree}', radius=radius, interpolating=interpolating)
1891    t = list(range(degree + 2))
1892    self._bspline = scipy.interpolate.BSpline.basis_element(t)
1893
1894  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1895    x = np.abs(x)
1896    return np.where(x < self.radius, self._bspline(x + self.radius), 0.0)
1897
1898
1899class CardinalBsplineFilter(Filter):
1900  """Interpolating B-spline, achieved with aid of digital pre or post filter.
1901
1902  Args:
1903    degree: The polynomial degree of the B-spline segments.
1904    sampled: If True, use a discretized approximation for improved speed.
1905
1906  See [Hou and Andrews.  Cubic splines for image interpolation and digital filtering, 1978] and
1907  [Unser et al.  Fast B-spline transforms for continuous image representation and interpolation,
1908  1991].
1909  """
1910
1911  def __init__(self, *, degree: int, sampled: bool = True) -> None:
1912    self.degree = degree
1913    if degree < 0:
1914      raise ValueError(f'Bspline of degree {degree} is invalid.')
1915    radius = (degree + 1) / 2
1916    super().__init__(
1917        name=f'cardinal{degree}',
1918        radius=radius,
1919        requires_digital_filter=degree >= 2,
1920        continuous=degree >= 1,
1921    )
1922    t = list(range(degree + 2))
1923    bspline = scipy.interpolate.BSpline.basis_element(t)
1924
1925    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1926    def _eval(x: _ArrayLike) -> _NDArray:
1927      x = np.abs(x)
1928      return np.where(x < radius, bspline(x + radius), 0.0)
1929
1930    self._function = _eval
1931
1932  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1933    return self._function(x)
1934
1935
1936class OmomsFilter(Filter):
1937  """OMOMS interpolating filter, with aid of digital pre or post filter.
1938
1939  Args:
1940    degree: The polynomial degree of the filter segments.
1941
1942  Optimal MOMS (maximal-order-minimal-support) function; see [Blu and Thevenaz, MOMS: Maximal-order
1943  interpolation of minimal support, 2001].
1944  https://infoscience.epfl.ch/record/63074/files/blu0101.pdf
1945  """
1946
1947  def __init__(self, *, degree: int) -> None:
1948    if degree not in (3, 5):
1949      raise ValueError(f'Degree {degree} not supported.')
1950    super().__init__(name=f'omoms{degree}', radius=(degree + 1) / 2, requires_digital_filter=True)
1951    self.degree = degree
1952
1953  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1954    x = np.abs(x)
1955    match self.degree:
1956      case 3:
1957        v01 = ((0.5 * x - 1.0) * x + 3 / 42) * x + 26 / 42
1958        v12 = ((-7 / 42 * x + 1.0) * x - 85 / 42) * x + 58 / 42
1959        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1960      case 5:
1961        v01 = ((((-1 / 12 * x + 1 / 4) * x - 5 / 99) * x - 9 / 22) * x - 1 / 792) * x + 229 / 440
1962        v12 = (
1963            (((1 / 24 * x - 3 / 8) * x + 505 / 396) * x - 83 / 44) * x + 1351 / 1584
1964        ) * x + 839 / 2640
1965        v23 = (
1966            (((-1 / 120 * x + 1 / 8) * x - 299 / 396) * x + 101 / 44) * x - 27811 / 7920
1967        ) * x + 5707 / 2640
1968        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, np.where(x < 3.0, v23, 0.0)))
1969      case _:
1970        raise ValueError(self.degree)
1971
1972
1973class GaussianFilter(Filter):
1974  r"""See https://en.wikipedia.org/wiki/Gaussian_function.
1975
1976  Args:
1977    standard_deviation: Sets the Gaussian $\sigma$.  The default value is 1.25/3.0, which
1978      creates a kernel that is as-close-as-possible to a partition of unity.
1979  """
1980
1981  DEFAULT_STANDARD_DEVIATION = 1.25 / 3.0
1982  """This value creates a kernel that is as-close-as-possible to a partition of unity; see
1983  mesh_processing/test/GridOp_test.cpp: `0.93503:1.06497     av=1           sd=0.0459424`.
1984  Another possibility is 0.5, as suggested on p. 4 of [Ken Turkowski.  Filters for common
1985  resampling tasks, 1990] for kernels with a support of 3 pixels.
1986  https://cadxfem.org/inf/ResamplingFilters.pdf
1987  """
1988
1989  def __init__(self, *, standard_deviation: float = DEFAULT_STANDARD_DEVIATION) -> None:
1990    super().__init__(
1991        name=f'gaussian_{standard_deviation:.3f}',
1992        radius=np.ceil(8.0 * standard_deviation),  # Sufficiently large.
1993        interpolating=False,
1994        partition_of_unity=False,
1995    )
1996    self.standard_deviation = standard_deviation
1997
1998  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1999    x = np.abs(x)
2000    sdv = self.standard_deviation
2001    v0r = np.exp(np.square(x / sdv) / -2.0) / (np.sqrt(math.tau) * sdv)
2002    return np.where(x < self.radius, v0r, 0.0)
2003
2004
2005class NarrowBoxFilter(Filter):
2006  """Compact footprint, used for visualization of grid sample location.
2007
2008  Args:
2009    radius: Specifies the support [-radius, radius] of the narrow box function.  (The default
2010      value 0.199 is an inexact 0.2 to avoid numerical ambiguities.)
2011  """
2012
2013  def __init__(self, *, radius: float = 0.199) -> None:
2014    super().__init__(
2015        name='narrowbox',
2016        radius=radius,
2017        continuous=False,
2018        unit_integral=False,
2019        partition_of_unity=False,
2020    )
2021
2022  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2023    radius = self.radius
2024    magnitude = 1.0
2025    x = np.asarray(x)
2026    return np.where((-radius <= x) & (x < radius), magnitude, 0.0)
2027
2028
2029_DEFAULT_FILTER = 'lanczos3'
2030
2031_DICT_FILTERS = {
2032    'impulse': ImpulseFilter(),
2033    'box': BoxFilter(),
2034    'trapezoid': TrapezoidFilter(),
2035    'triangle': TriangleFilter(),
2036    'cubic': CatmullRomFilter(),
2037    'sharpcubic': SharpCubicFilter(),
2038    'lanczos3': LanczosFilter(radius=3),
2039    'lanczos5': LanczosFilter(radius=5),
2040    'lanczos10': LanczosFilter(radius=10),
2041    'cardinal3': CardinalBsplineFilter(degree=3),
2042    'cardinal5': CardinalBsplineFilter(degree=5),
2043    'omoms3': OmomsFilter(degree=3),
2044    'omoms5': OmomsFilter(degree=5),
2045    'hamming3': GeneralizedHammingFilter(radius=3, a0=25 / 46),
2046    'kaiser3': KaiserFilter(radius=3.0, beta=7.12),
2047    'gaussian': GaussianFilter(),
2048    'bspline3': BsplineFilter(degree=3),
2049    'mitchell': MitchellFilter(),
2050    'narrowbox': NarrowBoxFilter(),
2051    # Not in FILTERS:
2052    'hamming1': GeneralizedHammingFilter(radius=1, a0=0.54),
2053    'hann3': GeneralizedHammingFilter(radius=3, a0=0.5),
2054    'lanczos4': LanczosFilter(radius=4),
2055}
2056
2057FILTERS = list(itertools.takewhile(lambda x: x != 'hamming1', _DICT_FILTERS))
2058r"""Shortcut names for some predefined filter kernels (specified per dimension).
2059The names expand to:
2060
2061| name           | `Filter`                      | a.k.a. / comments |
2062|----------------|-------------------------------|-------------------|
2063| `'impulse'`    | `ImpulseFilter()`             | *nearest* |
2064| `'box'`        | `BoxFilter()`                 | non-antialiased box, e.g. ImageMagick |
2065| `'trapezoid'`  | `TrapezoidFilter()`           | *area* antialiasing, e.g. `cv.INTER_AREA` |
2066| `'triangle'`   | `TriangleFilter()`            | *linear*  (*bilinear* in 2D), spline `order=1` |
2067| `'cubic'`      | `CatmullRomFilter()`          | *catmullrom*, *keys*, *bicubic* |
2068| `'sharpcubic'` | `SharpCubicFilter()`          | `cv.INTER_CUBIC`, `torch 'bicubic'` |
2069| `'lanczos3'`   | `LanczosFilter`(radius=3)     | support window [-3, 3] |
2070| `'lanczos5'`   | `LanczosFilter`(radius=5)     | [-5, 5] |
2071| `'lanczos10'`  | `LanczosFilter`(radius=10)    | [-10, 10] |
2072| `'cardinal3'`  | `CardinalBsplineFilter`(degree=3) | *spline interpolation*, `order=3`, *GF* |
2073| `'cardinal5'`  | `CardinalBsplineFilter`(degree=5) | *spline interpolation*, `order=5`, *GF* |
2074| `'omoms3'`     | `OmomsFilter`(degree=3)       | non-$C^1$, [-3, 3], *GF* |
2075| `'omoms5'`     | `OmomsFilter`(degree=5)       | non-$C^1$, [-5, 5], *GF* |
2076| `'hamming3'`   | `GeneralizedHammingFilter`(...) | (radius=3, a0=25/46) |
2077| `'kaiser3'`    | `KaiserFilter`(radius=3.0, beta=7.12) | |
2078| `'gaussian'`   | `GaussianFilter()`            | non-interpolating, default $\sigma=1.25/3$ |
2079| `'bspline3'`   | `BsplineFilter`(degree=3)     | non-interpolating |
2080| `'mitchell'`   | `MitchellFilter()`            | *mitchellcubic* |
2081| `'narrowbox'`  | `NarrowBoxFilter()`           | for visualization of sample positions |
2082
2083The comment label *GF* denotes a [generalized filter](https://hhoppe.com/proj/filtering/), formed
2084as the composition of a finitely supported kernel and a discrete inverse convolution.
2085
2086**Some example filter kernels:**
2087
2088<center>
2089<img src="https://github.com/hhoppe/resampler/raw/main/media/filter_summary.png" width="100%"/>
2090</center>
2091
2092<br/>A more extensive set of filters is presented [here](#plots_of_filters) in the
2093[notebook](https://colab.research.google.com/github/hhoppe/resampler/blob/main/resampler_notebook.ipynb),
2094together with visualizations and analyses of the filter properties.
2095See the source code for extensibility.
2096"""
2097
2098
2099def _get_filter(filter: str | Filter, /) -> Filter:
2100  """Return a `Filter`, which can be specified as a name string key in `FILTERS`."""
2101  return filter if isinstance(filter, Filter) else _DICT_FILTERS[filter]
2102
2103
2104def _to_float_01(array: _Array, /, dtype: _DTypeLike) -> _Array:
2105  """Scale uint to the range [0.0, 1.0], and clip float to [0.0, 1.0]."""
2106  array_dtype = _arr_dtype(array)
2107  dtype = np.dtype(dtype)
2108  assert np.issubdtype(dtype, np.floating)
2109  match array_dtype.type:
2110    case np.uint8 | np.uint16 | np.uint32:
2111      if _arr_arraylib(array) == 'numpy':
2112        assert isinstance(array, np.ndarray)  # Help mypy.
2113        return np.multiply(array, 1 / np.iinfo(array_dtype).max, dtype=dtype)
2114      return _arr_astype(array, dtype) / np.iinfo(array_dtype).max
2115    case _:
2116      assert np.issubdtype(array_dtype, np.floating)
2117      return _arr_clip(array, 0.0, 1.0, dtype)
2118
2119
2120def _from_float(array: _Array, /, dtype: _DTypeLike) -> _Array:
2121  """Convert a float in range [0.0, 1.0] to uint or float type."""
2122  assert np.issubdtype(_arr_dtype(array), np.floating)
2123  dtype = np.dtype(dtype)
2124  match dtype.type:
2125    case np.uint8 | np.uint16:
2126      return cast(_Array, _arr_astype(array * np.float32(np.iinfo(dtype).max) + 0.5, dtype))
2127    case np.uint32:
2128      return cast(_Array, _arr_astype(array * np.float64(np.iinfo(dtype).max) + 0.5, dtype))
2129    case _:
2130      assert np.issubdtype(dtype, np.floating)
2131      return _arr_astype(array, dtype)
2132
2133
2134@dataclasses.dataclass(frozen=True)
2135class Gamma(abc.ABC):
2136  """Abstract base class for transfer functions on sample values.
2137
2138  Image/video content is often stored using a color component transfer function.
2139  See https://en.wikipedia.org/wiki/Gamma_correction.
2140
2141  Converts between integer types and [0.0, 1.0] internal value range.
2142  """
2143
2144  name: str
2145  """Name of component transfer function."""
2146
2147  @abc.abstractmethod
2148  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2149    """Decode source sample values into floating-point, possibly nonlinearly.
2150
2151    Uint source values are mapped to the range [0.0, 1.0].
2152    """
2153
2154  @abc.abstractmethod
2155  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2156    """Encode float signal into destination samples, possibly nonlinearly.
2157
2158    Uint destination values are mapped from the range [0.0, 1.0].
2159
2160    Note that non-integer destination types are not clipped to the range [0.0, 1.0].
2161    If that is desired, it can be performed as a postprocess using `output.clip(0.0, 1.0)`.
2162    """
2163
2164
2165class IdentityGamma(Gamma):
2166  """Identity component transfer function."""
2167
2168  def __init__(self) -> None:
2169    super().__init__('identity')
2170
2171  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2172    dtype = np.dtype(dtype)
2173    assert np.issubdtype(dtype, np.inexact)
2174    if np.issubdtype(_arr_dtype(array), np.unsignedinteger):
2175      return _to_float_01(array, dtype)
2176    return _arr_astype(array, dtype)
2177
2178  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2179    dtype = np.dtype(dtype)
2180    assert np.issubdtype(dtype, np.number)
2181    if np.issubdtype(dtype, np.unsignedinteger):
2182      return _from_float(_arr_clip(array, 0.0, 1.0), dtype)
2183    if np.issubdtype(dtype, np.integer):
2184      return _arr_astype(array + 0.5, dtype)
2185    return _arr_astype(array, dtype)
2186
2187
2188class PowerGamma(Gamma):
2189  """Gamma correction using a power function."""
2190
2191  def __init__(self, power: float) -> None:
2192    super().__init__(name=f'power_{power}')
2193    self.power = power
2194
2195  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2196    dtype = np.dtype(dtype)
2197    assert np.issubdtype(dtype, np.floating)
2198    if _arr_dtype(array) == np.uint8 and self.power != 2:
2199      arraylib = _arr_arraylib(array)
2200      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2201      return _arr_getitem(decode_table, array)
2202
2203    array = _to_float_01(array, dtype)
2204    return _arr_square(array) if self.power == 2 else array**self.power
2205
2206  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2207    array = _arr_clip(array, 0.0, 1.0)
2208    array = _arr_sqrt(array) if self.power == 2 else array ** (1.0 / self.power)
2209    return _from_float(array, dtype)
2210
2211
2212class SrgbGamma(Gamma):
2213  """Gamma correction using sRGB; see https://en.wikipedia.org/wiki/SRGB."""
2214
2215  def __init__(self) -> None:
2216    super().__init__(name='srgb')
2217
2218  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2219    dtype = np.dtype(dtype)
2220    assert np.issubdtype(dtype, np.floating)
2221    if _arr_dtype(array) == np.uint8:
2222      arraylib = _arr_arraylib(array)
2223      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2224      return _arr_getitem(decode_table, array)
2225
2226    x = _to_float_01(array, dtype)
2227    return _arr_where(x > 0.04045, ((x + 0.055) / 1.055) ** 2.4, x / 12.92)
2228
2229  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2230    x = _arr_clip(array, 0.0, 1.0)
2231    # Unfortunately, exponentiation is slow, and np.digitize() is even slower.
2232    # pytype: disable=wrong-arg-types
2233    x = _arr_where(x > 0.0031308, x ** (1.0 / 2.4) * 1.055 - (0.055 - 1e-17), x * 12.92)
2234    # pytype: enable=wrong-arg-types
2235    return _from_float(x, dtype)
2236
2237
2238_DICT_GAMMAS = {
2239    'identity': IdentityGamma(),
2240    'power2': PowerGamma(2.0),
2241    'power22': PowerGamma(2.2),
2242    'srgb': SrgbGamma(),
2243}
2244
2245GAMMAS = list(_DICT_GAMMAS)
2246r"""Shortcut names for some predefined gamma-correction schemes:
2247
2248| name | `Gamma` | Decoding function<br/> (linear space from stored value) | Encoding function<br/> (stored value from linear space) |
2249|---|---|:---:|:---:|
2250| `'identity'` | `IdentityGamma()` | $l = e$ | $e = l$ |
2251| `'power2'` | `PowerGamma`(2.0) | $l = e^{2.0}$ | $e = l^{1/2.0}$ |
2252| `'power22'` | `PowerGamma`(2.2) | $l = e^{2.2}$ | $e = l^{1/2.2}$ |
2253| `'srgb'` | `SrgbGamma()` | $l = \left(\left(e + 0.055\right) / 1.055\right)^{2.4}$ | $e = l^{1/2.4} * 1.055 - 0.055$ |
2254
2255See the source code for extensibility.
2256"""
2257
2258
2259def _get_gamma(gamma: str | Gamma, /) -> Gamma:
2260  """Return a `Gamma`, which can be specified as a name in `GAMMAS`."""
2261  return gamma if isinstance(gamma, Gamma) else _DICT_GAMMAS[gamma]
2262
2263
2264def _get_src_dst_gamma(
2265    gamma: str | Gamma | None,
2266    src_gamma: str | Gamma | None,
2267    dst_gamma: str | Gamma | None,
2268    src_dtype: _DType,
2269    dst_dtype: _DType,
2270) -> tuple[Gamma, Gamma]:
2271  if gamma is None and src_gamma is None and dst_gamma is None:
2272    src_uint = np.issubdtype(src_dtype, np.unsignedinteger)
2273    dst_uint = np.issubdtype(dst_dtype, np.unsignedinteger)
2274    if src_uint and dst_uint:
2275      # The default might ideally be 'srgb' but that conversion is costlier.
2276      gamma = 'power2'
2277    elif not src_uint and not dst_uint:
2278      gamma = 'identity'
2279    else:
2280      raise ValueError(f'Gamma must be specified because {src_dtype=} and {dst_dtype=}.')
2281  if gamma is not None:
2282    if src_gamma is not None:
2283      raise ValueError('Cannot specify both gamma and src_gamma.')
2284    if dst_gamma is not None:
2285      raise ValueError('Cannot specify both gamma and dst_gamma.')
2286    src_gamma = dst_gamma = gamma
2287  assert src_gamma and dst_gamma
2288  src_gamma = _get_gamma(src_gamma)
2289  dst_gamma = _get_gamma(dst_gamma)
2290  return src_gamma, dst_gamma
2291
2292
2293def _create_resize_matrix(
2294    src_size: int,
2295    dst_size: int,
2296    src_gridtype: Gridtype,
2297    dst_gridtype: Gridtype,
2298    boundary: Boundary,
2299    filter: Filter,
2300    prefilter: Filter | None = None,
2301    scale: float = 1.0,
2302    translate: float = 0.0,
2303    dtype: _DTypeLike = np.float64,
2304    arraylib: str = 'numpy',
2305) -> tuple[_Array, _Array | None]:
2306  """Compute affine weights for 1D resampling from `src_size` to `dst_size`.
2307
2308  Compute a sparse matrix in which each row expresses a destination sample value as a combination
2309  of source sample values depending on the boundary rule.  If the combination is non-affine,
2310  the remainder (returned as `cval_weight`) is the contribution of the special constant value
2311  (cval) defined outside the domain.
2312
2313  Args:
2314    src_size: The number of samples within the source 1D domain.
2315    dst_size: The number of samples within the destination 1D domain.
2316    src_gridtype: Placement of the samples in the source domain grid.
2317    dst_gridtype: Placement of the output samples in the destination domain grid.
2318    boundary: The reconstruction boundary rule.
2319    filter: The reconstruction kernel (used for upsampling/magnification).
2320    prefilter: The prefilter kernel (used for downsampling/minification).  If it is `None`,
2321      `filter` is used.
2322    scale: Scaling factor applied when mapping the source domain onto the destination domain.
2323    translate: Offset applied when mapping the scaled source domain onto the destination domain.
2324    dtype: Precision of computed resize matrix entries.
2325    arraylib: Representation of output.  Must be an element of `ARRAYLIBS`.
2326
2327  Returns:
2328    sparse_matrix: Matrix whose rows express output sample values as affine combinations of the
2329      source sample values.
2330    cval_weight: Optional vector expressing the additional contribution of the constant value
2331      (`cval`) to the combination in each row of `sparse_matrix`.  It equals one minus the sum of
2332      the weights in each matrix row.
2333  """
2334  if src_size < src_gridtype.min_size():
2335    raise ValueError(f'Source size {src_size} is too small for resize.')
2336  if dst_size < dst_gridtype.min_size():
2337    raise ValueError(f'Destination size {dst_size} is too small for resize.')
2338  prefilter = filter if prefilter is None else prefilter
2339  dtype = np.dtype(dtype)
2340  assert np.issubdtype(dtype, np.floating)
2341
2342  scaling = dst_gridtype.size_in_samples(dst_size) / src_gridtype.size_in_samples(src_size) * scale
2343  is_minification = scaling < 1.0
2344  filter = prefilter if is_minification else filter
2345  if filter.name == 'trapezoid':
2346    radius = 0.5 + 0.5 * min(scaling, 1.0 / scaling)
2347    filter = TrapezoidFilter(radius=radius)
2348  radius = filter.radius
2349  num_samples = int(np.ceil(radius * 2 / scaling) if is_minification else np.ceil(radius * 2))
2350
2351  dst_index = np.arange(dst_size, dtype=dtype)
2352  # Destination sample locations in unit domain [0, 1].
2353  dst_position = dst_gridtype.point_from_index(dst_index, dst_size)
2354
2355  src_position = (dst_position - translate) / scale
2356  src_position = boundary.preprocess_coordinates(src_position)
2357
2358  # Sample positions mapped back to source unit domain [0, 1].
2359  src_float_index = src_gridtype.index_from_point(src_position, src_size)
2360  src_first_index = (
2361      np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
2362      - (num_samples - 1) // 2
2363  )
2364
2365  sample_index = np.arange(num_samples, dtype=np.int32)
2366  src_index = src_first_index[:, None] + sample_index  # (dst_size, num_samples)
2367
2368  def get_weight_matrix() -> _NDArray:
2369    if filter.name == 'impulse':
2370      return np.ones(src_index.shape, dtype)
2371    if is_minification:
2372      x = (src_float_index[:, None] - src_index.astype(dtype)) * scaling
2373      return filter(x) * scaling
2374    # Either same size or magnification.
2375    x = src_float_index[:, None] - src_index.astype(dtype)
2376    return filter(x)
2377
2378  weight = get_weight_matrix().astype(dtype, copy=False)
2379
2380  if filter.name != 'narrowbox' and (is_minification or not filter.partition_of_unity):
2381    weight = weight / weight.sum(axis=-1)[..., None]
2382
2383  src_index, weight = boundary.apply(src_index, weight, src_position, src_size, src_gridtype)
2384  shape = dst_size, src_size
2385
2386  def prepare_sparse_resize_matrix() -> tuple[_NDArray, _NDArray, _NDArray]:
2387    linearized = (src_index + np.indices(src_index.shape)[0] * src_size).ravel()
2388    values = weight.ravel()
2389    # Remove the zero weights.
2390    nonzero = values != 0.0
2391    linearized, values = linearized[nonzero], values[nonzero]
2392    # Sort and merge the duplicate indices.
2393    unique, unique_inverse = np.unique(linearized, return_inverse=True)
2394    data2 = np.ones(len(linearized), np.float32)
2395    row_ind2 = unique_inverse
2396    col_ind2 = np.arange(len(linearized))
2397    shape2 = len(unique), len(linearized)
2398    csr = scipy.sparse.csr_matrix((data2, (row_ind2, col_ind2)), shape=shape2)
2399    data = csr * values  # Merged values.
2400    row_ind, col_ind = unique // src_size, unique % src_size  # Merged indices.
2401    return data, row_ind, col_ind
2402
2403  data, row_ind, col_ind = prepare_sparse_resize_matrix()
2404  resize_matrix = _make_sparse_matrix(data, row_ind, col_ind, shape, arraylib)
2405
2406  uses_cval = boundary.uses_cval or filter.name == 'narrowbox'
2407  cval_weight = _make_array(1.0 - weight.sum(axis=-1), arraylib) if uses_cval else None
2408
2409  return resize_matrix, cval_weight
2410
2411
2412def _apply_digital_filter_1d(
2413    array: _Array,
2414    gridtype: Gridtype,
2415    boundary: Boundary,
2416    cval: _ArrayLike,
2417    filter: Filter,
2418    /,
2419    *,
2420    axis: int = 0,
2421) -> _Array:
2422  """Apply inverse convolution to the specified dimension of the array.
2423
2424  Find the array coefficients such that convolution with the (continuous) filter (given
2425  gridtype and boundary) interpolates the original array values.
2426  """
2427  assert filter.requires_digital_filter
2428  arraylib = _arr_arraylib(array)
2429
2430  if arraylib == 'tensorflow':
2431    import tensorflow as tf
2432
2433    def forward(x: _NDArray) -> _NDArray:
2434      return _apply_digital_filter_1d_numpy(x, gridtype, boundary, cval, filter, axis, False)
2435
2436    def backward(grad_output: _NDArray) -> _NDArray:
2437      return _apply_digital_filter_1d_numpy(
2438          grad_output, gridtype, boundary, cval, filter, axis, True
2439      )
2440
2441    @tf.custom_gradient  # type: ignore[misc]
2442    def tensorflow_inverse_convolution(x: _TensorflowTensor) -> _TensorflowTensor:
2443      # Although `forward` accesses parameters gridtype, boundary, etc., it is not stateful
2444      # because the function is redefined on each invocation of _apply_digital_filter_1d.
2445      y = tf.numpy_function(forward, [x], x.dtype, stateful=False)
2446
2447      def grad(grad_output: _TensorflowTensor) -> _TensorflowTensor:
2448        return tf.numpy_function(backward, [grad_output], x.dtype, stateful=False)
2449
2450      return y, grad
2451
2452    return tensorflow_inverse_convolution(array)
2453
2454  if arraylib == 'torch':
2455    import torch.autograd
2456
2457    class InverseConvolution(torch.autograd.Function):  # type: ignore[misc] # pylint: disable=abstract-method
2458      """Differentiable wrapper for _apply_digital_filter_1d."""
2459
2460      @staticmethod
2461      # pylint: disable-next=arguments-differ
2462      def forward(ctx: Any, *args: _TorchTensor, **kwargs: Any) -> _TorchTensor:
2463        del ctx
2464        assert not kwargs
2465        (x,) = args
2466        a = _apply_digital_filter_1d_numpy(
2467            x.detach().numpy(), gridtype, boundary, cval, filter, axis, False
2468        )
2469        return torch.as_tensor(a)
2470
2471      @staticmethod
2472      def backward(ctx: Any, *grad_outputs: _TorchTensor) -> _TorchTensor:
2473        del ctx
2474        (grad_output,) = grad_outputs
2475        a = _apply_digital_filter_1d_numpy(
2476            grad_output.detach().numpy(), gridtype, boundary, cval, filter, axis, True
2477        )
2478        return torch.as_tensor(a)
2479
2480    return InverseConvolution.apply(array)
2481
2482  if arraylib == 'jax':
2483    import jax
2484    import jax.numpy as jnp
2485    # It seems rather difficult to implement this digital filter (inverse convolution) in Jax.
2486    # https://jax.readthedocs.io/en/latest/jax.scipy.html sadly omits scipy.signal.filtfilt().
2487    # To include a (non-traceable) numpy function in Jax requires jax.experimental.host_callback
2488    # and/or defining a new jax.core.Primitive (which allows differentiability).  See
2489    # https://github.com/google/jax/issues/1142#issuecomment-544286585
2490    # https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb  :-(
2491    # https://github.com/google/jax/issues/5934
2492
2493    @jax.custom_gradient  # type: ignore[misc]
2494    def jax_inverse_convolution(x: _JaxArray) -> _JaxArray:
2495      # This function is not jax-traceable due to the presence of to_py(), so jit and grad fail.
2496      x_py = np.asarray(x)  # to_py() deprecated.
2497      a = _apply_digital_filter_1d_numpy(x_py, gridtype, boundary, cval, filter, axis, False)
2498      y = jnp.asarray(a)
2499
2500      def grad(grad_output: _JaxArray) -> _JaxArray:
2501        grad_output_py = np.asarray(grad_output)  # to_py() deprecated.
2502        a = _apply_digital_filter_1d_numpy(
2503            grad_output_py, gridtype, boundary, cval, filter, axis, True
2504        )
2505        return jnp.asarray(a)
2506
2507      return y, grad
2508
2509    return jax_inverse_convolution(array)
2510
2511  assert arraylib == 'numpy'
2512  assert isinstance(array, np.ndarray)  # Help mypy.
2513  return _apply_digital_filter_1d_numpy(array, gridtype, boundary, cval, filter, axis, False)
2514
2515
2516def _apply_digital_filter_1d_numpy(
2517    array: _NDArray,
2518    gridtype: Gridtype,
2519    boundary: Boundary,
2520    cval: _ArrayLike,
2521    filter: Filter,
2522    axis: int,
2523    compute_backward: bool,
2524    /,
2525) -> _NDArray:
2526  """Version of _apply_digital_filter_1d` specialized to numpy array."""
2527  assert np.issubdtype(array.dtype, np.inexact)
2528  cval = np.asarray(cval).astype(array.dtype, copy=False)
2529
2530  # Use fast spline_filter1d() if we have a compatible gridtype, boundary, and filter:
2531  mode = {
2532      ('reflect', 'dual'): 'reflect',
2533      ('reflect', 'primal'): 'mirror',
2534      ('wrap', 'dual'): 'grid-wrap',
2535      ('wrap', 'primal'): 'wrap',
2536  }.get((boundary.name, gridtype.name))
2537  filter_is_compatible = isinstance(filter, CardinalBsplineFilter)
2538  use_split_filter1d = filter_is_compatible and mode
2539  if use_split_filter1d:
2540    assert isinstance(filter, CardinalBsplineFilter)  # Help mypy.
2541    assert filter.degree >= 2
2542    # compute_backward=True is same: matrix is symmetric and cval is unused.
2543    return scipy.ndimage.spline_filter1d(
2544        array, axis=axis, order=filter.degree, mode=mode, output=array.dtype
2545    )
2546
2547  array_dim = np.moveaxis(array, axis, 0)
2548  l = original_l = math.ceil(filter.radius) - 1
2549  x = np.arange(-l, l + 1, dtype=array.real.dtype)
2550  values = filter(x)
2551  size = array_dim.shape[0]
2552  src_index = np.arange(size)[:, None] + np.arange(len(values)) - l
2553  weight = np.full((size, len(values)), values)
2554  src_position = np.broadcast_to(0.5, len(values))
2555  src_index, weight = boundary.apply(src_index, weight, src_position, size, gridtype)
2556  if gridtype.name == 'primal' and boundary.name == 'wrap':
2557    # Overwrite redundant last row to preserve unreferenced last sample and thereby make the
2558    # matrix non-singular.
2559    src_index[-1] = [size - 1] + [0] * (src_index.shape[1] - 1)
2560    weight[-1] = [1.0] + [0.0] * (weight.shape[1] - 1)
2561  bandwidth = abs(src_index - np.arange(size)[:, None]).max()
2562  is_banded = bandwidth <= l + 1  # Add one for quadratic boundary and l == 1.
2563  # Currently, matrix is always banded unless boundary.name == 'wrap'.
2564
2565  data = weight.reshape(-1).astype(array.dtype, copy=False)
2566  row_ind = np.arange(size).repeat(src_index.shape[1])
2567  col_ind = src_index.reshape(-1)
2568  matrix = scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=(size, size))
2569  if compute_backward:
2570    matrix = matrix.T
2571
2572  if boundary.uses_cval and not compute_backward:
2573    cval_weight = 1.0 - np.asarray(matrix.sum(axis=-1))[:, 0]
2574    if array_dim.ndim == 2:  # Handle the case that we have array_flat.
2575      cval = np.tile(cval.reshape(-1), array_dim[0].size // cval.size)
2576    array_dim = array_dim - cval_weight.reshape(-1, *(1,) * array_dim[0].ndim) * cval
2577
2578  array_flat = array_dim.reshape(array_dim.shape[0], -1)
2579
2580  if is_banded:
2581    matrix = matrix.todia()
2582    assert np.all(np.diff(matrix.offsets) == 1)  # Consecutive, often [-l, l].
2583    l, u = -matrix.offsets[0], matrix.offsets[-1]
2584    assert l <= original_l + 1 and u <= original_l + 1, (l, u, original_l)
2585    options = dict(check_finite=False, overwrite_ab=True, overwrite_b=False)
2586    if _is_symmetric(matrix):
2587      array_flat = scipy.linalg.solveh_banded(matrix.data[-1 : l - 1 : -1], array_flat, **options)
2588    else:
2589      array_flat = scipy.linalg.solve_banded((l, u), matrix.data[::-1], array_flat, **options)
2590
2591  else:
2592    lu = scipy.sparse.linalg.splu(matrix.tocsc(), permc_spec='NATURAL')
2593    assert all(s <= size * len(values) for s in (lu.L.nnz, lu.U.nnz))  # Sparse.
2594    array_flat = lu.solve(array_flat)
2595
2596  array_dim = array_flat.reshape(array_dim.shape)
2597  return np.moveaxis(array_dim, 0, axis)
2598
2599
2600def resize(
2601    array: _Array,
2602    /,
2603    shape: Iterable[int],
2604    *,
2605    gridtype: str | Gridtype | None = None,
2606    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2607    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2608    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2609    cval: _ArrayLike = 0.0,
2610    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2611    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2612    gamma: str | Gamma | None = None,
2613    src_gamma: str | Gamma | None = None,
2614    dst_gamma: str | Gamma | None = None,
2615    scale: float | Iterable[float] = 1.0,
2616    translate: float | Iterable[float] = 0.0,
2617    precision: _DTypeLike = None,
2618    dtype: _DTypeLike = None,
2619    dim_order: Iterable[int] | None = None,
2620    num_threads: int | Literal['auto'] = 'auto',
2621) -> _Array:
2622  """Resample `array` (a grid of sample values) onto a grid with resolution `shape`.
2623
2624  The source `array` is any object recognized by `ARRAYLIBS`.  It is interpreted as a grid
2625  with `len(shape)` domain coordinate dimensions, where each grid sample value has shape
2626  `array.shape[len(shape):]`.
2627
2628  Some examples:
2629
2630  - A grayscale image has `array.shape = height, width` and resizing it with `len(shape) == 2`
2631    produces a new image of scalar values.
2632  - An RGB image has `array.shape = height, width, 3` and resizing it with `len(shape) == 2`
2633    produces a new image of RGB values.
2634  - An 3D grid of 3x3 Jacobians has `array.shape = Z, Y, X, 3, 3` and resizing it with
2635    `len(shape) == 3` produces a new 3D grid of Jacobians.
2636
2637  This function also allows scaling and translation from the source domain to the output domain
2638  through the parameters `scale` and `translate`.  For more general transforms, see `resample`.
2639
2640  Args:
2641    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
2642      The array must have numeric type.  Its first `len(shape)` dimensions are the domain
2643      coordinate dimensions.  Each grid dimension must be at least 1 for a `'dual'` grid or
2644      at least 2 for a `'primal'` grid.
2645    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2646      `array` must have at least as many dimensions as `len(shape)`.
2647    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
2648      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
2649      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
2650    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2651      Parameters `gridtype` and `src_gridtype` cannot both be set.
2652    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2653      Parameters `gridtype` and `dst_gridtype` cannot both be set.
2654    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2655      a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses `'reflect'`
2656      for upsampling and `'clamp'` for downsampling.
2657    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
2658      onto `array.shape[len(shape):]`.  It is subject to `src_gamma`.
2659    filter: The reconstruction kernel for each dimension in `shape`, specified as either a filter
2660      name in `FILTERS` or a `Filter` instance.  It is used during upsampling (i.e., magnification).
2661    prefilter: The prefilter kernel for each dimension in `shape`, specified as either a filter
2662      name in `FILTERS` or a `Filter` instance.  It is used during downsampling
2663      (i.e., minification).  If `None`, it inherits the value of `filter`.  The default
2664      `'lanczos3'` is good for natural images.  For vector graphics images, `'trapezoid'` is better
2665      because it avoids ringing artifacts.
2666    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from
2667      `array` and when creating output grid samples.  It is specified as either a name in `GAMMAS`
2668      or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default is
2669      `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
2670      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
2671      range [0.0, 1.0].
2672    src_gamma: Component transfer function used to "decode" `array` samples.
2673      Parameters `gamma` and `src_gamma` cannot both be set.
2674    dst_gamma: Component transfer function used to "encode" the output samples.
2675      Parameters `gamma` and `dst_gamma` cannot both be set.
2676    scale: Scaling factor applied to each dimension of the source domain when it is mapped onto
2677      the destination domain.
2678    translate: Offset applied to each dimension of the scaled source domain when it is mapped onto
2679      the destination domain.
2680    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
2681      on `array.dtype` and `dtype`.
2682    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
2683      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
2684      to the uint range.
2685    dim_order: Override the automatically selected order in which the grid dimensions are resized.
2686      Must contain a permutation of `range(len(shape))`.
2687    num_threads: Used to determine multithread parallelism if `array` is from `numpy`.  If set to
2688      `'auto'`, it is selected automatically.  Otherwise, it must be a positive integer.
2689
2690  Returns:
2691    An array of the same class as the source `array`, with shape `shape + array.shape[len(shape):]`
2692      and data type `dtype`.
2693
2694  **Example of image upsampling:**
2695
2696  >>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
2697  >>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.
2698
2699  <center>
2700  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_upsampled.png"/>
2701  </center>
2702
2703  **Example of image downsampling:**
2704
2705  >>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
2706  >>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
2707  >>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
2708  >>> downsampled = resize(array, (24, 48))
2709
2710  <center>
2711  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_downsampled2.png"/>
2712  </center>
2713
2714  **Unit test:**
2715
2716  >>> result = resize([1.0, 4.0, 5.0], shape=(4,))
2717  >>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
2718  """
2719  if isinstance(array, (tuple, list)):
2720    array = np.asarray(array)
2721  arraylib = _arr_arraylib(array)
2722  array_dtype = _arr_dtype(array)
2723  if not np.issubdtype(array_dtype, np.number):
2724    raise ValueError(f'Type {array.dtype} is not numeric.')
2725  shape2 = tuple(shape)
2726  array_ndim = len(array.shape)
2727  if not 0 < len(shape2) <= array_ndim:
2728    raise ValueError(f'Shape {array.shape} cannot be resized to {shape2}.')
2729  src_shape = array.shape[: len(shape2)]
2730  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2731      gridtype, src_gridtype, dst_gridtype, len(shape2), len(shape2)
2732  )
2733  boundary2 = np.broadcast_to(np.array(boundary), len(shape2))
2734  cval = np.broadcast_to(cval, array.shape[len(shape2) :])
2735  prefilter = filter if prefilter is None else prefilter
2736  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape2))]
2737  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), len(shape2))]
2738  dtype = array_dtype if dtype is None else np.dtype(dtype)
2739  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, array_dtype, dtype)
2740  scale2 = np.broadcast_to(np.array(scale), len(shape2))
2741  translate2 = np.broadcast_to(np.array(translate), len(shape2))
2742  del (shape, src_gridtype, dst_gridtype, boundary, filter, prefilter)
2743  del (src_gamma, dst_gamma, scale, translate)
2744  precision = _get_precision(precision, [array_dtype, dtype], [])
2745  weight_precision = _real_precision(precision)
2746
2747  is_noop = (
2748      all(src == dst for src, dst in zip(src_shape, shape2))
2749      and all(gt1 == gt2 for gt1, gt2 in zip(src_gridtype2, dst_gridtype2))
2750      and all(f.interpolating for f in prefilter2)
2751      and np.all(scale2 == 1.0)
2752      and np.all(translate2 == 0.0)
2753      and src_gamma2 == dst_gamma2
2754  )
2755  if is_noop:
2756    return array
2757
2758  if dim_order is None:
2759    dim_order = _arr_best_dims_order_for_resize(array, shape2)
2760  else:
2761    dim_order = tuple(dim_order)
2762    if sorted(dim_order) != list(range(len(shape2))):
2763      raise ValueError(f'{dim_order} not a permutation of {list(range(len(shape2)))}.')
2764
2765  array = src_gamma2.decode(array, precision)
2766  cval = _arr_numpy(src_gamma2.decode(cval, precision))
2767
2768  can_use_fast_box_downsampling = (
2769      using_numba
2770      and arraylib == 'numpy'
2771      and len(shape2) == 2
2772      and array_ndim in (2, 3)
2773      and all(src > dst for src, dst in zip(src_shape, shape2))
2774      and all(src % dst == 0 for src, dst in zip(src_shape, shape2))
2775      and all(gridtype.name == 'dual' for gridtype in src_gridtype2)
2776      and all(gridtype.name == 'dual' for gridtype in dst_gridtype2)
2777      and all(f.name in ('box', 'trapezoid') for f in prefilter2)
2778      and np.all(scale2 == 1.0)
2779      and np.all(translate2 == 0.0)
2780  )
2781  if can_use_fast_box_downsampling:
2782    assert isinstance(array, np.ndarray)  # Help mypy.
2783    array = _downsample_in_2d_using_box_filter(array, cast(Any, shape2))
2784    array = dst_gamma2.encode(array, dtype)
2785    return array
2786
2787  # Multidimensional resize can be expressed using einsum() with multiple per-dim resize matrices,
2788  # e.g., as in jax.image.resize().  A benefit is to seek the optimal order of multiplications.
2789  # However, efficiency often requires sparse resize matrices, which are unsupported in einsum().
2790  # Sparse tensors requested for tf.einsum: https://github.com/tensorflow/tensorflow/issues/43497
2791  # https://github.com/tensor-compiler/taco: C++ library that computes tensor algebra expressions
2792  # on sparse and dense tensors; however it does not interoperate with tensorflow, torch, or jax.
2793
2794  for dim in dim_order:
2795    skip_resize_on_this_dim = (
2796        shape2[dim] == array.shape[dim]
2797        and scale2[dim] == 1.0
2798        and translate2[dim] == 0.0
2799        and filter2[dim].interpolating
2800    )
2801    if skip_resize_on_this_dim:
2802      continue
2803
2804    def get_is_minification() -> bool:
2805      src_in_samples = src_gridtype2[dim].size_in_samples(array.shape[dim])
2806      dst_in_samples = dst_gridtype2[dim].size_in_samples(shape2[dim])
2807      return dst_in_samples / src_in_samples * scale2[dim] < 1.0
2808
2809    is_minification = get_is_minification()
2810    boundary_dim = boundary2[dim]
2811    if boundary_dim == 'auto':
2812      boundary_dim = 'clamp' if is_minification else 'reflect'
2813    boundary_dim = _get_boundary(boundary_dim)
2814    resize_matrix, cval_weight = _create_resize_matrix(
2815        array.shape[dim],
2816        shape2[dim],
2817        src_gridtype=src_gridtype2[dim],
2818        dst_gridtype=dst_gridtype2[dim],
2819        boundary=boundary_dim,
2820        filter=filter2[dim],
2821        prefilter=prefilter2[dim],
2822        scale=scale2[dim],
2823        translate=translate2[dim],
2824        dtype=weight_precision,
2825        arraylib=arraylib,
2826    )
2827
2828    array_dim: _Array = _arr_moveaxis(array, dim, 0)
2829    array_flat = _arr_reshape(array_dim, (array_dim.shape[0], -1))
2830    array_flat = _arr_possibly_make_contiguous(array_flat)
2831    if not is_minification and filter2[dim].requires_digital_filter:
2832      array_flat = _apply_digital_filter_1d(
2833          array_flat, src_gridtype2[dim], boundary_dim, cval, filter2[dim]
2834      )
2835
2836    array_flat = _arr_matmul_sparse_dense(resize_matrix, array_flat, num_threads=num_threads)
2837    if cval_weight is not None:
2838      cval_flat = np.broadcast_to(cval, array_dim.shape[1:]).reshape(-1)
2839      if np.issubdtype(array_dtype, np.complexfloating):
2840        cval_weight = _arr_astype(cval_weight, array_dtype)  # (Only necessary for 'tensorflow'.)
2841      array_flat += cval_weight[:, None] * cval_flat
2842
2843    if is_minification and filter2[dim].requires_digital_filter:  # use prefilter2[dim]?
2844      array_flat = _apply_digital_filter_1d(
2845          array_flat, dst_gridtype2[dim], boundary_dim, cval, filter2[dim]
2846      )
2847    array_dim = _arr_reshape(array_flat, (array_flat.shape[0], *array_dim.shape[1:]))
2848    array = _arr_moveaxis(array_dim, 0, dim)
2849
2850  array = dst_gamma2.encode(array, dtype)
2851  return array
2852
2853
2854_original_resize = resize
2855
2856
2857def resize_in_arraylib(array: _NDArray, /, *args: Any, arraylib: str, **kwargs: Any) -> _NDArray:
2858  """Evaluate the `resize()` operation using the specified array library from `ARRAYLIBS`."""
2859  _check_eq(_arr_arraylib(array), 'numpy')
2860  return _arr_numpy(_original_resize(_make_array(array, arraylib), *args, **kwargs))
2861
2862
2863def resize_in_numpy(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2864  """Evaluate the `resize()` operation using the `numpy` library."""
2865  return resize_in_arraylib(array, *args, arraylib='numpy', **kwargs)
2866
2867
2868def resize_in_tensorflow(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2869  """Evaluate the `resize()` operation using the `tensorflow` library."""
2870  return resize_in_arraylib(array, *args, arraylib='tensorflow', **kwargs)
2871
2872
2873def resize_in_torch(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2874  """Evaluate the `resize()` operation using the `torch` library."""
2875  return resize_in_arraylib(array, *args, arraylib='torch', **kwargs)
2876
2877
2878def resize_in_jax(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2879  """Evaluate the `resize()` operation using the `jax` library."""
2880  return resize_in_arraylib(array, *args, arraylib='jax', **kwargs)
2881
2882
2883def _resize_possibly_in_arraylib(
2884    array: _Array, /, *args: Any, arraylib: str, **kwargs: Any
2885) -> _AnyArray:
2886  """If `array` is from numpy, evaluate `resize()` using the array library from `ARRAYLIBS`."""
2887  if _arr_arraylib(array) == 'numpy':
2888    return _arr_numpy(
2889        _original_resize(_make_array(cast(_ArrayLike, array), arraylib), *args, **kwargs)
2890    )
2891  return _original_resize(array, *args, **kwargs)
2892
2893
2894@functools.cache
2895def _create_jaxjit_resize() -> Callable[..., _Array]:
2896  """Lazily invoke `jax.jit` on `resize`."""
2897  import jax
2898
2899  jitted: Any = jax.jit(
2900      _original_resize, static_argnums=(1,), static_argnames=list(_original_resize.__kwdefaults__)
2901  )
2902  return jitted
2903
2904
2905def jaxjit_resize(array: _Array, /, *args: Any, **kwargs: Any) -> _Array:
2906  """Compute `resize` but with resize function jitted using Jax."""
2907  return _create_jaxjit_resize()(array, *args, **kwargs)  # pylint: disable=not-callable
2908
2909
2910def uniform_resize(
2911    array: _Array,
2912    /,
2913    shape: Iterable[int],
2914    *,
2915    object_fit: Literal['contain', 'cover'] = 'contain',
2916    gridtype: str | Gridtype | None = None,
2917    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2918    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2919    boundary: str | Boundary | Iterable[str | Boundary] = 'natural',  # Instead of 'auto' default.
2920    scale: float | Iterable[float] = 1.0,
2921    translate: float | Iterable[float] = 0.0,
2922    **kwargs: Any,
2923) -> _Array:
2924  """Resample `array` onto a grid with resolution `shape` but with uniform scaling.
2925
2926  Calls function `resize` with `scale` and `translate` set such that the aspect ratio of `array`
2927  is preserved.  The effect is similar to CSS `object-fit: contain`.
2928  The parameter `boundary` (whose default is changed to `'natural'`) determines the values assigned
2929  outside the source domain.
2930
2931  Args:
2932    array: Regular grid of source sample values.
2933    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2934      `array` must have at least as many dimensions as `len(shape)`.
2935    object_fit: Like CSS `object-fit`.  If `'contain'`, `array` is resized uniformly to fit within
2936      `shape`. If `'cover'`, `array` is resized to fully cover `shape`.
2937    gridtype: Placement of samples on all dimensions of both the source and output domain grids.
2938    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2939    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2940    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2941      a name in `BOUNDARIES` or a `Boundary` instance.  The default is `'natural'`, which assigns
2942      `cval` to output points that map outside the source unit domain.
2943    scale: Parameter may not be specified.
2944    translate: Parameter may not be specified.
2945    **kwargs: Additional parameters for `resize` function (including `cval`).
2946
2947  Returns:
2948    An array with shape `shape + array.shape[len(shape):]`.
2949
2950  >>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
2951  array([[0., 1., 1., 0.],
2952         [0., 1., 1., 0.]])
2953
2954  >>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
2955  array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
2956         [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
2957
2958  >>> a = np.arange(6.0).reshape(2, 3)
2959  >>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
2960  array([[0.5, 1.5],
2961         [3.5, 4.5]])
2962  """
2963  if scale != 1.0 or translate != 0.0:
2964    raise ValueError('`uniform_resize()` does not accept `scale` or `translate` parameters.')
2965  if isinstance(array, (tuple, list)):
2966    array = np.asarray(array)
2967  shape = tuple(shape)
2968  array_ndim = len(array.shape)
2969  if not 0 < len(shape) <= array_ndim:
2970    raise ValueError(f'Shape {array.shape} cannot be resized to {shape}.')
2971  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2972      gridtype, src_gridtype, dst_gridtype, len(shape), len(shape)
2973  )
2974  raw_scales = [
2975      dst_gridtype2[dim].size_in_samples(shape[dim])
2976      / src_gridtype2[dim].size_in_samples(array.shape[dim])
2977      for dim in range(len(shape))
2978  ]
2979  scale0 = {'contain': min(raw_scales), 'cover': max(raw_scales)}[object_fit]
2980  scale2 = scale0 / np.array(raw_scales)
2981  translate = (1.0 - scale2) / 2
2982  return resize(array, shape, boundary=boundary, scale=scale2, translate=translate, **kwargs)
2983
2984
2985_MAX_BLOCK_SIZE_RECURSING = -999  # Special value to indicate re-invocation on partitioned blocks.
2986
2987
2988def resample(
2989    array: _Array,
2990    /,
2991    coords: _ArrayLike,
2992    *,
2993    gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual',
2994    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2995    cval: _ArrayLike = 0.0,
2996    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2997    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2998    gamma: str | Gamma | None = None,
2999    src_gamma: str | Gamma | None = None,
3000    dst_gamma: str | Gamma | None = None,
3001    jacobian: _ArrayLike | None = None,
3002    precision: _DTypeLike = None,
3003    dtype: _DTypeLike = None,
3004    max_block_size: int = 40_000,
3005    debug: bool = False,
3006) -> _Array:
3007  """Interpolate `array` (a grid of samples) at specified unit-domain coordinates `coords`.
3008
3009  The last dimension of `coords` contains unit-domain coordinates at which to interpolate the
3010  domain grid samples in `array`.
3011
3012  The number of coordinates (`coords.shape[-1]`) determines how to interpret `array`: its first
3013  `coords.shape[-1]` dimensions define the grid, and the remaining dimensions describe each grid
3014  sample (e.g., scalar, vector, tensor).
3015
3016  Concretely, the grid has shape `array.shape[:coords.shape[-1]]` and each grid sample has shape
3017  `array.shape[coords.shape[-1]:]`.
3018
3019  Examples include:
3020
3021  - Resample a grayscale image with `array.shape = height, width` onto a new grayscale image with
3022    `new.shape = height2, width2` by using `coords.shape = height2, width2, 2`.
3023
3024  - Resample an RGB image with `array.shape = height, width, 3` onto a new RGB image with
3025    `new.shape = height2, width2, 3` by using `coords.shape = height2, width2, 2`.
3026
3027  - Sample an RGB image at `num` 2D points along a line segment by using `coords.shape = num, 2`.
3028
3029  - Sample an RGB image at a single 2D point by using `coords.shape = (2,)`.
3030
3031  - Sample a 3D grid of 3x3 Jacobians with `array.shape = nz, ny, nx, 3, 3` along a 2D plane by
3032    using `coords.shape = height, width, 3`.
3033
3034  - Map a grayscale image through a color map by using `array.shape = 256, 3` and
3035    `coords.shape = height, width`.
3036
3037  Args:
3038    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3039      The array must have numeric type.  The coordinate dimensions appear first, and
3040      each grid sample may have an arbitrary shape.  Each grid dimension must be at least 1 for
3041      a `'dual'` grid or at least 2 for a `'primal'` grid.
3042    coords: Grid of points at which to resample `array`.  The point coordinates are in the last
3043      dimension of `coords`.  The domain associated with the source grid is a unit hypercube,
3044      i.e. with a range [0, 1] on each coordinate dimension.  The output grid has shape
3045      `coords.shape[:-1]` and each of its grid samples has shape `array.shape[coords.shape[-1]:]`.
3046    gridtype: Placement of the samples in the source domain grid for each dimension, specified as
3047      either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`.
3048    boundary: The reconstruction boundary rule for each dimension in `coords.shape[-1]`, specified
3049      as either a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses
3050      `'reflect'` for upsampling and `'clamp'` for downsampling.
3051    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
3052      onto the shape `array.shape[coords.shape[-1]:]`.  It is subject to `src_gamma`.
3053    filter: The reconstruction kernel for each dimension in `coords.shape[-1]`, specified as either
3054      a filter name in `FILTERS` or a `Filter` instance.
3055    prefilter: The prefilter kernel for each dimension in `coords.shape[:-1]`, specified as either
3056      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3057      (i.e., minification).  If `None`, it inherits the value of `filter`.
3058    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples
3059      from `array` and when creating output grid samples.  It is specified as either a name in
3060      `GAMMAS` or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default
3061      is `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
3062      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
3063      range [0.0, 1.0].
3064    src_gamma: Component transfer function used to "decode" `array` samples.
3065      Parameters `gamma` and `src_gamma` cannot both be set.
3066    dst_gamma: Component transfer function used to "encode" the output samples.
3067      Parameters `gamma` and `dst_gamma` cannot both be set.
3068    jacobian: Optional array, which must be broadcastable onto the shape
3069      `coords.shape[:-1] + (coords.shape[-1], coords.shape[-1])`, storing for each point in the
3070      output grid the Jacobian matrix of the map from the unit output domain to the unit source
3071      domain.  If omitted, it is estimated by computing finite differences on `coords`.
3072    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3073      on `array.dtype`, `coords.dtype`, and `dtype`.
3074    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3075      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3076      to the uint range.
3077    max_block_size: If nonzero, maximum number of grid points in `coords` before the resampling
3078      evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
3079    debug: Show internal information.
3080
3081  Returns:
3082    A new sample grid of shape `coords.shape[:-1]`, represented as an array of shape
3083    `coords.shape[:-1] + array.shape[coords.shape[-1]:]`, of the same array library type as
3084    the source array.
3085
3086  **Example of resample operation:**
3087
3088  <center>
3089  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_warp_coords.png"/>
3090  </center>
3091
3092  For reference, the identity resampling for a scalar-valued grid with the default grid-type
3093  `'dual'` is:
3094
3095  >>> array = np.random.default_rng(0).random((5, 7, 3))
3096  >>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
3097  >>> new_array = resample(array, coords)
3098  >>> assert np.allclose(new_array, array)
3099
3100  It is more efficient to use the function `resize` for the special case where the `coords` are
3101  obtained as simple scaling and translation of a new regular grid over the source domain:
3102
3103  >>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
3104  >>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
3105  >>> coords = (coords - translate) / scale
3106  >>> resampled = resample(array, coords)
3107  >>> resized = resize(array, new_shape, scale=scale, translate=translate)
3108  >>> assert np.allclose(resampled, resized)
3109  """
3110  if isinstance(array, (tuple, list)):
3111    array = np.asarray(array)
3112  arraylib = _arr_arraylib(array)
3113  if len(array.shape) == 0:
3114    array = array[None]
3115  coords = np.atleast_1d(coords)
3116  if not np.issubdtype(_arr_dtype(array), np.number):
3117    raise ValueError(f'Type {array.dtype} is not numeric.')
3118  if not np.issubdtype(coords.dtype, np.floating):
3119    raise ValueError(f'Type {coords.dtype} is not floating.')
3120  array_ndim = len(array.shape)
3121  if coords.ndim == 1 and coords.shape[0] > 1 and array_ndim == 1:
3122    coords = coords[:, None]
3123  grid_ndim = coords.shape[-1]
3124  grid_shape = array.shape[:grid_ndim]
3125  sample_shape = array.shape[grid_ndim:]
3126  resampled_ndim = coords.ndim - 1
3127  resampled_shape = coords.shape[:-1]
3128  if grid_ndim > array_ndim:
3129    raise ValueError(
3130        f'There are more coordinate dimensions ({grid_ndim}) in {coords=} than in {array.shape=}.'
3131    )
3132  gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(gridtype), grid_ndim)]
3133  boundary2 = np.broadcast_to(np.array(boundary), grid_ndim).tolist()
3134  cval = np.broadcast_to(cval, sample_shape)
3135  prefilter = filter if prefilter is None else prefilter
3136  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), grid_ndim)]
3137  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), resampled_ndim)]
3138  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3139  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, _arr_dtype(array), dtype)
3140  del gridtype, boundary, filter, prefilter, src_gamma, dst_gamma
3141  if jacobian is not None:
3142    jacobian = np.broadcast_to(jacobian, resampled_shape + (coords.shape[-1],) * 2)
3143  precision = _get_precision(precision, [_arr_dtype(array), dtype], [coords.dtype])
3144  weight_precision = _real_precision(precision)
3145  coords = coords.astype(weight_precision, copy=False)
3146  is_minification = False  # Current limitation; no prefiltering!
3147  assert max_block_size >= 0 or max_block_size == _MAX_BLOCK_SIZE_RECURSING
3148  for dim in range(grid_ndim):
3149    if boundary2[dim] == 'auto':
3150      boundary2[dim] = 'clamp' if is_minification else 'reflect'
3151    boundary2[dim] = _get_boundary(boundary2[dim])
3152
3153  if max_block_size != _MAX_BLOCK_SIZE_RECURSING:
3154    array = src_gamma2.decode(array, precision)
3155    for dim in range(grid_ndim):
3156      assert not is_minification
3157      if filter2[dim].requires_digital_filter:
3158        array = _apply_digital_filter_1d(
3159            array, gridtype2[dim], boundary2[dim], cval, filter2[dim], axis=dim
3160        )
3161    cval = _arr_numpy(src_gamma2.decode(cval, precision))
3162
3163  if math.prod(resampled_shape) > max_block_size > 0:
3164    block_shape = _block_shape_with_min_size(resampled_shape, max_block_size)
3165    if debug:
3166      print(f'(resample: splitting coords into blocks {block_shape}).')
3167    coord_blocks = _split_array_into_blocks(coords, block_shape)
3168
3169    def process_block(coord_block: _NDArray) -> _Array:
3170      return resample(
3171          array,
3172          coord_block,
3173          gridtype=gridtype2,
3174          boundary=boundary2,
3175          cval=cval,
3176          filter=filter2,
3177          prefilter=prefilter2,
3178          src_gamma='identity',
3179          dst_gamma=dst_gamma2,
3180          jacobian=jacobian,
3181          precision=precision,
3182          dtype=dtype,
3183          max_block_size=_MAX_BLOCK_SIZE_RECURSING,
3184      )
3185
3186    result_blocks = _map_function_over_blocks(coord_blocks, process_block)
3187    array = _merge_array_from_blocks(result_blocks)
3188    return array
3189
3190  # A concrete example of upsampling:
3191  #   array = np.ones((5, 7, 3))  # source RGB image has height=5 width=7
3192  #   coords = np.random.default_rng(0).random((8, 9, 2))  # output RGB image has height=8 width=9
3193  #   resample(array, coords, filter=('cubic', 'lanczos3'))
3194  #   grid_shape = 5, 7  grid_ndim = 2
3195  #   resampled_shape = 8, 9  resampled_ndim = 2
3196  #   sample_shape = (3,)
3197  #   src_float_index.shape = 8, 9
3198  #   src_first_index.shape = 8, 9
3199  #   sample_index.shape = (4,) for dim == 0, then (6,) for dim == 1
3200  #   weight = [shape(8, 9, 4), shape(8, 9, 6)]
3201  #   src_index = [shape(8, 9, 4), shape(8, 9, 6)]
3202
3203  # Both:[shape(8, 9, 4), shape(8, 9, 6)]
3204  weight: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3205  src_index: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3206  uses_cval = False
3207  all_num_samples = []  # will be [4, 6]
3208
3209  for dim in range(grid_ndim):
3210    src_size = grid_shape[dim]  # scalar
3211    coords_dim = coords[..., dim]  # (8, 9)
3212    radius = filter2[dim].radius  # scalar
3213    num_samples = int(np.ceil(radius * 2))  # scalar
3214    all_num_samples.append(num_samples)
3215
3216    boundary_dim = boundary2[dim]
3217    coords_dim = boundary_dim.preprocess_coordinates(coords_dim)
3218
3219    # Sample positions mapped back to source unit domain [0, 1].
3220    src_float_index = gridtype2[dim].index_from_point(coords_dim, src_size)  # (8, 9)
3221    src_first_index = (
3222        np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
3223        - (num_samples - 1) // 2
3224    )  # (8, 9)
3225
3226    sample_index = np.arange(num_samples, dtype=np.int32)  # (4,) then (6,)
3227    src_index[dim] = src_first_index[..., None] + sample_index  # (8, 9, 4) then (8, 9, 6)
3228    if filter2[dim].name == 'trapezoid':
3229      # (It might require changing the filter radius at every sample.)
3230      raise ValueError('resample() cannot use adaptive `trapezoid` filter.')
3231    if filter2[dim].name == 'impulse':
3232      weight[dim] = np.ones_like(src_index[dim], weight_precision)
3233    else:
3234      x = src_float_index[..., None] - src_index[dim].astype(weight_precision)
3235      weight[dim] = filter2[dim](x).astype(weight_precision, copy=False)
3236      if filter2[dim].name != 'narrowbox' and (
3237          is_minification or not filter2[dim].partition_of_unity
3238      ):
3239        weight[dim] = weight[dim] / weight[dim].sum(axis=-1)[..., None]
3240
3241    src_index[dim], weight[dim] = boundary_dim.apply(
3242        src_index[dim], weight[dim], coords_dim, src_size, gridtype2[dim]
3243    )
3244    if boundary_dim.uses_cval or filter2[dim].name == 'narrowbox':
3245      uses_cval = True
3246
3247  # Gather the samples.
3248
3249  # Recall that src_index = [shape(8, 9, 4), shape(8, 9, 6)].
3250  src_index_expanded = []
3251  for dim in range(grid_ndim):
3252    src_index_dim = np.moveaxis(
3253        src_index[dim].reshape(src_index[dim].shape + (1,) * (grid_ndim - 1)),
3254        resampled_ndim,
3255        resampled_ndim + dim,
3256    )
3257    src_index_expanded.append(src_index_dim)
3258  indices = tuple(src_index_expanded)  # (shape(8, 9, 4, 1), shape(8, 9, 1, 6))
3259  samples = _arr_getitem(array, indices)  # (8, 9, 4, 6, 3)
3260
3261  # Indirectly derive samples.ndim (which is unavailable during Tensorflow grad computation).
3262  samples_ndim = resampled_ndim + grid_ndim + len(sample_shape)
3263
3264  # Compute an Einstein summation over the samples and each of the per-dimension weights.
3265
3266  def label(dims: Iterable[int]) -> str:
3267    return ''.join(chr(ord('a') + i) for i in dims)
3268
3269  operands = [samples]  # (8, 9, 4, 6, 3)
3270  assert samples_ndim < 26  # Letters 'a' through 'z'.
3271  labels = [label(range(samples_ndim))]  # ['abcde']
3272  for dim in range(grid_ndim):
3273    operands.append(weight[dim])  # (8, 9, 4), then (8, 9, 6)
3274    labels.append(label(list(range(resampled_ndim)) + [resampled_ndim + dim]))  # 'abc' then 'abd'
3275  output_label = label(
3276      list(range(resampled_ndim)) + list(range(resampled_ndim + grid_ndim, samples_ndim))
3277  )  # 'abe'
3278  subscripts = ','.join(labels) + '->' + output_label  # 'abcde,abc,abd->abe'
3279  # Starting in numpy 2.0, np.einsum() outputs np.float64 even with all np.float32 inputs;
3280  # GPT: "aligns np.einsum with other functions where intermediate calculations use higher
3281  # precision (np.float64) regardless of input type when floating-point arithmetic is involved."
3282  # we could explicitly add the parameter `dtype=precision`.
3283  array = _arr_einsum(subscripts, *operands)  # (8, 9, 3)
3284
3285  # Gathering `samples` is the memory bottleneck.  It would be ideal if the gather() and einsum()
3286  # computations could be fused.  In Jax, https://github.com/google/jax/issues/3206 suggests
3287  # that this may become possible.  In any case, for large outputs it helps to partition the
3288  # evaluation over output tiles (using max_block_size).
3289
3290  if uses_cval:
3291    cval_weight = 1.0 - np.multiply.reduce(
3292        [weight[dim].sum(axis=-1) for dim in range(resampled_ndim)]
3293    )  # (8, 9)
3294    cval_weight_reshaped = cval_weight.reshape(cval_weight.shape + (1,) * len(sample_shape))
3295    array += _make_array((cval_weight_reshaped * cval).astype(precision, copy=False), arraylib)
3296
3297  array = dst_gamma2.encode(array, dtype)
3298  return array
3299
3300
3301def resample_affine(
3302    array: _Array,
3303    /,
3304    shape: Iterable[int],
3305    matrix: _ArrayLike,
3306    *,
3307    gridtype: str | Gridtype | None = None,
3308    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3309    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3310    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3311    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3312    precision: _DTypeLike = None,
3313    dtype: _DTypeLike = None,
3314    **kwargs: Any,
3315) -> _Array:
3316  """Resample a source array using an affinely transformed grid of given shape.
3317
3318  The `matrix` transformation can be linear,
3319    `source_point = matrix @ destination_point`,
3320  or it can be affine where the last matrix column is an offset vector,
3321    `source_point = matrix @ (destination_point, 1.0)`.
3322
3323  Args:
3324    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3325      The array must have numeric type.  The number of grid dimensions is determined from
3326      `matrix.shape[0]`; the remaining dimensions are for each sample value and are all
3327      linearly interpolated.
3328    shape: Dimensions of the desired destination grid.  The number of destination grid dimensions
3329      may be different from that of the source grid.
3330    matrix: 2D array for a linear or affine transform from unit-domain destination points
3331      (in a space with `len(shape)` dimensions) into unit-domain source points (in a space with
3332      `matrix.shape[0]` dimensions).  If the matrix has `len(shape) + 1` columns, the last column
3333      is the affine offset (i.e., translation).
3334    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
3335      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
3336      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
3337    src_gridtype: Placement of samples in the source domain grid for each dimension.
3338      Parameters `gridtype` and `src_gridtype` cannot both be set.
3339    dst_gridtype: Placement of samples in the output domain grid for each dimension.
3340      Parameters `gridtype` and `dst_gridtype` cannot both be set.
3341    filter: The reconstruction kernel for each dimension in `matrix.shape[0]`, specified as either
3342      a filter name in `FILTERS` or a `Filter` instance.
3343    prefilter: The prefilter kernel for each dimension in `len(shape)`, specified as either
3344      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3345      (i.e., minification).  If `None`, it inherits the value of `filter`.
3346    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3347      on `array.dtype` and `dtype`.
3348    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3349      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3350      to the uint range.
3351    **kwargs: Additional parameters for `resample` function.
3352
3353  Returns:
3354    An array of the same class as the source `array`, representing a grid with specified `shape`,
3355    where each grid value is resampled from `array`.  Thus the shape of the returned array is
3356    `shape + array.shape[matrix.shape[0]:]`.
3357  """
3358  if isinstance(array, (tuple, list)):
3359    array = np.asarray(array)
3360  shape = tuple(shape)
3361  matrix = np.asarray(matrix)
3362  dst_ndim = len(shape)
3363  if matrix.ndim != 2:
3364    raise ValueError(f'Array {matrix} is not 2D matrix.')
3365  src_ndim = matrix.shape[0]
3366  # grid_shape = array.shape[:src_ndim]
3367  is_affine = matrix.shape[1] == dst_ndim + 1
3368  if src_ndim > len(array.shape):
3369    raise ValueError(
3370        f'Matrix {matrix} has more rows ({matrix.shape[0]}) than ndim in {array.shape=}.'
3371    )
3372  if matrix.shape[1] != dst_ndim and not is_affine:
3373    raise ValueError(
3374        f'Matrix has {matrix.shape=}, but we expect either {dst_ndim} or {dst_ndim + 1} columns.'
3375    )
3376  src_gridtype2, dst_gridtype2 = _get_gridtypes(
3377      gridtype, src_gridtype, dst_gridtype, src_ndim, dst_ndim
3378  )
3379  prefilter = filter if prefilter is None else prefilter
3380  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), src_ndim)]
3381  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), dst_ndim)]
3382  del src_gridtype, dst_gridtype, filter, prefilter
3383  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3384  precision = _get_precision(precision, [_arr_dtype(array), dtype], [])
3385  weight_precision = _real_precision(precision)
3386
3387  dst_position_list = []  # per dimension
3388  for dim in range(dst_ndim):
3389    dst_size = shape[dim]
3390    dst_index = np.arange(dst_size, dtype=weight_precision)
3391    dst_position_list.append(dst_gridtype2[dim].point_from_index(dst_index, dst_size))
3392  dst_position = np.meshgrid(*dst_position_list, indexing='ij')
3393
3394  linear_matrix = matrix[:, :-1] if is_affine else matrix
3395  src_position = np.tensordot(linear_matrix, dst_position, 1)
3396  coords = np.moveaxis(src_position, 0, -1)
3397  if is_affine:
3398    coords += matrix[:, -1]
3399
3400  # TODO: Based on grid_shape, shape, linear_matrix, and prefilter, determine a
3401  # convolution prefilter and apply it to bandlimit 'array', using boundary for padding.
3402
3403  return resample(
3404      array,
3405      coords,
3406      gridtype=src_gridtype2,
3407      filter=filter2,
3408      prefilter=prefilter2,
3409      precision=precision,
3410      dtype=dtype,
3411      **kwargs,
3412  )
3413
3414
3415def _resize_using_resample(
3416    array: _Array,
3417    /,
3418    shape: Iterable[int],
3419    *,
3420    scale: _ArrayLike = 1.0,
3421    translate: _ArrayLike = 0.0,
3422    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3423    fallback: bool = False,
3424    **kwargs: Any,
3425) -> _Array:
3426  """Use the more general `resample` operation for `resize`, as a debug tool."""
3427  if isinstance(array, (tuple, list)):
3428    array = np.asarray(array)
3429  shape = tuple(shape)
3430  scale = np.broadcast_to(scale, len(shape))
3431  translate = np.broadcast_to(translate, len(shape))
3432  # TODO: let resample() do prefiltering for proper downsampling.
3433  has_minification = np.any(np.array(shape) < array.shape[: len(shape)]) or np.any(scale < 1.0)
3434  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape))]
3435  has_auto_trapezoid = any(f.name == 'trapezoid' for f in filter2)
3436  if fallback and (has_minification or has_auto_trapezoid):
3437    return _original_resize(array, shape, scale=scale, translate=translate, filter=filter, **kwargs)
3438  offset = -translate / scale
3439  matrix = np.concatenate([np.diag(1.0 / scale), offset[:, None]], axis=1)
3440  return resample_affine(array, shape, matrix, filter=filter, **kwargs)
3441
3442
3443def rotation_about_center_in_2d(
3444    src_shape: _ArrayLike,
3445    /,
3446    angle: float,
3447    *,
3448    new_shape: _ArrayLike | None = None,
3449    scale: float = 1.0,
3450) -> _NDArray:
3451  """Return the 3x3 matrix mapping destination into a source unit domain.
3452
3453  The returned matrix accounts for the possibly non-square domain shapes.
3454
3455  Args:
3456    src_shape: Resolution `(ny, nx)` of the source domain grid.
3457    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3458      onto the destination domain.
3459    new_shape: Resolution `(ny, nx)` of the destination domain grid; it defaults to `src_shape`.
3460    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3461  """
3462
3463  def translation_matrix(vector: _NDArray) -> _NDArray:
3464    matrix = np.eye(len(vector) + 1)
3465    matrix[:-1, -1] = vector
3466    return matrix
3467
3468  def scaling_matrix(scale: _NDArray) -> _NDArray:
3469    return np.diag(tuple(scale) + (1.0,))
3470
3471  def rotation_matrix_2d(angle: float) -> _NDArray:
3472    cos, sin = np.cos(angle), np.sin(angle)
3473    return np.array([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
3474
3475  src_shape = np.asarray(src_shape)
3476  new_shape = src_shape if new_shape is None else np.asarray(new_shape)
3477  _check_eq(src_shape.shape, (2,))
3478  _check_eq(new_shape.shape, (2,))
3479  half = np.array([0.5, 0.5])
3480  matrix = (
3481      translation_matrix(half)
3482      @ scaling_matrix(min(src_shape) / src_shape)
3483      @ rotation_matrix_2d(angle)
3484      @ scaling_matrix(scale * new_shape / min(new_shape))
3485      @ translation_matrix(-half)
3486  )
3487  assert np.allclose(matrix[-1], [0.0, 0.0, 1.0])
3488  return matrix
3489
3490
3491def rotate_image_about_center(
3492    image: _NDArray,
3493    /,
3494    angle: float,
3495    *,
3496    new_shape: _ArrayLike | None = None,
3497    scale: float = 1.0,
3498    num_rotations: int = 1,
3499    **kwargs: Any,
3500) -> _NDArray:
3501  """Return a copy of `image` rotated about its center.
3502
3503  Args:
3504    image: Source grid samples; the first two dimensions are spatial (ny, nx).
3505    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3506      onto the destination domain.
3507    new_shape: Resolution `(ny, nx)` of the output grid; it defaults to `image.shape[:2]`.
3508    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3509    num_rotations: Number of rotations (each by `angle`).  Successive resamplings are useful in
3510      analyzing the filtering quality.
3511    **kwargs: Additional parameters for `resample_affine`.
3512  """
3513  new_shape = image.shape[:2] if new_shape is None else np.asarray(new_shape)
3514  matrix = rotation_about_center_in_2d(image.shape[:2], angle, new_shape=new_shape, scale=scale)
3515  for _ in range(num_rotations):
3516    image = resample_affine(image, new_shape, matrix[:-1], **kwargs)
3517  return image
3518
3519
3520def pil_image_resize(
3521    array: _ArrayLike,
3522    /,
3523    shape: Iterable[int],
3524    *,
3525    filter: str,
3526    boundary: str = 'natural',
3527    cval: float = 0.0,
3528) -> _NDArray:
3529  """Invoke `PIL.Image.resize` using the same parameters as `resize`."""
3530  import PIL.Image
3531
3532  if boundary != 'natural':
3533    raise ValueError(f"{boundary=} must equal 'natural'.")
3534  del cval
3535  array = np.asarray(array)
3536  assert 1 <= array.ndim <= 3
3537  assert np.issubdtype(array.dtype, np.floating)
3538  shape = tuple(shape)
3539  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3540  if array.ndim == 1:
3541    return pil_image_resize(array[None], (1, *shape), filter=filter)[0]
3542  if not hasattr(PIL.Image, 'Resampling'):  # Pillow<9.0
3543    PIL.Image.Resampling = PIL.Image
3544  filters = {
3545      'impulse': PIL.Image.Resampling.NEAREST,
3546      'box': PIL.Image.Resampling.BOX,
3547      'triangle': PIL.Image.Resampling.BILINEAR,
3548      'hamming1': PIL.Image.Resampling.HAMMING,
3549      'cubic': PIL.Image.Resampling.BICUBIC,
3550      'lanczos3': PIL.Image.Resampling.LANCZOS,
3551  }
3552  if filter not in filters:
3553    raise ValueError(f'{filter=} not in {filters=}.')
3554  pil_resample = filters[filter]
3555  if array.ndim == 2:
3556    return np.array(
3557        PIL.Image.fromarray(array).resize(shape[::-1], resample=pil_resample), array.dtype
3558    )
3559  stack = []
3560  for channel in np.moveaxis(array, -1, 0):
3561    pil_image = PIL.Image.fromarray(channel).resize(shape[::-1], resample=pil_resample)
3562    stack.append(np.array(pil_image, array.dtype))
3563  return np.dstack(stack)
3564
3565
3566def cv_resize(
3567    array: _ArrayLike,
3568    /,
3569    shape: Iterable[int],
3570    *,
3571    filter: str,
3572    boundary: str = 'clamp',
3573    cval: float = 0.0,
3574) -> _NDArray:
3575  """Invoke `cv.resize` using the same parameters as `resize`."""
3576  import cv2 as cv
3577
3578  if boundary != 'clamp':
3579    raise ValueError(f"{boundary=} must equal 'clamp'.")
3580  del cval
3581  array = np.asarray(array)
3582  assert 1 <= array.ndim <= 3
3583  shape = tuple(shape)
3584  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3585  if array.ndim == 1:
3586    return cv_resize(array[None], (1, *shape), filter=filter)[0]
3587  filters = {
3588      'impulse': cv.INTER_NEAREST,  # Or consider cv.INTER_NEAREST_EXACT.
3589      'triangle': cv.INTER_LINEAR_EXACT,  # Or just cv.INTER_LINEAR.
3590      'trapezoid': cv.INTER_AREA,
3591      'sharpcubic': cv.INTER_CUBIC,
3592      'lanczos4': cv.INTER_LANCZOS4,
3593  }
3594  if filter not in filters:
3595    raise ValueError(f'{filter=} not in {filters=}.')
3596  interpolation = filters[filter]
3597  result = cv.resize(array, shape[::-1], interpolation=interpolation)
3598  if array.ndim == 3 and result.ndim == 2:
3599    assert array.shape[2] == 1
3600    return result[..., None]  # Add back the last dimension dropped by cv.resize().
3601  return result
3602
3603
3604def scipy_ndimage_resize(
3605    array: _ArrayLike,
3606    /,
3607    shape: Iterable[int],
3608    *,
3609    filter: str,
3610    boundary: str = 'reflect',
3611    cval: float = 0.0,
3612    scale: float | Iterable[float] = 1.0,
3613    translate: float | Iterable[float] = 0.0,
3614) -> _NDArray:
3615  """Invoke `scipy.ndimage.map_coordinates` using the same parameters as `resize`."""
3616  array = np.asarray(array)
3617  shape = tuple(shape)
3618  assert 1 <= len(shape) <= array.ndim
3619  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3620  if filter not in filters:
3621    raise ValueError(f'{filter=} not in {filters=}.')
3622  order = filters[filter]
3623  boundaries = {'reflect': 'reflect', 'wrap': 'grid-wrap', 'clamp': 'nearest', 'border': 'constant'}
3624  if boundary not in boundaries:
3625    raise ValueError(f'{boundary=} not in {boundaries=}.')
3626  mode = boundaries[boundary]
3627  shape_all = shape + array.shape[len(shape) :]
3628  coords = np.moveaxis(np.indices(shape_all, array.dtype), 0, -1)
3629  coords[..., : len(shape)] = (
3630      (coords[..., : len(shape)] + 0.5) / shape - np.asarray(translate)
3631  ) / np.asarray(scale) * np.array(array.shape)[: len(shape)] - 0.5
3632  coords = np.moveaxis(coords, -1, 0)
3633  return scipy.ndimage.map_coordinates(array, coords, order=order, mode=mode, cval=cval)
3634
3635
3636def skimage_transform_resize(
3637    array: _ArrayLike,
3638    /,
3639    shape: Iterable[int],
3640    *,
3641    filter: str,
3642    boundary: str = 'reflect',
3643    cval: float = 0.0,
3644) -> _NDArray:
3645  """Invoke `skimage.transform.resize` using the same parameters as `resize`."""
3646  import skimage.transform
3647
3648  array = np.asarray(array)
3649  shape = tuple(shape)
3650  assert 1 <= len(shape) <= array.ndim
3651  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3652  if filter not in filters:
3653    raise ValueError(f'{filter=} not in {filters=}.')
3654  order = filters[filter]
3655  boundaries = {'reflect': 'symmetric', 'wrap': 'wrap', 'clamp': 'edge', 'border': 'constant'}
3656  if boundary not in boundaries:
3657    raise ValueError(f'{boundary=} not in {boundaries=}.')
3658  mode = boundaries[boundary]
3659  shape_all = shape + array.shape[len(shape) :]
3660  # Default anti_aliasing=None automatically enables (poor) Gaussian prefilter if downsampling.
3661  # clip=False is the default behavior in `resampler` if the output type is non-integer.
3662  return skimage.transform.resize(
3663      array, shape_all, order=order, mode=mode, cval=cval, clip=False
3664  )  # type: ignore[no-untyped-call]
3665
3666
3667_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER = {
3668    'impulse': 'nearest',
3669    'trapezoid': 'area',
3670    'triangle': 'bilinear',
3671    'mitchell': 'mitchellcubic',
3672    'cubic': 'bicubic',
3673    'lanczos3': 'lanczos3',
3674    'lanczos5': 'lanczos5',
3675    # GaussianFilter(0.5): 'gaussian',  # radius_4 > desired_radius_3.
3676}
3677
3678
3679def tf_image_resize(
3680    array: _ArrayLike,
3681    /,
3682    shape: Iterable[int],
3683    *,
3684    filter: str,
3685    boundary: str = 'natural',
3686    cval: float = 0.0,
3687    antialias: bool = True,
3688) -> _TensorflowTensor:
3689  """Invoke `tf.image.resize` using the same parameters as `resize`."""
3690  import tensorflow as tf
3691
3692  if filter not in _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER:
3693    raise ValueError(f'{filter=} not in {_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER=}.')
3694  if boundary != 'natural':
3695    raise ValueError(f"{boundary=} must equal 'natural'.")
3696  del cval
3697  array2 = tf.convert_to_tensor(array)
3698  ndim = len(array2.shape)
3699  del array
3700  assert 1 <= ndim <= 3
3701  shape = tuple(shape)
3702  _check_eq(len(shape), 2 if ndim >= 2 else 1)
3703  match ndim:
3704    case 1:
3705      return tf_image_resize(array2[None], (1, *shape), filter=filter, antialias=antialias)[0]
3706    case 2:
3707      return tf_image_resize(array2[..., None], shape, filter=filter, antialias=antialias)[..., 0]
3708    case _:
3709      method = _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER[filter]
3710      return tf.image.resize(array2, shape, method=method, antialias=antialias)
3711
3712
3713_TORCH_INTERPOLATE_MODE_FROM_FILTER = {
3714    'impulse': 'nearest-exact',  # ('nearest' matches buggy OpenCV's INTER_NEAREST)
3715    'trapezoid': 'area',
3716    'triangle': 'bilinear',
3717    'sharpcubic': 'bicubic',
3718}
3719
3720
3721def torch_nn_resize(
3722    array: _ArrayLike,
3723    /,
3724    shape: Iterable[int],
3725    *,
3726    filter: str,
3727    boundary: str = 'clamp',
3728    cval: float = 0.0,
3729    antialias: bool = False,
3730) -> _TorchTensor:
3731  """Invoke `torch.nn.functional.interpolate` using the same parameters as `resize`."""
3732  import torch
3733
3734  if filter not in _TORCH_INTERPOLATE_MODE_FROM_FILTER:
3735    raise ValueError(f'{filter=} not in {_TORCH_INTERPOLATE_MODE_FROM_FILTER=}.')
3736  if boundary != 'clamp':
3737    raise ValueError(f"{boundary=} must equal 'clamp'.")
3738  del cval
3739  a = torch.as_tensor(array)
3740  del array
3741  assert 1 <= a.ndim <= 3
3742  shape = tuple(shape)
3743  _check_eq(len(shape), 2 if a.ndim >= 2 else 1)
3744  mode = _TORCH_INTERPOLATE_MODE_FROM_FILTER[filter]
3745
3746  def local_resize(a: _TorchTensor) -> _TorchTensor:
3747    # For upsampling, BILINEAR antialias is same PSNR and slower,
3748    #  and BICUBIC antialias is worse PSNR and faster.
3749    # For downsampling, antialias improves PSNR for both BILINEAR and BICUBIC.
3750    # Default align_corners=None corresponds to False which is what we desire.
3751    return torch.nn.functional.interpolate(a, shape, mode=mode, antialias=antialias)
3752
3753  match a.ndim:
3754    case 1:
3755      shape = (1, *shape)
3756      return local_resize(a[None, None, None])[0, 0, 0]
3757    case 2:
3758      return local_resize(a[None, None])[0, 0]
3759    case _:
3760      return local_resize(a.moveaxis(2, 0)[None])[0].moveaxis(0, 2)
3761
3762
3763def jax_image_resize(
3764    array: _ArrayLike,
3765    /,
3766    shape: Iterable[int],
3767    *,
3768    filter: str,
3769    boundary: str = 'natural',
3770    cval: float = 0.0,
3771    scale: float | Iterable[float] = 1.0,
3772    translate: float | Iterable[float] = 0.0,
3773) -> _JaxArray:
3774  """Invoke `jax.image.scale_and_translate` using the same parameters as `resize`."""
3775  import jax.image
3776  import jax.numpy as jnp
3777
3778  filters = 'triangle cubic lanczos3 lanczos5'.split()
3779  if filter not in filters:
3780    raise ValueError(f'{filter=} not in {filters=}.')
3781  if boundary != 'natural':
3782    raise ValueError(f"{boundary=} must equal 'natural'.")
3783  # When `scale` or `translate` are applied, any region outside the unit domain is assigned value 0.
3784  # To be consistent, the parameter `cval` must be zero.
3785  if scale != 1.0 and cval != 0.0:
3786    raise ValueError(f'Non-unity {scale=} requires that {cval=} be zero.')
3787  if translate != 0.0 and cval != 0.0:
3788    raise ValueError(f'Nonzero {translate=} requires that {cval=} be zero.')
3789  array2 = jnp.asarray(array)
3790  del array
3791  shape = tuple(shape)
3792  assert len(shape) <= array2.ndim
3793  completed_shape = shape + (1,) * (array2.ndim - len(shape))
3794  spatial_dims = list(range(len(shape)))
3795  scale2 = np.broadcast_to(np.array(scale), len(shape))
3796  scale2 = scale2 / np.array(array2.shape[: len(shape)]) * np.array(shape)
3797  translate2 = np.broadcast_to(np.array(translate), len(shape))
3798  translate2 = translate2 * np.array(shape)
3799  return jax.image.scale_and_translate(
3800      array2, completed_shape, spatial_dims, scale2, translate2, filter
3801  )
3802
3803
3804_CANDIDATE_RESIZERS = {
3805    'resampler.resize': resize,
3806    'PIL.Image.resize': pil_image_resize,
3807    'cv.resize': cv_resize,
3808    'scipy.ndimage.map_coordinates': scipy_ndimage_resize,
3809    'skimage.transform.resize': skimage_transform_resize,
3810    'tf.image.resize': tf_image_resize,
3811    'torch.nn.functional.interpolate': torch_nn_resize,
3812    'jax.image.scale_and_translate': jax_image_resize,
3813}
3814
3815
3816def _resizer_is_available(library_function: str) -> bool:
3817  """Return whether the resizer is available as an installed package."""
3818  top_name = library_function.split('.', 1)[0]
3819  module = {'PIL': 'Pillow', 'cv': 'cv2', 'tf': 'tensorflow'}.get(top_name, top_name)
3820  return importlib.util.find_spec(module) is not None
3821
3822
3823_RESIZERS = {
3824    library_function: resizer
3825    for library_function, resizer in _CANDIDATE_RESIZERS.items()
3826    if _resizer_is_available(library_function)
3827}
3828
3829
3830def _find_closest_filter(filter: str, resizer: Callable[..., Any]) -> str:
3831  """Return the filter supported by `resizer` (i.e., `*_resize`) that is closest to `filter`."""
3832  match filter:
3833    case 'box_like':
3834      return {
3835          cv_resize: 'trapezoid',
3836          skimage_transform_resize: 'box',
3837          tf_image_resize: 'trapezoid',
3838          torch_nn_resize: 'trapezoid',
3839      }.get(resizer, 'box')
3840    case 'cubic_like':
3841      return {
3842          cv_resize: 'sharpcubic',
3843          scipy_ndimage_resize: 'cardinal3',
3844          skimage_transform_resize: 'cardinal3',
3845          torch_nn_resize: 'sharpcubic',
3846      }.get(resizer, 'cubic')
3847    case 'high_quality':
3848      return {
3849          pil_image_resize: 'lanczos3',
3850          cv_resize: 'lanczos4',
3851          scipy_ndimage_resize: 'cardinal5',
3852          skimage_transform_resize: 'cardinal5',
3853          torch_nn_resize: 'sharpcubic',
3854      }.get(resizer, 'lanczos5')
3855    case _:
3856      return filter
3857
3858
3859# For Emacs:
3860# Local Variables:
3861# fill-column: 100
3862# End:
using_numba = False
def is_available(arraylib: str) -> bool:
806def is_available(arraylib: str) -> bool:
807  """Return whether the array library (e.g. 'tensorflow') is available as an installed package."""
808  return importlib.util.find_spec(arraylib) is not None  # Faster than trying to import it.

Return whether the array library (e.g. 'tensorflow') is available as an installed package.

ARRAYLIBS: list[str] = ['numpy']

Array libraries supported automatically in the resize and resampling operations.

  • The library is selected automatically based on the type of the array function parameter.

  • The class _Arraylib provides library-specific implementations of needed basic functions.

  • The _arr_*() functions dispatch the _Arraylib methods based on the array type.

@dataclasses.dataclass(frozen=True)
class Gridtype(abc.ABC):
1034@dataclasses.dataclass(frozen=True)
1035class Gridtype(abc.ABC):
1036  """Abstract base class for grid-types such as `'dual'` and `'primal'`.
1037
1038  In resampling operations, the grid-type may be specified separately as `src_gridtype` for the
1039  source domain and `dst_gridtype` for the destination domain.  Moreover, the grid-type may be
1040  specified per domain dimension.
1041
1042  Examples:
1043    `resize(source, shape, gridtype='primal')`  # Sets both src and dst to be `'primal'` grids.
1044
1045    `resize(source, shape, src_gridtype=['dual', 'primal'],
1046            dst_gridtype='dual')`  # Source is `'dual'` in dim0 and `'primal'` in dim1.
1047  """
1048
1049  name: str
1050  """Gridtype name."""
1051
1052  @abc.abstractmethod
1053  def min_size(self) -> int:
1054    """Return the necessary minimum number of grid samples."""
1055
1056  @abc.abstractmethod
1057  def size_in_samples(self, size: int, /) -> int:
1058    """Return the domain size in units of inter-sample spacing."""
1059
1060  @abc.abstractmethod
1061  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1062    """Return [0.0, 1.0] coordinates given [0, size - 1] indices."""
1063
1064  @abc.abstractmethod
1065  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1066    """Return location x given coordinates [0.0, 1.0], where x == 0.0 is the first grid sample
1067    and x == size - 1.0 is the last grid sample."""
1068
1069  @abc.abstractmethod
1070  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1071    """Map integer sample indices to interior ones using boundary reflection."""
1072
1073  @abc.abstractmethod
1074  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1075    """Map integer sample indices to interior ones using wrapping."""
1076
1077  @abc.abstractmethod
1078  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1079    """Map integer sample indices to interior ones using reflect-clamp."""

Abstract base class for grid-types such as 'dual' and 'primal'.

In resampling operations, the grid-type may be specified separately as src_gridtype for the source domain and dst_gridtype for the destination domain. Moreover, the grid-type may be specified per domain dimension.

Examples:

resize(source, shape, gridtype='primal') # Sets both src and dst to be 'primal' grids.

resize(source, shape, src_gridtype=['dual', 'primal'], dst_gridtype='dual') # Source is 'dual' in dim0 and 'primal' in dim1.

GRIDTYPES: list[str] = ['dual', 'primal']

Shortcut names for the two predefined grid types (specified per dimension):

gridtype 'dual'
DualGridtype()
(default)
'primal'
PrimalGridtype()
 
Sample positions in 2D
and in 1D at different resolutions
Dual Primal
Nesting of samples across resolutions The samples positions do not nest. The even samples remain at coarser scale.
Number $N_\ell$ of samples (per-dimension) at resolution level $\ell$ $N_\ell=2^\ell$ $N_\ell=2^\ell+1$
Position of sample index $i$ within domain $[0, 1]$ $\frac{i + 0.5}{N}$ ("half-integer" coordinates) $\frac{i}{N-1}$
Image resolutions ($N_\ell\times N_\ell$) for dyadic scales $1\times1, ~~2\times2, ~~4\times4, ~~8\times8, ~\ldots$ $2\times2, ~~3\times3, ~~5\times5, ~~9\times9, ~\ldots$

See the source code for extensibility.

@dataclasses.dataclass(frozen=True)
class Boundary:
1466@dataclasses.dataclass(frozen=True)
1467class Boundary:
1468  """Domain boundary rules.  These define the reconstruction over the source domain near and beyond
1469  the domain boundaries.  The rules may be specified separately for each domain dimension."""
1470
1471  name: str = ''
1472  """Boundary rule name."""
1473
1474  coord_remap: RemapCoordinates = NoRemapCoordinates()
1475  """Modify specified coordinates prior to evaluating the reconstruction kernels."""
1476
1477  extend_samples: ExtendSamples = ReflectExtendSamples()
1478  """Define the value of each grid sample outside the unit domain as an affine combination of
1479  interior sample(s) and possibly the constant value (`cval`)."""
1480
1481  override_value: OverrideExteriorValue = NoOverrideExteriorValue()
1482  """Set the value outside some extent to a constant value (`cval`)."""
1483
1484  @property
1485  def uses_cval(self) -> bool:
1486    """True if weights may be non-affine, involving the constant value (`cval`)."""
1487    return self.extend_samples.uses_cval or self.override_value.uses_cval
1488
1489  def preprocess_coordinates(self, point: _NDArray, /) -> _NDArray:
1490    """Modify coordinates prior to evaluating the filter kernels."""
1491    # Antialiasing across the tile boundaries may be feasible but seems hard.
1492    point = self.coord_remap(point)
1493    return point
1494
1495  def apply(
1496      self, index: _NDArray, weight: _NDArray, point: _NDArray, size: int, gridtype: Gridtype, /
1497  ) -> tuple[_NDArray, _NDArray]:
1498    """Replace exterior samples by combinations of interior samples."""
1499    index, weight = self.extend_samples(index, weight, size, gridtype)
1500    self.override_reconstruction(weight, point)
1501    return index, weight
1502
1503  def override_reconstruction(self, weight: _NDArray, point: _NDArray, /) -> None:
1504    """For points outside an extent, modify weight to zero to assign `cval`."""
1505    self.override_value(weight, point)

Domain boundary rules. These define the reconstruction over the source domain near and beyond the domain boundaries. The rules may be specified separately for each domain dimension.

BOUNDARIES: list[str] = ['reflect', 'wrap', 'tile', 'clamp', 'border', 'natural', 'linear_constant', 'quadratic_constant', 'reflect_clamp', 'constant', 'linear', 'quadratic']

Shortcut names for some predefined boundary rules (as defined by _DICT_BOUNDARIES):

name a.k.a. / comments
'reflect' reflected, symm, symmetric, mirror, grid-mirror
'wrap' periodic, repeat, grid-wrap
'tile' like 'reflect' within unit domain, then tile discontinuously
'clamp' clamped, nearest, edge, clamp-to-edge, repeat last sample
'border' grid-constant, use cval for samples outside unit domain
'natural' renormalize using only interior samples, use cval outside domain
'reflect_clamp' mirror-clamp-to-edge
'constant' like 'reflect' but replace by cval outside unit domain
'linear' extrapolate from 2 last samples
'quadratic' extrapolate from 3 last samples
'linear_constant' like 'linear' but replace by cval outside unit domain
'quadratic_constant' like 'quadratic' but replace by cval outside unit domain

These boundary rules may be specified per dimension. See the source code for extensibility using the classes RemapCoordinates, ExtendSamples, and OverrideExteriorValue.

Boundary rules illustrated in 1D:

Boundary rules illustrated in 2D:

@dataclasses.dataclass(frozen=True)
class Filter(abc.ABC):
1586@dataclasses.dataclass(frozen=True)
1587class Filter(abc.ABC):
1588  """Abstract base class for filter kernel functions.
1589
1590  Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support
1591  interval [-radius, radius].  (Some sites instead define kernels over the interval [0, N]
1592  where N = 2 * radius.)
1593
1594  Portions of this code are adapted from the C++ library in
1595  https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp
1596
1597  See also https://hhoppe.com/proj/filtering/.
1598  """
1599
1600  name: str
1601  """Filter kernel name."""
1602
1603  radius: float
1604  """Max absolute value of x for which self(x) is nonzero."""
1605
1606  interpolating: bool = True
1607  """True if self(0) == 1.0 and self(i) == 0.0 for all nonzero integers i."""
1608
1609  continuous: bool = True
1610  """True if the kernel function has $C^0$ continuity."""
1611
1612  partition_of_unity: bool = True
1613  """True if the convolution of the kernel with a Dirac comb reproduces the
1614  unity function."""
1615
1616  unit_integral: bool = True
1617  """True if the integral of the kernel function is 1."""
1618
1619  requires_digital_filter: bool = False
1620  """True if the filter needs a pre/post digital filter for interpolation."""
1621
1622  @abc.abstractmethod
1623  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1624    """Return evaluation of filter kernel at locations x."""

Abstract base class for filter kernel functions.

Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support interval [-radius, radius]. (Some sites instead define kernels over the interval [0, N] where N = 2 * radius.)

Portions of this code are adapted from the C++ library in https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp

See also https://hhoppe.com/proj/filtering/.

class ImpulseFilter(Filter):
1627class ImpulseFilter(Filter):
1628  """See https://en.wikipedia.org/wiki/Dirac_delta_function."""
1629
1630  def __init__(self) -> None:
1631    super().__init__(name='impulse', radius=1e-20, continuous=False, partition_of_unity=False)
1632
1633  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1634    raise AssertionError('The Impulse is infinitely narrow, so cannot be directly evaluated.')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class BoxFilter(Filter):
1637class BoxFilter(Filter):
1638  """See https://en.wikipedia.org/wiki/Box_function.
1639
1640  The kernel function has value 1.0 over the half-open interval [-.5, .5).
1641  """
1642
1643  def __init__(self) -> None:
1644    super().__init__(name='box', radius=0.5, continuous=False)
1645
1646  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1647    use_asymmetric = True
1648    if use_asymmetric:
1649      x = np.asarray(x)
1650      return np.where((-0.5 <= x) & (x < 0.5), 1.0, 0.0)
1651    x = np.abs(x)
1652    return np.where(x < 0.5, 1.0, np.where(x == 0.5, 0.5, 0.0))

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class TrapezoidFilter(Filter):
1655class TrapezoidFilter(Filter):
1656  """Filter for antialiased "area-based" filtering.
1657
1658  Args:
1659    radius: Specifies the support [-radius, radius] of the filter, where 0.5 < radius <= 1.0.
1660      The special case `radius = None` is a placeholder that indicates that the filter will be
1661      replaced by a trapezoid of the appropriate radius (based on scaling) for correct
1662      antialiasing in both minification and magnification.
1663
1664  This filter is similar to the BoxFilter but with linearly sloped sides.  It has value 1.0
1665  in the interval abs(x) <= 1.0 - radius and decreases linearly to value 0.0 in the interval
1666  1.0 - radius <= abs(x) <= radius, always with value 0.5 at x = 0.5.
1667  """
1668
1669  def __init__(self, *, radius: float | None = None) -> None:
1670    if radius is None:
1671      super().__init__(name='trapezoid', radius=0.0)
1672      return
1673    if not 0.5 < radius <= 1.0:
1674      raise ValueError(f'Radius {radius} is outside the range (0.5, 1.0].')
1675    super().__init__(name=f'trapezoid_{radius}', radius=radius)
1676
1677  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1678    x = np.abs(x)
1679    assert 0.5 < self.radius <= 1.0
1680    return ((0.5 + 0.25 / (self.radius - 0.5)) - (0.5 / (self.radius - 0.5)) * x).clip(0.0, 1.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class TriangleFilter(Filter):
1683class TriangleFilter(Filter):
1684  """See https://en.wikipedia.org/wiki/Triangle_function.
1685
1686  Also known as the hat or tent function.  It is used for piecewise-linear
1687  (or bilinear, or trilinear, ...) interpolation.
1688  """
1689
1690  def __init__(self) -> None:
1691    super().__init__(name='triangle', radius=1.0)
1692
1693  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1694    return (1.0 - np.abs(x)).clip(0.0, 1.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CubicFilter(Filter):
1697class CubicFilter(Filter):
1698  """Family of cubic filters parameterized by two scalar parameters.
1699
1700  Args:
1701    b: first scalar parameter.
1702    c: second scalar parameter.
1703
1704  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1705  https://doi.org/10.1145/378456.378514.
1706
1707  [D. P. Mitchell and A. N. Netravali. Reconstruction filters in computer graphics.
1708  Computer Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1709
1710  - The filter has quadratic precision iff b + 2 * c == 1.
1711  - The filter is interpolating iff b == 0.
1712  - (b=1, c=0) is the (non-interpolating) cubic B-spline basis;
1713  - (b=1/3, c=1/3) is the Mitchell filter;
1714  - (b=0, c=0.5) is the Catmull-Rom spline (which has cubic precision);
1715  - (b=0, c=0.75) is the "sharper cubic" used in Photoshop and OpenCV.
1716  """
1717
1718  def __init__(self, *, b: float, c: float, name: str | None = None) -> None:
1719    name = f'cubic_b{b}_c{c}' if name is None else name
1720    interpolating = b == 0
1721    super().__init__(name=name, radius=2.0, interpolating=interpolating)
1722    self.b, self.c = b, c
1723
1724  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1725    x = np.abs(x)
1726    b, c = self.b, self.c
1727    f3, f2, f0 = 2 - 9 / 6 * b - c, -3 + 2 * b + c, 1 - 1 / 3 * b
1728    g3, g2, g1, g0 = -b / 6 - c, b + 5 * c, -2 * b - 8 * c, 8 / 6 * b + 4 * c
1729    # (np.polynomial.polynomial.polyval(x, [f0, 0, f2, f3]) is almost
1730    # twice as slow; see also https://stackoverflow.com/questions/24065904)
1731    v01 = ((f3 * x + f2) * x) * x + f0
1732    v12 = ((g3 * x + g2) * x + g1) * x + g0
1733    return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CatmullRomFilter(CubicFilter):
1736class CatmullRomFilter(CubicFilter):
1737  """Cubic filter with cubic precision.  Also known as Keys filter.
1738
1739  [E. Catmull, R. Rom.  A class of local interpolating splines.  Computer aided geometric
1740  design, 1974]
1741  [Wikipedia](https://en.wikipedia.org/wiki/Cubic_Hermite_spline#Catmull%E2%80%93Rom_spline)
1742
1743  [R. G. Keys.  Cubic convolution interpolation for digital image processing.
1744  IEEE Trans. on Acoustics, Speech, and Signal Processing, 29(6), 1981.]
1745  https://ieeexplore.ieee.org/document/1163711/.
1746  """
1747
1748  def __init__(self) -> None:
1749    super().__init__(b=0, c=0.5, name='cubic')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class MitchellFilter(CubicFilter):
1752class MitchellFilter(CubicFilter):
1753  """See https://doi.org/10.1145/378456.378514.
1754
1755  [D. P. Mitchell and A. N. Netravali.  Reconstruction filters in computer graphics.  Computer
1756  Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1757  """
1758
1759  def __init__(self) -> None:
1760    super().__init__(b=1 / 3, c=1 / 3, name='mitchell')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class SharpCubicFilter(CubicFilter):
1763class SharpCubicFilter(CubicFilter):
1764  """Cubic filter that is sharper than Catmull-Rom filter.
1765
1766  Used by some tools including OpenCV and Photoshop.
1767
1768  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1769  https://entropymine.com/resamplescope/notes/photoshop/.
1770  """
1771
1772  def __init__(self) -> None:
1773    super().__init__(b=0, c=0.75, name='sharpcubic')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class LanczosFilter(Filter):
1776class LanczosFilter(Filter):
1777  """High-quality filter: sinc function modulated by a sinc window.
1778
1779  Args:
1780    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1781    sampled: If True, use a discretized approximation for improved speed.
1782
1783  See https://en.wikipedia.org/wiki/Lanczos_kernel.
1784  """
1785
1786  def __init__(self, *, radius: int, sampled: bool = True) -> None:
1787    super().__init__(
1788        name=f'lanczos_{radius}', radius=radius, partition_of_unity=False, unit_integral=False
1789    )
1790
1791    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1792    def _eval(x: _ArrayLike) -> _NDArray:
1793      x = np.abs(x)
1794      # Note that window[n] = sinc(2*n/N - 1), with 0 <= n <= N.
1795      # But, x = n - N/2, or equivalently, n = x + N/2, with -N/2 <= x <= N/2.
1796      window = _sinc(x / radius)  # Zero-phase function w_0(x).
1797      return np.where(x < radius, _sinc(x) * window, 0.0)
1798
1799    self._function = _eval
1800
1801  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1802    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class GeneralizedHammingFilter(Filter):
1805class GeneralizedHammingFilter(Filter):
1806  """Sinc function modulated by a Hamming window.
1807
1808  Args:
1809    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1810    a0: Scalar parameter, where 0.0 < a0 < 1.0.  The case of a0=0.5 is the Hann filter.
1811
1812  See https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows,
1813  and hamming() in https://github.com/scipy/scipy/blob/main/scipy/signal/windows/_windows.py.
1814
1815  Note that `'hamming3'` is `(radius=3, a0=25/46)`, which close to but different from `a0=0.54`.
1816
1817  See also np.hamming() and np.hanning().
1818  """
1819
1820  def __init__(self, *, radius: int, a0: float) -> None:
1821    super().__init__(
1822        name=f'hamming_{radius}',
1823        radius=radius,
1824        partition_of_unity=False,  # 1:1.00242  av=1.00188  sd=0.00052909
1825        unit_integral=False,  # 1.00188
1826    )
1827    assert 0.0 < a0 < 1.0
1828    self.a0 = a0
1829
1830  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1831    x = np.abs(x)
1832    # Note that window[n] = a0 - (1 - a0) * cos(2 * pi * n / N), 0 <= n <= N.
1833    # With n = x + N/2, we get the zero-phase function w_0(x):
1834    window = self.a0 + (1.0 - self.a0) * np.cos(np.pi / self.radius * x)
1835    return np.where(x < self.radius, _sinc(x) * window, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class KaiserFilter(Filter):
1838class KaiserFilter(Filter):
1839  """Sinc function modulated by a Kaiser-Bessel window.
1840
1841  See https://en.wikipedia.org/wiki/Kaiser_window, and example use in:
1842  [Karras et al. 20201.  Alias-free generative adversarial networks.
1843  https://arxiv.org/pdf/2106.12423.pdf].
1844
1845  See also np.kaiser().
1846
1847  Args:
1848    radius: Value L/2 in the definition.  It may be fractional for a (digital) resizing filter
1849      (sample spacing s != 1) with an even number of samples (dual grid), e.g., Eq. (6)
1850      in [Karras et al. 2021] --- this effects the precise shape of the window function.
1851    beta: Determines the trade-off between main-lobe width and side-lobe level.
1852    sampled: If True, use a discretized approximation for improved speed.
1853  """
1854
1855  def __init__(self, *, radius: float, beta: float, sampled: bool = True) -> None:
1856    assert beta >= 0.0
1857    super().__init__(
1858        name=f'kaiser_{radius}_{beta}', radius=radius, partition_of_unity=False, unit_integral=False
1859    )
1860
1861    @_cache_sampled_1d_function(xmin=-math.ceil(radius), xmax=math.ceil(radius), enable=sampled)
1862    def _eval(x: _ArrayLike) -> _NDArray:
1863      x = np.abs(x)
1864      window = np.i0(beta * np.sqrt((1.0 - np.square(x / radius)).clip(0.0, 1.0))) / np.i0(beta)
1865      return np.where(x <= radius + 1e-6, _sinc(x) * window, 0.0)
1866
1867    self._function = _eval
1868
1869  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1870    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class BsplineFilter(Filter):
1873class BsplineFilter(Filter):
1874  """B-spline of a non-negative degree.
1875
1876  Args:
1877    degree: The polynomial degree of the B-spline segments.
1878      With `degree=0`, it is like `BoxFilter` except with f(0.5) = f(-0.5) = 0.
1879      With `degree=1`, it is identical to `TriangleFilter`.
1880      With `degree >= 2`, it is no longer interpolating.
1881
1882  See [Carl de Boor.  A practical guide to splines.  Springer, 2001.]
1883  https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.BSpline.html
1884  """
1885
1886  def __init__(self, *, degree: int) -> None:
1887    if degree < 0:
1888      raise ValueError(f'Bspline of degree {degree} is invalid.')
1889    radius = (degree + 1) / 2
1890    interpolating = degree <= 1
1891    super().__init__(name=f'bspline{degree}', radius=radius, interpolating=interpolating)
1892    t = list(range(degree + 2))
1893    self._bspline = scipy.interpolate.BSpline.basis_element(t)
1894
1895  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1896    x = np.abs(x)
1897    return np.where(x < self.radius, self._bspline(x + self.radius), 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CardinalBsplineFilter(Filter):
1900class CardinalBsplineFilter(Filter):
1901  """Interpolating B-spline, achieved with aid of digital pre or post filter.
1902
1903  Args:
1904    degree: The polynomial degree of the B-spline segments.
1905    sampled: If True, use a discretized approximation for improved speed.
1906
1907  See [Hou and Andrews.  Cubic splines for image interpolation and digital filtering, 1978] and
1908  [Unser et al.  Fast B-spline transforms for continuous image representation and interpolation,
1909  1991].
1910  """
1911
1912  def __init__(self, *, degree: int, sampled: bool = True) -> None:
1913    self.degree = degree
1914    if degree < 0:
1915      raise ValueError(f'Bspline of degree {degree} is invalid.')
1916    radius = (degree + 1) / 2
1917    super().__init__(
1918        name=f'cardinal{degree}',
1919        radius=radius,
1920        requires_digital_filter=degree >= 2,
1921        continuous=degree >= 1,
1922    )
1923    t = list(range(degree + 2))
1924    bspline = scipy.interpolate.BSpline.basis_element(t)
1925
1926    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1927    def _eval(x: _ArrayLike) -> _NDArray:
1928      x = np.abs(x)
1929      return np.where(x < radius, bspline(x + radius), 0.0)
1930
1931    self._function = _eval
1932
1933  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1934    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class OmomsFilter(Filter):
1937class OmomsFilter(Filter):
1938  """OMOMS interpolating filter, with aid of digital pre or post filter.
1939
1940  Args:
1941    degree: The polynomial degree of the filter segments.
1942
1943  Optimal MOMS (maximal-order-minimal-support) function; see [Blu and Thevenaz, MOMS: Maximal-order
1944  interpolation of minimal support, 2001].
1945  https://infoscience.epfl.ch/record/63074/files/blu0101.pdf
1946  """
1947
1948  def __init__(self, *, degree: int) -> None:
1949    if degree not in (3, 5):
1950      raise ValueError(f'Degree {degree} not supported.')
1951    super().__init__(name=f'omoms{degree}', radius=(degree + 1) / 2, requires_digital_filter=True)
1952    self.degree = degree
1953
1954  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1955    x = np.abs(x)
1956    match self.degree:
1957      case 3:
1958        v01 = ((0.5 * x - 1.0) * x + 3 / 42) * x + 26 / 42
1959        v12 = ((-7 / 42 * x + 1.0) * x - 85 / 42) * x + 58 / 42
1960        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1961      case 5:
1962        v01 = ((((-1 / 12 * x + 1 / 4) * x - 5 / 99) * x - 9 / 22) * x - 1 / 792) * x + 229 / 440
1963        v12 = (
1964            (((1 / 24 * x - 3 / 8) * x + 505 / 396) * x - 83 / 44) * x + 1351 / 1584
1965        ) * x + 839 / 2640
1966        v23 = (
1967            (((-1 / 120 * x + 1 / 8) * x - 299 / 396) * x + 101 / 44) * x - 27811 / 7920
1968        ) * x + 5707 / 2640
1969        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, np.where(x < 3.0, v23, 0.0)))
1970      case _:
1971        raise ValueError(self.degree)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class GaussianFilter(Filter):
1974class GaussianFilter(Filter):
1975  r"""See https://en.wikipedia.org/wiki/Gaussian_function.
1976
1977  Args:
1978    standard_deviation: Sets the Gaussian $\sigma$.  The default value is 1.25/3.0, which
1979      creates a kernel that is as-close-as-possible to a partition of unity.
1980  """
1981
1982  DEFAULT_STANDARD_DEVIATION = 1.25 / 3.0
1983  """This value creates a kernel that is as-close-as-possible to a partition of unity; see
1984  mesh_processing/test/GridOp_test.cpp: `0.93503:1.06497     av=1           sd=0.0459424`.
1985  Another possibility is 0.5, as suggested on p. 4 of [Ken Turkowski.  Filters for common
1986  resampling tasks, 1990] for kernels with a support of 3 pixels.
1987  https://cadxfem.org/inf/ResamplingFilters.pdf
1988  """
1989
1990  def __init__(self, *, standard_deviation: float = DEFAULT_STANDARD_DEVIATION) -> None:
1991    super().__init__(
1992        name=f'gaussian_{standard_deviation:.3f}',
1993        radius=np.ceil(8.0 * standard_deviation),  # Sufficiently large.
1994        interpolating=False,
1995        partition_of_unity=False,
1996    )
1997    self.standard_deviation = standard_deviation
1998
1999  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2000    x = np.abs(x)
2001    sdv = self.standard_deviation
2002    v0r = np.exp(np.square(x / sdv) / -2.0) / (np.sqrt(math.tau) * sdv)
2003    return np.where(x < self.radius, v0r, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class NarrowBoxFilter(Filter):
2006class NarrowBoxFilter(Filter):
2007  """Compact footprint, used for visualization of grid sample location.
2008
2009  Args:
2010    radius: Specifies the support [-radius, radius] of the narrow box function.  (The default
2011      value 0.199 is an inexact 0.2 to avoid numerical ambiguities.)
2012  """
2013
2014  def __init__(self, *, radius: float = 0.199) -> None:
2015    super().__init__(
2016        name='narrowbox',
2017        radius=radius,
2018        continuous=False,
2019        unit_integral=False,
2020        partition_of_unity=False,
2021    )
2022
2023  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2024    radius = self.radius
2025    magnitude = 1.0
2026    x = np.asarray(x)
2027    return np.where((-radius <= x) & (x < radius), magnitude, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

FILTERS: list[str] = ['impulse', 'box', 'trapezoid', 'triangle', 'cubic', 'sharpcubic', 'lanczos3', 'lanczos5', 'lanczos10', 'cardinal3', 'cardinal5', 'omoms3', 'omoms5', 'hamming3', 'kaiser3', 'gaussian', 'bspline3', 'mitchell', 'narrowbox']

Shortcut names for some predefined filter kernels (specified per dimension). The names expand to:

name Filter a.k.a. / comments
'impulse' ImpulseFilter() nearest
'box' BoxFilter() non-antialiased box, e.g. ImageMagick
'trapezoid' TrapezoidFilter() area antialiasing, e.g. cv.INTER_AREA
'triangle' TriangleFilter() linear (bilinear in 2D), spline order=1
'cubic' CatmullRomFilter() catmullrom, keys, bicubic
'sharpcubic' SharpCubicFilter() cv.INTER_CUBIC, torch 'bicubic'
'lanczos3' LanczosFilter(radius=3) support window [-3, 3]
'lanczos5' LanczosFilter(radius=5) [-5, 5]
'lanczos10' LanczosFilter(radius=10) [-10, 10]
'cardinal3' CardinalBsplineFilter(degree=3) spline interpolation, order=3, GF
'cardinal5' CardinalBsplineFilter(degree=5) spline interpolation, order=5, GF
'omoms3' OmomsFilter(degree=3) non-$C^1$, [-3, 3], GF
'omoms5' OmomsFilter(degree=5) non-$C^1$, [-5, 5], GF
'hamming3' GeneralizedHammingFilter(...) (radius=3, a0=25/46)
'kaiser3' KaiserFilter(radius=3.0, beta=7.12)
'gaussian' GaussianFilter() non-interpolating, default $\sigma=1.25/3$
'bspline3' BsplineFilter(degree=3) non-interpolating
'mitchell' MitchellFilter() mitchellcubic
'narrowbox' NarrowBoxFilter() for visualization of sample positions

The comment label GF denotes a generalized filter, formed as the composition of a finitely supported kernel and a discrete inverse convolution.

Some example filter kernels:


A more extensive set of filters is presented here in the notebook, together with visualizations and analyses of the filter properties. See the source code for extensibility.

@dataclasses.dataclass(frozen=True)
class Gamma(abc.ABC):
2135@dataclasses.dataclass(frozen=True)
2136class Gamma(abc.ABC):
2137  """Abstract base class for transfer functions on sample values.
2138
2139  Image/video content is often stored using a color component transfer function.
2140  See https://en.wikipedia.org/wiki/Gamma_correction.
2141
2142  Converts between integer types and [0.0, 1.0] internal value range.
2143  """
2144
2145  name: str
2146  """Name of component transfer function."""
2147
2148  @abc.abstractmethod
2149  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2150    """Decode source sample values into floating-point, possibly nonlinearly.
2151
2152    Uint source values are mapped to the range [0.0, 1.0].
2153    """
2154
2155  @abc.abstractmethod
2156  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2157    """Encode float signal into destination samples, possibly nonlinearly.
2158
2159    Uint destination values are mapped from the range [0.0, 1.0].
2160
2161    Note that non-integer destination types are not clipped to the range [0.0, 1.0].
2162    If that is desired, it can be performed as a postprocess using `output.clip(0.0, 1.0)`.
2163    """

Abstract base class for transfer functions on sample values.

Image/video content is often stored using a color component transfer function. See https://en.wikipedia.org/wiki/Gamma_correction.

Converts between integer types and [0.0, 1.0] internal value range.

class IdentityGamma(Gamma):
2166class IdentityGamma(Gamma):
2167  """Identity component transfer function."""
2168
2169  def __init__(self) -> None:
2170    super().__init__('identity')
2171
2172  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2173    dtype = np.dtype(dtype)
2174    assert np.issubdtype(dtype, np.inexact)
2175    if np.issubdtype(_arr_dtype(array), np.unsignedinteger):
2176      return _to_float_01(array, dtype)
2177    return _arr_astype(array, dtype)
2178
2179  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2180    dtype = np.dtype(dtype)
2181    assert np.issubdtype(dtype, np.number)
2182    if np.issubdtype(dtype, np.unsignedinteger):
2183      return _from_float(_arr_clip(array, 0.0, 1.0), dtype)
2184    if np.issubdtype(dtype, np.integer):
2185      return _arr_astype(array + 0.5, dtype)
2186    return _arr_astype(array, dtype)

Gamma(name: 'str')

GAMMAS: list[str] = ['identity', 'power2', 'power22', 'srgb']

Shortcut names for some predefined gamma-correction schemes:

name Gamma Decoding function
(linear space from stored value)
Encoding function
(stored value from linear space)
'identity' IdentityGamma() $l = e$ $e = l$
'power2' PowerGamma(2.0) $l = e^{2.0}$ $e = l^{1/2.0}$
'power22' PowerGamma(2.2) $l = e^{2.2}$ $e = l^{1/2.2}$
'srgb' SrgbGamma() $l = \left(\left(e + 0.055\right) / 1.055\right)^{2.4}$ $e = l^{1/2.4} * 1.055 - 0.055$

See the source code for extensibility.

def resize( array: Array, /, shape: Iterable[int], *, gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, boundary: str | Boundary | Iterable[str | Boundary] = 'auto', cval: ArrayLike = 0, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, gamma: str | Gamma | None = None, src_gamma: str | Gamma | None = None, dst_gamma: str | Gamma | None = None, scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0, precision: DTypeLike = None, dtype: DTypeLike = None, dim_order: Iterable[int] | None = None, num_threads: Union[int, Literal['auto']] = 'auto') -> Array:
2601def resize(
2602    array: _Array,
2603    /,
2604    shape: Iterable[int],
2605    *,
2606    gridtype: str | Gridtype | None = None,
2607    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2608    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2609    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2610    cval: _ArrayLike = 0.0,
2611    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2612    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2613    gamma: str | Gamma | None = None,
2614    src_gamma: str | Gamma | None = None,
2615    dst_gamma: str | Gamma | None = None,
2616    scale: float | Iterable[float] = 1.0,
2617    translate: float | Iterable[float] = 0.0,
2618    precision: _DTypeLike = None,
2619    dtype: _DTypeLike = None,
2620    dim_order: Iterable[int] | None = None,
2621    num_threads: int | Literal['auto'] = 'auto',
2622) -> _Array:
2623  """Resample `array` (a grid of sample values) onto a grid with resolution `shape`.
2624
2625  The source `array` is any object recognized by `ARRAYLIBS`.  It is interpreted as a grid
2626  with `len(shape)` domain coordinate dimensions, where each grid sample value has shape
2627  `array.shape[len(shape):]`.
2628
2629  Some examples:
2630
2631  - A grayscale image has `array.shape = height, width` and resizing it with `len(shape) == 2`
2632    produces a new image of scalar values.
2633  - An RGB image has `array.shape = height, width, 3` and resizing it with `len(shape) == 2`
2634    produces a new image of RGB values.
2635  - An 3D grid of 3x3 Jacobians has `array.shape = Z, Y, X, 3, 3` and resizing it with
2636    `len(shape) == 3` produces a new 3D grid of Jacobians.
2637
2638  This function also allows scaling and translation from the source domain to the output domain
2639  through the parameters `scale` and `translate`.  For more general transforms, see `resample`.
2640
2641  Args:
2642    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
2643      The array must have numeric type.  Its first `len(shape)` dimensions are the domain
2644      coordinate dimensions.  Each grid dimension must be at least 1 for a `'dual'` grid or
2645      at least 2 for a `'primal'` grid.
2646    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2647      `array` must have at least as many dimensions as `len(shape)`.
2648    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
2649      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
2650      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
2651    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2652      Parameters `gridtype` and `src_gridtype` cannot both be set.
2653    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2654      Parameters `gridtype` and `dst_gridtype` cannot both be set.
2655    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2656      a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses `'reflect'`
2657      for upsampling and `'clamp'` for downsampling.
2658    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
2659      onto `array.shape[len(shape):]`.  It is subject to `src_gamma`.
2660    filter: The reconstruction kernel for each dimension in `shape`, specified as either a filter
2661      name in `FILTERS` or a `Filter` instance.  It is used during upsampling (i.e., magnification).
2662    prefilter: The prefilter kernel for each dimension in `shape`, specified as either a filter
2663      name in `FILTERS` or a `Filter` instance.  It is used during downsampling
2664      (i.e., minification).  If `None`, it inherits the value of `filter`.  The default
2665      `'lanczos3'` is good for natural images.  For vector graphics images, `'trapezoid'` is better
2666      because it avoids ringing artifacts.
2667    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from
2668      `array` and when creating output grid samples.  It is specified as either a name in `GAMMAS`
2669      or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default is
2670      `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
2671      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
2672      range [0.0, 1.0].
2673    src_gamma: Component transfer function used to "decode" `array` samples.
2674      Parameters `gamma` and `src_gamma` cannot both be set.
2675    dst_gamma: Component transfer function used to "encode" the output samples.
2676      Parameters `gamma` and `dst_gamma` cannot both be set.
2677    scale: Scaling factor applied to each dimension of the source domain when it is mapped onto
2678      the destination domain.
2679    translate: Offset applied to each dimension of the scaled source domain when it is mapped onto
2680      the destination domain.
2681    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
2682      on `array.dtype` and `dtype`.
2683    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
2684      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
2685      to the uint range.
2686    dim_order: Override the automatically selected order in which the grid dimensions are resized.
2687      Must contain a permutation of `range(len(shape))`.
2688    num_threads: Used to determine multithread parallelism if `array` is from `numpy`.  If set to
2689      `'auto'`, it is selected automatically.  Otherwise, it must be a positive integer.
2690
2691  Returns:
2692    An array of the same class as the source `array`, with shape `shape + array.shape[len(shape):]`
2693      and data type `dtype`.
2694
2695  **Example of image upsampling:**
2696
2697  >>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
2698  >>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.
2699
2700  <center>
2701  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_upsampled.png"/>
2702  </center>
2703
2704  **Example of image downsampling:**
2705
2706  >>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
2707  >>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
2708  >>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
2709  >>> downsampled = resize(array, (24, 48))
2710
2711  <center>
2712  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_downsampled2.png"/>
2713  </center>
2714
2715  **Unit test:**
2716
2717  >>> result = resize([1.0, 4.0, 5.0], shape=(4,))
2718  >>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
2719  """
2720  if isinstance(array, (tuple, list)):
2721    array = np.asarray(array)
2722  arraylib = _arr_arraylib(array)
2723  array_dtype = _arr_dtype(array)
2724  if not np.issubdtype(array_dtype, np.number):
2725    raise ValueError(f'Type {array.dtype} is not numeric.')
2726  shape2 = tuple(shape)
2727  array_ndim = len(array.shape)
2728  if not 0 < len(shape2) <= array_ndim:
2729    raise ValueError(f'Shape {array.shape} cannot be resized to {shape2}.')
2730  src_shape = array.shape[: len(shape2)]
2731  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2732      gridtype, src_gridtype, dst_gridtype, len(shape2), len(shape2)
2733  )
2734  boundary2 = np.broadcast_to(np.array(boundary), len(shape2))
2735  cval = np.broadcast_to(cval, array.shape[len(shape2) :])
2736  prefilter = filter if prefilter is None else prefilter
2737  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape2))]
2738  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), len(shape2))]
2739  dtype = array_dtype if dtype is None else np.dtype(dtype)
2740  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, array_dtype, dtype)
2741  scale2 = np.broadcast_to(np.array(scale), len(shape2))
2742  translate2 = np.broadcast_to(np.array(translate), len(shape2))
2743  del (shape, src_gridtype, dst_gridtype, boundary, filter, prefilter)
2744  del (src_gamma, dst_gamma, scale, translate)
2745  precision = _get_precision(precision, [array_dtype, dtype], [])
2746  weight_precision = _real_precision(precision)
2747
2748  is_noop = (
2749      all(src == dst for src, dst in zip(src_shape, shape2))
2750      and all(gt1 == gt2 for gt1, gt2 in zip(src_gridtype2, dst_gridtype2))
2751      and all(f.interpolating for f in prefilter2)
2752      and np.all(scale2 == 1.0)
2753      and np.all(translate2 == 0.0)
2754      and src_gamma2 == dst_gamma2
2755  )
2756  if is_noop:
2757    return array
2758
2759  if dim_order is None:
2760    dim_order = _arr_best_dims_order_for_resize(array, shape2)
2761  else:
2762    dim_order = tuple(dim_order)
2763    if sorted(dim_order) != list(range(len(shape2))):
2764      raise ValueError(f'{dim_order} not a permutation of {list(range(len(shape2)))}.')
2765
2766  array = src_gamma2.decode(array, precision)
2767  cval = _arr_numpy(src_gamma2.decode(cval, precision))
2768
2769  can_use_fast_box_downsampling = (
2770      using_numba
2771      and arraylib == 'numpy'
2772      and len(shape2) == 2
2773      and array_ndim in (2, 3)
2774      and all(src > dst for src, dst in zip(src_shape, shape2))
2775      and all(src % dst == 0 for src, dst in zip(src_shape, shape2))
2776      and all(gridtype.name == 'dual' for gridtype in src_gridtype2)
2777      and all(gridtype.name == 'dual' for gridtype in dst_gridtype2)
2778      and all(f.name in ('box', 'trapezoid') for f in prefilter2)
2779      and np.all(scale2 == 1.0)
2780      and np.all(translate2 == 0.0)
2781  )
2782  if can_use_fast_box_downsampling:
2783    assert isinstance(array, np.ndarray)  # Help mypy.
2784    array = _downsample_in_2d_using_box_filter(array, cast(Any, shape2))
2785    array = dst_gamma2.encode(array, dtype)
2786    return array
2787
2788  # Multidimensional resize can be expressed using einsum() with multiple per-dim resize matrices,
2789  # e.g., as in jax.image.resize().  A benefit is to seek the optimal order of multiplications.
2790  # However, efficiency often requires sparse resize matrices, which are unsupported in einsum().
2791  # Sparse tensors requested for tf.einsum: https://github.com/tensorflow/tensorflow/issues/43497
2792  # https://github.com/tensor-compiler/taco: C++ library that computes tensor algebra expressions
2793  # on sparse and dense tensors; however it does not interoperate with tensorflow, torch, or jax.
2794
2795  for dim in dim_order:
2796    skip_resize_on_this_dim = (
2797        shape2[dim] == array.shape[dim]
2798        and scale2[dim] == 1.0
2799        and translate2[dim] == 0.0
2800        and filter2[dim].interpolating
2801    )
2802    if skip_resize_on_this_dim:
2803      continue
2804
2805    def get_is_minification() -> bool:
2806      src_in_samples = src_gridtype2[dim].size_in_samples(array.shape[dim])
2807      dst_in_samples = dst_gridtype2[dim].size_in_samples(shape2[dim])
2808      return dst_in_samples / src_in_samples * scale2[dim] < 1.0
2809
2810    is_minification = get_is_minification()
2811    boundary_dim = boundary2[dim]
2812    if boundary_dim == 'auto':
2813      boundary_dim = 'clamp' if is_minification else 'reflect'
2814    boundary_dim = _get_boundary(boundary_dim)
2815    resize_matrix, cval_weight = _create_resize_matrix(
2816        array.shape[dim],
2817        shape2[dim],
2818        src_gridtype=src_gridtype2[dim],
2819        dst_gridtype=dst_gridtype2[dim],
2820        boundary=boundary_dim,
2821        filter=filter2[dim],
2822        prefilter=prefilter2[dim],
2823        scale=scale2[dim],
2824        translate=translate2[dim],
2825        dtype=weight_precision,
2826        arraylib=arraylib,
2827    )
2828
2829    array_dim: _Array = _arr_moveaxis(array, dim, 0)
2830    array_flat = _arr_reshape(array_dim, (array_dim.shape[0], -1))
2831    array_flat = _arr_possibly_make_contiguous(array_flat)
2832    if not is_minification and filter2[dim].requires_digital_filter:
2833      array_flat = _apply_digital_filter_1d(
2834          array_flat, src_gridtype2[dim], boundary_dim, cval, filter2[dim]
2835      )
2836
2837    array_flat = _arr_matmul_sparse_dense(resize_matrix, array_flat, num_threads=num_threads)
2838    if cval_weight is not None:
2839      cval_flat = np.broadcast_to(cval, array_dim.shape[1:]).reshape(-1)
2840      if np.issubdtype(array_dtype, np.complexfloating):
2841        cval_weight = _arr_astype(cval_weight, array_dtype)  # (Only necessary for 'tensorflow'.)
2842      array_flat += cval_weight[:, None] * cval_flat
2843
2844    if is_minification and filter2[dim].requires_digital_filter:  # use prefilter2[dim]?
2845      array_flat = _apply_digital_filter_1d(
2846          array_flat, dst_gridtype2[dim], boundary_dim, cval, filter2[dim]
2847      )
2848    array_dim = _arr_reshape(array_flat, (array_flat.shape[0], *array_dim.shape[1:]))
2849    array = _arr_moveaxis(array_dim, 0, dim)
2850
2851  array = dst_gamma2.encode(array, dtype)
2852  return array

Resample array (a grid of sample values) onto a grid with resolution shape.

The source array is any object recognized by ARRAYLIBS. It is interpreted as a grid with len(shape) domain coordinate dimensions, where each grid sample value has shape array.shape[len(shape):].

Some examples:

  • A grayscale image has array.shape = height, width and resizing it with len(shape) == 2 produces a new image of scalar values.
  • An RGB image has array.shape = height, width, 3 and resizing it with len(shape) == 2 produces a new image of RGB values.
  • An 3D grid of 3x3 Jacobians has array.shape = Z, Y, X, 3, 3 and resizing it with len(shape) == 3 produces a new 3D grid of Jacobians.

This function also allows scaling and translation from the source domain to the output domain through the parameters scale and translate. For more general transforms, see resample.

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. Its first len(shape) dimensions are the domain coordinate dimensions. Each grid dimension must be at least 1 for a 'dual' grid or at least 2 for a 'primal' grid.
  • shape: The number of grid samples in each coordinate dimension of the output array. The source array must have at least as many dimensions as len(shape).
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual' if gridtype, src_gridtype, and dst_gridtype are all kept None.
  • src_gridtype: Placement of the samples in the source domain grid for each dimension. Parameters gridtype and src_gridtype cannot both be set.
  • dst_gridtype: Placement of the samples in the output domain grid for each dimension. Parameters gridtype and dst_gridtype cannot both be set.
  • boundary: The reconstruction boundary rule for each dimension in shape, specified as either a name in BOUNDARIES or a Boundary instance. The special value 'auto' uses 'reflect' for upsampling and 'clamp' for downsampling.
  • cval: Constant value used beyond the samples by some boundary rules. It must be broadcastable onto array.shape[len(shape):]. It is subject to src_gamma.
  • filter: The reconstruction kernel for each dimension in shape, specified as either a filter name in FILTERS or a Filter instance. It is used during upsampling (i.e., magnification).
  • prefilter: The prefilter kernel for each dimension in shape, specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter. The default 'lanczos3' is good for natural images. For vector graphics images, 'trapezoid' is better because it avoids ringing artifacts.
  • gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from array and when creating output grid samples. It is specified as either a name in GAMMAS or a Gamma instance. If both array.dtype and dtype are uint, the default is 'power2'. If both are non-uint, the default is 'identity'. Otherwise, gamma or src_gamma/dst_gamma must be set. Gamma correction assumes that float values are in the range [0.0, 1.0].
  • src_gamma: Component transfer function used to "decode" array samples. Parameters gamma and src_gamma cannot both be set.
  • dst_gamma: Component transfer function used to "encode" the output samples. Parameters gamma and dst_gamma cannot both be set.
  • scale: Scaling factor applied to each dimension of the source domain when it is mapped onto the destination domain.
  • translate: Offset applied to each dimension of the scaled source domain when it is mapped onto the destination domain.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • dim_order: Override the automatically selected order in which the grid dimensions are resized. Must contain a permutation of range(len(shape)).
  • num_threads: Used to determine multithread parallelism if array is from numpy. If set to 'auto', it is selected automatically. Otherwise, it must be a positive integer.
Returns:

An array of the same class as the source array, with shape shape + array.shape[len(shape):] and data type dtype.

Example of image upsampling:

>>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
>>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.

Example of image downsampling:

>>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
>>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
>>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
>>> downsampled = resize(array, (24, 48))

Unit test:

>>> result = resize([1.0, 4.0, 5.0], shape=(4,))
>>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
def jaxjit_resize(array: Array, /, *args: Any, **kwargs: Any) -> Array:
2906def jaxjit_resize(array: _Array, /, *args: Any, **kwargs: Any) -> _Array:
2907  """Compute `resize` but with resize function jitted using Jax."""
2908  return _create_jaxjit_resize()(array, *args, **kwargs)  # pylint: disable=not-callable

Compute resize but with resize function jitted using Jax.

def uniform_resize( array: Array, /, shape: Iterable[int], *, object_fit: Literal['contain', 'cover'] = 'contain', gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, boundary: str | Boundary | Iterable[str | Boundary] = 'border', scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0, **kwargs: Any) -> Array:
2911def uniform_resize(
2912    array: _Array,
2913    /,
2914    shape: Iterable[int],
2915    *,
2916    object_fit: Literal['contain', 'cover'] = 'contain',
2917    gridtype: str | Gridtype | None = None,
2918    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2919    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2920    boundary: str | Boundary | Iterable[str | Boundary] = 'natural',  # Instead of 'auto' default.
2921    scale: float | Iterable[float] = 1.0,
2922    translate: float | Iterable[float] = 0.0,
2923    **kwargs: Any,
2924) -> _Array:
2925  """Resample `array` onto a grid with resolution `shape` but with uniform scaling.
2926
2927  Calls function `resize` with `scale` and `translate` set such that the aspect ratio of `array`
2928  is preserved.  The effect is similar to CSS `object-fit: contain`.
2929  The parameter `boundary` (whose default is changed to `'natural'`) determines the values assigned
2930  outside the source domain.
2931
2932  Args:
2933    array: Regular grid of source sample values.
2934    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2935      `array` must have at least as many dimensions as `len(shape)`.
2936    object_fit: Like CSS `object-fit`.  If `'contain'`, `array` is resized uniformly to fit within
2937      `shape`. If `'cover'`, `array` is resized to fully cover `shape`.
2938    gridtype: Placement of samples on all dimensions of both the source and output domain grids.
2939    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2940    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2941    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2942      a name in `BOUNDARIES` or a `Boundary` instance.  The default is `'natural'`, which assigns
2943      `cval` to output points that map outside the source unit domain.
2944    scale: Parameter may not be specified.
2945    translate: Parameter may not be specified.
2946    **kwargs: Additional parameters for `resize` function (including `cval`).
2947
2948  Returns:
2949    An array with shape `shape + array.shape[len(shape):]`.
2950
2951  >>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
2952  array([[0., 1., 1., 0.],
2953         [0., 1., 1., 0.]])
2954
2955  >>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
2956  array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
2957         [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
2958
2959  >>> a = np.arange(6.0).reshape(2, 3)
2960  >>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
2961  array([[0.5, 1.5],
2962         [3.5, 4.5]])
2963  """
2964  if scale != 1.0 or translate != 0.0:
2965    raise ValueError('`uniform_resize()` does not accept `scale` or `translate` parameters.')
2966  if isinstance(array, (tuple, list)):
2967    array = np.asarray(array)
2968  shape = tuple(shape)
2969  array_ndim = len(array.shape)
2970  if not 0 < len(shape) <= array_ndim:
2971    raise ValueError(f'Shape {array.shape} cannot be resized to {shape}.')
2972  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2973      gridtype, src_gridtype, dst_gridtype, len(shape), len(shape)
2974  )
2975  raw_scales = [
2976      dst_gridtype2[dim].size_in_samples(shape[dim])
2977      / src_gridtype2[dim].size_in_samples(array.shape[dim])
2978      for dim in range(len(shape))
2979  ]
2980  scale0 = {'contain': min(raw_scales), 'cover': max(raw_scales)}[object_fit]
2981  scale2 = scale0 / np.array(raw_scales)
2982  translate = (1.0 - scale2) / 2
2983  return resize(array, shape, boundary=boundary, scale=scale2, translate=translate, **kwargs)

Resample array onto a grid with resolution shape but with uniform scaling.

Calls function resize with scale and translate set such that the aspect ratio of array is preserved. The effect is similar to CSS object-fit: contain. The parameter boundary (whose default is changed to 'natural') determines the values assigned outside the source domain.

Arguments:
  • array: Regular grid of source sample values.
  • shape: The number of grid samples in each coordinate dimension of the output array. The source array must have at least as many dimensions as len(shape).
  • object_fit: Like CSS object-fit. If 'contain', array is resized uniformly to fit within shape. If 'cover', array is resized to fully cover shape.
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids.
  • src_gridtype: Placement of the samples in the source domain grid for each dimension.
  • dst_gridtype: Placement of the samples in the output domain grid for each dimension.
  • boundary: The reconstruction boundary rule for each dimension in shape, specified as either a name in BOUNDARIES or a Boundary instance. The default is 'natural', which assigns cval to output points that map outside the source unit domain.
  • scale: Parameter may not be specified.
  • translate: Parameter may not be specified.
  • **kwargs: Additional parameters for resize function (including cval).
Returns:

An array with shape shape + array.shape[len(shape):].

>>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
array([[0., 1., 1., 0.],
       [0., 1., 1., 0.]])
>>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
       [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
>>> a = np.arange(6.0).reshape(2, 3)
>>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
array([[0.5, 1.5],
       [3.5, 4.5]])
def resample( array: Array, /, coords: ArrayLike, *, gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual', boundary: str | Boundary | Iterable[str | Boundary] = 'auto', cval: ArrayLike = 0, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, gamma: str | Gamma | None = None, src_gamma: str | Gamma | None = None, dst_gamma: str | Gamma | None = None, jacobian: Optional[ArrayLike] = None, precision: DTypeLike = None, dtype: DTypeLike = None, max_block_size: int = 40000, debug: bool = False) -> Array:
2989def resample(
2990    array: _Array,
2991    /,
2992    coords: _ArrayLike,
2993    *,
2994    gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual',
2995    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2996    cval: _ArrayLike = 0.0,
2997    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2998    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2999    gamma: str | Gamma | None = None,
3000    src_gamma: str | Gamma | None = None,
3001    dst_gamma: str | Gamma | None = None,
3002    jacobian: _ArrayLike | None = None,
3003    precision: _DTypeLike = None,
3004    dtype: _DTypeLike = None,
3005    max_block_size: int = 40_000,
3006    debug: bool = False,
3007) -> _Array:
3008  """Interpolate `array` (a grid of samples) at specified unit-domain coordinates `coords`.
3009
3010  The last dimension of `coords` contains unit-domain coordinates at which to interpolate the
3011  domain grid samples in `array`.
3012
3013  The number of coordinates (`coords.shape[-1]`) determines how to interpret `array`: its first
3014  `coords.shape[-1]` dimensions define the grid, and the remaining dimensions describe each grid
3015  sample (e.g., scalar, vector, tensor).
3016
3017  Concretely, the grid has shape `array.shape[:coords.shape[-1]]` and each grid sample has shape
3018  `array.shape[coords.shape[-1]:]`.
3019
3020  Examples include:
3021
3022  - Resample a grayscale image with `array.shape = height, width` onto a new grayscale image with
3023    `new.shape = height2, width2` by using `coords.shape = height2, width2, 2`.
3024
3025  - Resample an RGB image with `array.shape = height, width, 3` onto a new RGB image with
3026    `new.shape = height2, width2, 3` by using `coords.shape = height2, width2, 2`.
3027
3028  - Sample an RGB image at `num` 2D points along a line segment by using `coords.shape = num, 2`.
3029
3030  - Sample an RGB image at a single 2D point by using `coords.shape = (2,)`.
3031
3032  - Sample a 3D grid of 3x3 Jacobians with `array.shape = nz, ny, nx, 3, 3` along a 2D plane by
3033    using `coords.shape = height, width, 3`.
3034
3035  - Map a grayscale image through a color map by using `array.shape = 256, 3` and
3036    `coords.shape = height, width`.
3037
3038  Args:
3039    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3040      The array must have numeric type.  The coordinate dimensions appear first, and
3041      each grid sample may have an arbitrary shape.  Each grid dimension must be at least 1 for
3042      a `'dual'` grid or at least 2 for a `'primal'` grid.
3043    coords: Grid of points at which to resample `array`.  The point coordinates are in the last
3044      dimension of `coords`.  The domain associated with the source grid is a unit hypercube,
3045      i.e. with a range [0, 1] on each coordinate dimension.  The output grid has shape
3046      `coords.shape[:-1]` and each of its grid samples has shape `array.shape[coords.shape[-1]:]`.
3047    gridtype: Placement of the samples in the source domain grid for each dimension, specified as
3048      either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`.
3049    boundary: The reconstruction boundary rule for each dimension in `coords.shape[-1]`, specified
3050      as either a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses
3051      `'reflect'` for upsampling and `'clamp'` for downsampling.
3052    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
3053      onto the shape `array.shape[coords.shape[-1]:]`.  It is subject to `src_gamma`.
3054    filter: The reconstruction kernel for each dimension in `coords.shape[-1]`, specified as either
3055      a filter name in `FILTERS` or a `Filter` instance.
3056    prefilter: The prefilter kernel for each dimension in `coords.shape[:-1]`, specified as either
3057      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3058      (i.e., minification).  If `None`, it inherits the value of `filter`.
3059    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples
3060      from `array` and when creating output grid samples.  It is specified as either a name in
3061      `GAMMAS` or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default
3062      is `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
3063      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
3064      range [0.0, 1.0].
3065    src_gamma: Component transfer function used to "decode" `array` samples.
3066      Parameters `gamma` and `src_gamma` cannot both be set.
3067    dst_gamma: Component transfer function used to "encode" the output samples.
3068      Parameters `gamma` and `dst_gamma` cannot both be set.
3069    jacobian: Optional array, which must be broadcastable onto the shape
3070      `coords.shape[:-1] + (coords.shape[-1], coords.shape[-1])`, storing for each point in the
3071      output grid the Jacobian matrix of the map from the unit output domain to the unit source
3072      domain.  If omitted, it is estimated by computing finite differences on `coords`.
3073    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3074      on `array.dtype`, `coords.dtype`, and `dtype`.
3075    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3076      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3077      to the uint range.
3078    max_block_size: If nonzero, maximum number of grid points in `coords` before the resampling
3079      evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
3080    debug: Show internal information.
3081
3082  Returns:
3083    A new sample grid of shape `coords.shape[:-1]`, represented as an array of shape
3084    `coords.shape[:-1] + array.shape[coords.shape[-1]:]`, of the same array library type as
3085    the source array.
3086
3087  **Example of resample operation:**
3088
3089  <center>
3090  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_warp_coords.png"/>
3091  </center>
3092
3093  For reference, the identity resampling for a scalar-valued grid with the default grid-type
3094  `'dual'` is:
3095
3096  >>> array = np.random.default_rng(0).random((5, 7, 3))
3097  >>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
3098  >>> new_array = resample(array, coords)
3099  >>> assert np.allclose(new_array, array)
3100
3101  It is more efficient to use the function `resize` for the special case where the `coords` are
3102  obtained as simple scaling and translation of a new regular grid over the source domain:
3103
3104  >>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
3105  >>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
3106  >>> coords = (coords - translate) / scale
3107  >>> resampled = resample(array, coords)
3108  >>> resized = resize(array, new_shape, scale=scale, translate=translate)
3109  >>> assert np.allclose(resampled, resized)
3110  """
3111  if isinstance(array, (tuple, list)):
3112    array = np.asarray(array)
3113  arraylib = _arr_arraylib(array)
3114  if len(array.shape) == 0:
3115    array = array[None]
3116  coords = np.atleast_1d(coords)
3117  if not np.issubdtype(_arr_dtype(array), np.number):
3118    raise ValueError(f'Type {array.dtype} is not numeric.')
3119  if not np.issubdtype(coords.dtype, np.floating):
3120    raise ValueError(f'Type {coords.dtype} is not floating.')
3121  array_ndim = len(array.shape)
3122  if coords.ndim == 1 and coords.shape[0] > 1 and array_ndim == 1:
3123    coords = coords[:, None]
3124  grid_ndim = coords.shape[-1]
3125  grid_shape = array.shape[:grid_ndim]
3126  sample_shape = array.shape[grid_ndim:]
3127  resampled_ndim = coords.ndim - 1
3128  resampled_shape = coords.shape[:-1]
3129  if grid_ndim > array_ndim:
3130    raise ValueError(
3131        f'There are more coordinate dimensions ({grid_ndim}) in {coords=} than in {array.shape=}.'
3132    )
3133  gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(gridtype), grid_ndim)]
3134  boundary2 = np.broadcast_to(np.array(boundary), grid_ndim).tolist()
3135  cval = np.broadcast_to(cval, sample_shape)
3136  prefilter = filter if prefilter is None else prefilter
3137  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), grid_ndim)]
3138  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), resampled_ndim)]
3139  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3140  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, _arr_dtype(array), dtype)
3141  del gridtype, boundary, filter, prefilter, src_gamma, dst_gamma
3142  if jacobian is not None:
3143    jacobian = np.broadcast_to(jacobian, resampled_shape + (coords.shape[-1],) * 2)
3144  precision = _get_precision(precision, [_arr_dtype(array), dtype], [coords.dtype])
3145  weight_precision = _real_precision(precision)
3146  coords = coords.astype(weight_precision, copy=False)
3147  is_minification = False  # Current limitation; no prefiltering!
3148  assert max_block_size >= 0 or max_block_size == _MAX_BLOCK_SIZE_RECURSING
3149  for dim in range(grid_ndim):
3150    if boundary2[dim] == 'auto':
3151      boundary2[dim] = 'clamp' if is_minification else 'reflect'
3152    boundary2[dim] = _get_boundary(boundary2[dim])
3153
3154  if max_block_size != _MAX_BLOCK_SIZE_RECURSING:
3155    array = src_gamma2.decode(array, precision)
3156    for dim in range(grid_ndim):
3157      assert not is_minification
3158      if filter2[dim].requires_digital_filter:
3159        array = _apply_digital_filter_1d(
3160            array, gridtype2[dim], boundary2[dim], cval, filter2[dim], axis=dim
3161        )
3162    cval = _arr_numpy(src_gamma2.decode(cval, precision))
3163
3164  if math.prod(resampled_shape) > max_block_size > 0:
3165    block_shape = _block_shape_with_min_size(resampled_shape, max_block_size)
3166    if debug:
3167      print(f'(resample: splitting coords into blocks {block_shape}).')
3168    coord_blocks = _split_array_into_blocks(coords, block_shape)
3169
3170    def process_block(coord_block: _NDArray) -> _Array:
3171      return resample(
3172          array,
3173          coord_block,
3174          gridtype=gridtype2,
3175          boundary=boundary2,
3176          cval=cval,
3177          filter=filter2,
3178          prefilter=prefilter2,
3179          src_gamma='identity',
3180          dst_gamma=dst_gamma2,
3181          jacobian=jacobian,
3182          precision=precision,
3183          dtype=dtype,
3184          max_block_size=_MAX_BLOCK_SIZE_RECURSING,
3185      )
3186
3187    result_blocks = _map_function_over_blocks(coord_blocks, process_block)
3188    array = _merge_array_from_blocks(result_blocks)
3189    return array
3190
3191  # A concrete example of upsampling:
3192  #   array = np.ones((5, 7, 3))  # source RGB image has height=5 width=7
3193  #   coords = np.random.default_rng(0).random((8, 9, 2))  # output RGB image has height=8 width=9
3194  #   resample(array, coords, filter=('cubic', 'lanczos3'))
3195  #   grid_shape = 5, 7  grid_ndim = 2
3196  #   resampled_shape = 8, 9  resampled_ndim = 2
3197  #   sample_shape = (3,)
3198  #   src_float_index.shape = 8, 9
3199  #   src_first_index.shape = 8, 9
3200  #   sample_index.shape = (4,) for dim == 0, then (6,) for dim == 1
3201  #   weight = [shape(8, 9, 4), shape(8, 9, 6)]
3202  #   src_index = [shape(8, 9, 4), shape(8, 9, 6)]
3203
3204  # Both:[shape(8, 9, 4), shape(8, 9, 6)]
3205  weight: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3206  src_index: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3207  uses_cval = False
3208  all_num_samples = []  # will be [4, 6]
3209
3210  for dim in range(grid_ndim):
3211    src_size = grid_shape[dim]  # scalar
3212    coords_dim = coords[..., dim]  # (8, 9)
3213    radius = filter2[dim].radius  # scalar
3214    num_samples = int(np.ceil(radius * 2))  # scalar
3215    all_num_samples.append(num_samples)
3216
3217    boundary_dim = boundary2[dim]
3218    coords_dim = boundary_dim.preprocess_coordinates(coords_dim)
3219
3220    # Sample positions mapped back to source unit domain [0, 1].
3221    src_float_index = gridtype2[dim].index_from_point(coords_dim, src_size)  # (8, 9)
3222    src_first_index = (
3223        np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
3224        - (num_samples - 1) // 2
3225    )  # (8, 9)
3226
3227    sample_index = np.arange(num_samples, dtype=np.int32)  # (4,) then (6,)
3228    src_index[dim] = src_first_index[..., None] + sample_index  # (8, 9, 4) then (8, 9, 6)
3229    if filter2[dim].name == 'trapezoid':
3230      # (It might require changing the filter radius at every sample.)
3231      raise ValueError('resample() cannot use adaptive `trapezoid` filter.')
3232    if filter2[dim].name == 'impulse':
3233      weight[dim] = np.ones_like(src_index[dim], weight_precision)
3234    else:
3235      x = src_float_index[..., None] - src_index[dim].astype(weight_precision)
3236      weight[dim] = filter2[dim](x).astype(weight_precision, copy=False)
3237      if filter2[dim].name != 'narrowbox' and (
3238          is_minification or not filter2[dim].partition_of_unity
3239      ):
3240        weight[dim] = weight[dim] / weight[dim].sum(axis=-1)[..., None]
3241
3242    src_index[dim], weight[dim] = boundary_dim.apply(
3243        src_index[dim], weight[dim], coords_dim, src_size, gridtype2[dim]
3244    )
3245    if boundary_dim.uses_cval or filter2[dim].name == 'narrowbox':
3246      uses_cval = True
3247
3248  # Gather the samples.
3249
3250  # Recall that src_index = [shape(8, 9, 4), shape(8, 9, 6)].
3251  src_index_expanded = []
3252  for dim in range(grid_ndim):
3253    src_index_dim = np.moveaxis(
3254        src_index[dim].reshape(src_index[dim].shape + (1,) * (grid_ndim - 1)),
3255        resampled_ndim,
3256        resampled_ndim + dim,
3257    )
3258    src_index_expanded.append(src_index_dim)
3259  indices = tuple(src_index_expanded)  # (shape(8, 9, 4, 1), shape(8, 9, 1, 6))
3260  samples = _arr_getitem(array, indices)  # (8, 9, 4, 6, 3)
3261
3262  # Indirectly derive samples.ndim (which is unavailable during Tensorflow grad computation).
3263  samples_ndim = resampled_ndim + grid_ndim + len(sample_shape)
3264
3265  # Compute an Einstein summation over the samples and each of the per-dimension weights.
3266
3267  def label(dims: Iterable[int]) -> str:
3268    return ''.join(chr(ord('a') + i) for i in dims)
3269
3270  operands = [samples]  # (8, 9, 4, 6, 3)
3271  assert samples_ndim < 26  # Letters 'a' through 'z'.
3272  labels = [label(range(samples_ndim))]  # ['abcde']
3273  for dim in range(grid_ndim):
3274    operands.append(weight[dim])  # (8, 9, 4), then (8, 9, 6)
3275    labels.append(label(list(range(resampled_ndim)) + [resampled_ndim + dim]))  # 'abc' then 'abd'
3276  output_label = label(
3277      list(range(resampled_ndim)) + list(range(resampled_ndim + grid_ndim, samples_ndim))
3278  )  # 'abe'
3279  subscripts = ','.join(labels) + '->' + output_label  # 'abcde,abc,abd->abe'
3280  # Starting in numpy 2.0, np.einsum() outputs np.float64 even with all np.float32 inputs;
3281  # GPT: "aligns np.einsum with other functions where intermediate calculations use higher
3282  # precision (np.float64) regardless of input type when floating-point arithmetic is involved."
3283  # we could explicitly add the parameter `dtype=precision`.
3284  array = _arr_einsum(subscripts, *operands)  # (8, 9, 3)
3285
3286  # Gathering `samples` is the memory bottleneck.  It would be ideal if the gather() and einsum()
3287  # computations could be fused.  In Jax, https://github.com/google/jax/issues/3206 suggests
3288  # that this may become possible.  In any case, for large outputs it helps to partition the
3289  # evaluation over output tiles (using max_block_size).
3290
3291  if uses_cval:
3292    cval_weight = 1.0 - np.multiply.reduce(
3293        [weight[dim].sum(axis=-1) for dim in range(resampled_ndim)]
3294    )  # (8, 9)
3295    cval_weight_reshaped = cval_weight.reshape(cval_weight.shape + (1,) * len(sample_shape))
3296    array += _make_array((cval_weight_reshaped * cval).astype(precision, copy=False), arraylib)
3297
3298  array = dst_gamma2.encode(array, dtype)
3299  return array

Interpolate array (a grid of samples) at specified unit-domain coordinates coords.

The last dimension of coords contains unit-domain coordinates at which to interpolate the domain grid samples in array.

The number of coordinates (coords.shape[-1]) determines how to interpret array: its first coords.shape[-1] dimensions define the grid, and the remaining dimensions describe each grid sample (e.g., scalar, vector, tensor).

Concretely, the grid has shape array.shape[:coords.shape[-1]] and each grid sample has shape array.shape[coords.shape[-1]:].

Examples include:

  • Resample a grayscale image with array.shape = height, width onto a new grayscale image with new.shape = height2, width2 by using coords.shape = height2, width2, 2.

  • Resample an RGB image with array.shape = height, width, 3 onto a new RGB image with new.shape = height2, width2, 3 by using coords.shape = height2, width2, 2.

  • Sample an RGB image at num 2D points along a line segment by using coords.shape = num, 2.

  • Sample an RGB image at a single 2D point by using coords.shape = (2,).

  • Sample a 3D grid of 3x3 Jacobians with array.shape = nz, ny, nx, 3, 3 along a 2D plane by using coords.shape = height, width, 3.

  • Map a grayscale image through a color map by using array.shape = 256, 3 and coords.shape = height, width.

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. The coordinate dimensions appear first, and each grid sample may have an arbitrary shape. Each grid dimension must be at least 1 for a 'dual' grid or at least 2 for a 'primal' grid.
  • coords: Grid of points at which to resample array. The point coordinates are in the last dimension of coords. The domain associated with the source grid is a unit hypercube, i.e. with a range [0, 1] on each coordinate dimension. The output grid has shape coords.shape[:-1] and each of its grid samples has shape array.shape[coords.shape[-1]:].
  • gridtype: Placement of the samples in the source domain grid for each dimension, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual'.
  • boundary: The reconstruction boundary rule for each dimension in coords.shape[-1], specified as either a name in BOUNDARIES or a Boundary instance. The special value 'auto' uses 'reflect' for upsampling and 'clamp' for downsampling.
  • cval: Constant value used beyond the samples by some boundary rules. It must be broadcastable onto the shape array.shape[coords.shape[-1]:]. It is subject to src_gamma.
  • filter: The reconstruction kernel for each dimension in coords.shape[-1], specified as either a filter name in FILTERS or a Filter instance.
  • prefilter: The prefilter kernel for each dimension in coords.shape[:-1], specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter.
  • gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from array and when creating output grid samples. It is specified as either a name in GAMMAS or a Gamma instance. If both array.dtype and dtype are uint, the default is 'power2'. If both are non-uint, the default is 'identity'. Otherwise, gamma or src_gamma/dst_gamma must be set. Gamma correction assumes that float values are in the range [0.0, 1.0].
  • src_gamma: Component transfer function used to "decode" array samples. Parameters gamma and src_gamma cannot both be set.
  • dst_gamma: Component transfer function used to "encode" the output samples. Parameters gamma and dst_gamma cannot both be set.
  • jacobian: Optional array, which must be broadcastable onto the shape coords.shape[:-1] + (coords.shape[-1], coords.shape[-1]), storing for each point in the output grid the Jacobian matrix of the map from the unit output domain to the unit source domain. If omitted, it is estimated by computing finite differences on coords.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype, coords.dtype, and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • max_block_size: If nonzero, maximum number of grid points in coords before the resampling evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
  • debug: Show internal information.
Returns:

A new sample grid of shape coords.shape[:-1], represented as an array of shape coords.shape[:-1] + array.shape[coords.shape[-1]:], of the same array library type as the source array.

Example of resample operation:

For reference, the identity resampling for a scalar-valued grid with the default grid-type 'dual' is:

>>> array = np.random.default_rng(0).random((5, 7, 3))
>>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
>>> new_array = resample(array, coords)
>>> assert np.allclose(new_array, array)

It is more efficient to use the function resize for the special case where the coords are obtained as simple scaling and translation of a new regular grid over the source domain:

>>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
>>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
>>> coords = (coords - translate) / scale
>>> resampled = resample(array, coords)
>>> resized = resize(array, new_shape, scale=scale, translate=translate)
>>> assert np.allclose(resampled, resized)
def resample_affine( array: Array, /, shape: Iterable[int], matrix: ArrayLike, *, gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, precision: DTypeLike = None, dtype: DTypeLike = None, **kwargs: Any) -> Array:
3302def resample_affine(
3303    array: _Array,
3304    /,
3305    shape: Iterable[int],
3306    matrix: _ArrayLike,
3307    *,
3308    gridtype: str | Gridtype | None = None,
3309    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3310    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3311    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3312    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3313    precision: _DTypeLike = None,
3314    dtype: _DTypeLike = None,
3315    **kwargs: Any,
3316) -> _Array:
3317  """Resample a source array using an affinely transformed grid of given shape.
3318
3319  The `matrix` transformation can be linear,
3320    `source_point = matrix @ destination_point`,
3321  or it can be affine where the last matrix column is an offset vector,
3322    `source_point = matrix @ (destination_point, 1.0)`.
3323
3324  Args:
3325    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3326      The array must have numeric type.  The number of grid dimensions is determined from
3327      `matrix.shape[0]`; the remaining dimensions are for each sample value and are all
3328      linearly interpolated.
3329    shape: Dimensions of the desired destination grid.  The number of destination grid dimensions
3330      may be different from that of the source grid.
3331    matrix: 2D array for a linear or affine transform from unit-domain destination points
3332      (in a space with `len(shape)` dimensions) into unit-domain source points (in a space with
3333      `matrix.shape[0]` dimensions).  If the matrix has `len(shape) + 1` columns, the last column
3334      is the affine offset (i.e., translation).
3335    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
3336      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
3337      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
3338    src_gridtype: Placement of samples in the source domain grid for each dimension.
3339      Parameters `gridtype` and `src_gridtype` cannot both be set.
3340    dst_gridtype: Placement of samples in the output domain grid for each dimension.
3341      Parameters `gridtype` and `dst_gridtype` cannot both be set.
3342    filter: The reconstruction kernel for each dimension in `matrix.shape[0]`, specified as either
3343      a filter name in `FILTERS` or a `Filter` instance.
3344    prefilter: The prefilter kernel for each dimension in `len(shape)`, specified as either
3345      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3346      (i.e., minification).  If `None`, it inherits the value of `filter`.
3347    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3348      on `array.dtype` and `dtype`.
3349    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3350      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3351      to the uint range.
3352    **kwargs: Additional parameters for `resample` function.
3353
3354  Returns:
3355    An array of the same class as the source `array`, representing a grid with specified `shape`,
3356    where each grid value is resampled from `array`.  Thus the shape of the returned array is
3357    `shape + array.shape[matrix.shape[0]:]`.
3358  """
3359  if isinstance(array, (tuple, list)):
3360    array = np.asarray(array)
3361  shape = tuple(shape)
3362  matrix = np.asarray(matrix)
3363  dst_ndim = len(shape)
3364  if matrix.ndim != 2:
3365    raise ValueError(f'Array {matrix} is not 2D matrix.')
3366  src_ndim = matrix.shape[0]
3367  # grid_shape = array.shape[:src_ndim]
3368  is_affine = matrix.shape[1] == dst_ndim + 1
3369  if src_ndim > len(array.shape):
3370    raise ValueError(
3371        f'Matrix {matrix} has more rows ({matrix.shape[0]}) than ndim in {array.shape=}.'
3372    )
3373  if matrix.shape[1] != dst_ndim and not is_affine:
3374    raise ValueError(
3375        f'Matrix has {matrix.shape=}, but we expect either {dst_ndim} or {dst_ndim + 1} columns.'
3376    )
3377  src_gridtype2, dst_gridtype2 = _get_gridtypes(
3378      gridtype, src_gridtype, dst_gridtype, src_ndim, dst_ndim
3379  )
3380  prefilter = filter if prefilter is None else prefilter
3381  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), src_ndim)]
3382  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), dst_ndim)]
3383  del src_gridtype, dst_gridtype, filter, prefilter
3384  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3385  precision = _get_precision(precision, [_arr_dtype(array), dtype], [])
3386  weight_precision = _real_precision(precision)
3387
3388  dst_position_list = []  # per dimension
3389  for dim in range(dst_ndim):
3390    dst_size = shape[dim]
3391    dst_index = np.arange(dst_size, dtype=weight_precision)
3392    dst_position_list.append(dst_gridtype2[dim].point_from_index(dst_index, dst_size))
3393  dst_position = np.meshgrid(*dst_position_list, indexing='ij')
3394
3395  linear_matrix = matrix[:, :-1] if is_affine else matrix
3396  src_position = np.tensordot(linear_matrix, dst_position, 1)
3397  coords = np.moveaxis(src_position, 0, -1)
3398  if is_affine:
3399    coords += matrix[:, -1]
3400
3401  # TODO: Based on grid_shape, shape, linear_matrix, and prefilter, determine a
3402  # convolution prefilter and apply it to bandlimit 'array', using boundary for padding.
3403
3404  return resample(
3405      array,
3406      coords,
3407      gridtype=src_gridtype2,
3408      filter=filter2,
3409      prefilter=prefilter2,
3410      precision=precision,
3411      dtype=dtype,
3412      **kwargs,
3413  )

Resample a source array using an affinely transformed grid of given shape.

The matrix transformation can be linear, source_point = matrix @ destination_point, or it can be affine where the last matrix column is an offset vector, source_point = matrix @ (destination_point, 1.0).

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. The number of grid dimensions is determined from matrix.shape[0]; the remaining dimensions are for each sample value and are all linearly interpolated.
  • shape: Dimensions of the desired destination grid. The number of destination grid dimensions may be different from that of the source grid.
  • matrix: 2D array for a linear or affine transform from unit-domain destination points (in a space with len(shape) dimensions) into unit-domain source points (in a space with matrix.shape[0] dimensions). If the matrix has len(shape) + 1 columns, the last column is the affine offset (i.e., translation).
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual' if gridtype, src_gridtype, and dst_gridtype are all kept None.
  • src_gridtype: Placement of samples in the source domain grid for each dimension. Parameters gridtype and src_gridtype cannot both be set.
  • dst_gridtype: Placement of samples in the output domain grid for each dimension. Parameters gridtype and dst_gridtype cannot both be set.
  • filter: The reconstruction kernel for each dimension in matrix.shape[0], specified as either a filter name in FILTERS or a Filter instance.
  • prefilter: The prefilter kernel for each dimension in len(shape), specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • **kwargs: Additional parameters for resample function.
Returns:

An array of the same class as the source array, representing a grid with specified shape, where each grid value is resampled from array. Thus the shape of the returned array is shape + array.shape[matrix.shape[0]:].

def rotation_about_center_in_2d( src_shape: ArrayLike, /, angle: float, *, new_shape: Optional[ArrayLike] = None, scale: float = 1.0) -> np.ndarray:
3444def rotation_about_center_in_2d(
3445    src_shape: _ArrayLike,
3446    /,
3447    angle: float,
3448    *,
3449    new_shape: _ArrayLike | None = None,
3450    scale: float = 1.0,
3451) -> _NDArray:
3452  """Return the 3x3 matrix mapping destination into a source unit domain.
3453
3454  The returned matrix accounts for the possibly non-square domain shapes.
3455
3456  Args:
3457    src_shape: Resolution `(ny, nx)` of the source domain grid.
3458    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3459      onto the destination domain.
3460    new_shape: Resolution `(ny, nx)` of the destination domain grid; it defaults to `src_shape`.
3461    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3462  """
3463
3464  def translation_matrix(vector: _NDArray) -> _NDArray:
3465    matrix = np.eye(len(vector) + 1)
3466    matrix[:-1, -1] = vector
3467    return matrix
3468
3469  def scaling_matrix(scale: _NDArray) -> _NDArray:
3470    return np.diag(tuple(scale) + (1.0,))
3471
3472  def rotation_matrix_2d(angle: float) -> _NDArray:
3473    cos, sin = np.cos(angle), np.sin(angle)
3474    return np.array([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
3475
3476  src_shape = np.asarray(src_shape)
3477  new_shape = src_shape if new_shape is None else np.asarray(new_shape)
3478  _check_eq(src_shape.shape, (2,))
3479  _check_eq(new_shape.shape, (2,))
3480  half = np.array([0.5, 0.5])
3481  matrix = (
3482      translation_matrix(half)
3483      @ scaling_matrix(min(src_shape) / src_shape)
3484      @ rotation_matrix_2d(angle)
3485      @ scaling_matrix(scale * new_shape / min(new_shape))
3486      @ translation_matrix(-half)
3487  )
3488  assert np.allclose(matrix[-1], [0.0, 0.0, 1.0])
3489  return matrix

Return the 3x3 matrix mapping destination into a source unit domain.

The returned matrix accounts for the possibly non-square domain shapes.

Arguments:
  • src_shape: Resolution (ny, nx) of the source domain grid.
  • angle: Angle in radians (positive from x to y axis) applied when mapping the source domain onto the destination domain.
  • new_shape: Resolution (ny, nx) of the destination domain grid; it defaults to src_shape.
  • scale: Scaling factor applied when mapping the source domain onto the destination domain.
def rotate_image_about_center( image: np.ndarray, /, angle: float, *, new_shape: Optional[ArrayLike] = None, scale: float = 1.0, num_rotations: int = 1, **kwargs: Any) -> np.ndarray:
3492def rotate_image_about_center(
3493    image: _NDArray,
3494    /,
3495    angle: float,
3496    *,
3497    new_shape: _ArrayLike | None = None,
3498    scale: float = 1.0,
3499    num_rotations: int = 1,
3500    **kwargs: Any,
3501) -> _NDArray:
3502  """Return a copy of `image` rotated about its center.
3503
3504  Args:
3505    image: Source grid samples; the first two dimensions are spatial (ny, nx).
3506    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3507      onto the destination domain.
3508    new_shape: Resolution `(ny, nx)` of the output grid; it defaults to `image.shape[:2]`.
3509    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3510    num_rotations: Number of rotations (each by `angle`).  Successive resamplings are useful in
3511      analyzing the filtering quality.
3512    **kwargs: Additional parameters for `resample_affine`.
3513  """
3514  new_shape = image.shape[:2] if new_shape is None else np.asarray(new_shape)
3515  matrix = rotation_about_center_in_2d(image.shape[:2], angle, new_shape=new_shape, scale=scale)
3516  for _ in range(num_rotations):
3517    image = resample_affine(image, new_shape, matrix[:-1], **kwargs)
3518  return image

Return a copy of image rotated about its center.

Arguments:
  • image: Source grid samples; the first two dimensions are spatial (ny, nx).
  • angle: Angle in radians (positive from x to y axis) applied when mapping the source domain onto the destination domain.
  • new_shape: Resolution (ny, nx) of the output grid; it defaults to image.shape[:2].
  • scale: Scaling factor applied when mapping the source domain onto the destination domain.
  • num_rotations: Number of rotations (each by angle). Successive resamplings are useful in analyzing the filtering quality.
  • **kwargs: Additional parameters for resample_affine.
def pil_image_resize(array: ArrayLike, /, shape: Iterable[int], *, filter: str) -> np.ndarray:
3521def pil_image_resize(
3522    array: _ArrayLike,
3523    /,
3524    shape: Iterable[int],
3525    *,
3526    filter: str,
3527    boundary: str = 'natural',
3528    cval: float = 0.0,
3529) -> _NDArray:
3530  """Invoke `PIL.Image.resize` using the same parameters as `resize`."""
3531  import PIL.Image
3532
3533  if boundary != 'natural':
3534    raise ValueError(f"{boundary=} must equal 'natural'.")
3535  del cval
3536  array = np.asarray(array)
3537  assert 1 <= array.ndim <= 3
3538  assert np.issubdtype(array.dtype, np.floating)
3539  shape = tuple(shape)
3540  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3541  if array.ndim == 1:
3542    return pil_image_resize(array[None], (1, *shape), filter=filter)[0]
3543  if not hasattr(PIL.Image, 'Resampling'):  # Pillow<9.0
3544    PIL.Image.Resampling = PIL.Image
3545  filters = {
3546      'impulse': PIL.Image.Resampling.NEAREST,
3547      'box': PIL.Image.Resampling.BOX,
3548      'triangle': PIL.Image.Resampling.BILINEAR,
3549      'hamming1': PIL.Image.Resampling.HAMMING,
3550      'cubic': PIL.Image.Resampling.BICUBIC,
3551      'lanczos3': PIL.Image.Resampling.LANCZOS,
3552  }
3553  if filter not in filters:
3554    raise ValueError(f'{filter=} not in {filters=}.')
3555  pil_resample = filters[filter]
3556  if array.ndim == 2:
3557    return np.array(
3558        PIL.Image.fromarray(array).resize(shape[::-1], resample=pil_resample), array.dtype
3559    )
3560  stack = []
3561  for channel in np.moveaxis(array, -1, 0):
3562    pil_image = PIL.Image.fromarray(channel).resize(shape[::-1], resample=pil_resample)
3563    stack.append(np.array(pil_image, array.dtype))
3564  return np.dstack(stack)

Invoke PIL.Image.resize using the same parameters as resize.

def cv_resize(array: ArrayLike, /, shape: Iterable[int], *, filter: str) -> np.ndarray:
3567def cv_resize(
3568    array: _ArrayLike,
3569    /,
3570    shape: Iterable[int],
3571    *,
3572    filter: str,
3573    boundary: str = 'clamp',
3574    cval: float = 0.0,
3575) -> _NDArray:
3576  """Invoke `cv.resize` using the same parameters as `resize`."""
3577  import cv2 as cv
3578
3579  if boundary != 'clamp':
3580    raise ValueError(f"{boundary=} must equal 'clamp'.")
3581  del cval
3582  array = np.asarray(array)
3583  assert 1 <= array.ndim <= 3
3584  shape = tuple(shape)
3585  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3586  if array.ndim == 1:
3587    return cv_resize(array[None], (1, *shape), filter=filter)[0]
3588  filters = {
3589      'impulse': cv.INTER_NEAREST,  # Or consider cv.INTER_NEAREST_EXACT.
3590      'triangle': cv.INTER_LINEAR_EXACT,  # Or just cv.INTER_LINEAR.
3591      'trapezoid': cv.INTER_AREA,
3592      'sharpcubic': cv.INTER_CUBIC,
3593      'lanczos4': cv.INTER_LANCZOS4,
3594  }
3595  if filter not in filters:
3596    raise ValueError(f'{filter=} not in {filters=}.')
3597  interpolation = filters[filter]
3598  result = cv.resize(array, shape[::-1], interpolation=interpolation)
3599  if array.ndim == 3 and result.ndim == 2:
3600    assert array.shape[2] == 1
3601    return result[..., None]  # Add back the last dimension dropped by cv.resize().
3602  return result

Invoke cv.resize using the same parameters as resize.

def scipy_ndimage_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, boundary: str = 'reflect', cval: float = 0.0) -> np.ndarray:
3605def scipy_ndimage_resize(
3606    array: _ArrayLike,
3607    /,
3608    shape: Iterable[int],
3609    *,
3610    filter: str,
3611    boundary: str = 'reflect',
3612    cval: float = 0.0,
3613    scale: float | Iterable[float] = 1.0,
3614    translate: float | Iterable[float] = 0.0,
3615) -> _NDArray:
3616  """Invoke `scipy.ndimage.map_coordinates` using the same parameters as `resize`."""
3617  array = np.asarray(array)
3618  shape = tuple(shape)
3619  assert 1 <= len(shape) <= array.ndim
3620  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3621  if filter not in filters:
3622    raise ValueError(f'{filter=} not in {filters=}.')
3623  order = filters[filter]
3624  boundaries = {'reflect': 'reflect', 'wrap': 'grid-wrap', 'clamp': 'nearest', 'border': 'constant'}
3625  if boundary not in boundaries:
3626    raise ValueError(f'{boundary=} not in {boundaries=}.')
3627  mode = boundaries[boundary]
3628  shape_all = shape + array.shape[len(shape) :]
3629  coords = np.moveaxis(np.indices(shape_all, array.dtype), 0, -1)
3630  coords[..., : len(shape)] = (
3631      (coords[..., : len(shape)] + 0.5) / shape - np.asarray(translate)
3632  ) / np.asarray(scale) * np.array(array.shape)[: len(shape)] - 0.5
3633  coords = np.moveaxis(coords, -1, 0)
3634  return scipy.ndimage.map_coordinates(array, coords, order=order, mode=mode, cval=cval)

Invoke scipy.ndimage.map_coordinates using the same parameters as resize.

def skimage_transform_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, boundary: str = 'reflect', cval: float = 0.0) -> np.ndarray:
3637def skimage_transform_resize(
3638    array: _ArrayLike,
3639    /,
3640    shape: Iterable[int],
3641    *,
3642    filter: str,
3643    boundary: str = 'reflect',
3644    cval: float = 0.0,
3645) -> _NDArray:
3646  """Invoke `skimage.transform.resize` using the same parameters as `resize`."""
3647  import skimage.transform
3648
3649  array = np.asarray(array)
3650  shape = tuple(shape)
3651  assert 1 <= len(shape) <= array.ndim
3652  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3653  if filter not in filters:
3654    raise ValueError(f'{filter=} not in {filters=}.')
3655  order = filters[filter]
3656  boundaries = {'reflect': 'symmetric', 'wrap': 'wrap', 'clamp': 'edge', 'border': 'constant'}
3657  if boundary not in boundaries:
3658    raise ValueError(f'{boundary=} not in {boundaries=}.')
3659  mode = boundaries[boundary]
3660  shape_all = shape + array.shape[len(shape) :]
3661  # Default anti_aliasing=None automatically enables (poor) Gaussian prefilter if downsampling.
3662  # clip=False is the default behavior in `resampler` if the output type is non-integer.
3663  return skimage.transform.resize(
3664      array, shape_all, order=order, mode=mode, cval=cval, clip=False
3665  )  # type: ignore[no-untyped-call]

Invoke skimage.transform.resize using the same parameters as resize.

def tf_image_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, antialias: bool = True) -> TensorflowTensor:
3680def tf_image_resize(
3681    array: _ArrayLike,
3682    /,
3683    shape: Iterable[int],
3684    *,
3685    filter: str,
3686    boundary: str = 'natural',
3687    cval: float = 0.0,
3688    antialias: bool = True,
3689) -> _TensorflowTensor:
3690  """Invoke `tf.image.resize` using the same parameters as `resize`."""
3691  import tensorflow as tf
3692
3693  if filter not in _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER:
3694    raise ValueError(f'{filter=} not in {_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER=}.')
3695  if boundary != 'natural':
3696    raise ValueError(f"{boundary=} must equal 'natural'.")
3697  del cval
3698  array2 = tf.convert_to_tensor(array)
3699  ndim = len(array2.shape)
3700  del array
3701  assert 1 <= ndim <= 3
3702  shape = tuple(shape)
3703  _check_eq(len(shape), 2 if ndim >= 2 else 1)
3704  match ndim:
3705    case 1:
3706      return tf_image_resize(array2[None], (1, *shape), filter=filter, antialias=antialias)[0]
3707    case 2:
3708      return tf_image_resize(array2[..., None], shape, filter=filter, antialias=antialias)[..., 0]
3709    case _:
3710      method = _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER[filter]
3711      return tf.image.resize(array2, shape, method=method, antialias=antialias)

Invoke tf.image.resize using the same parameters as resize.

def torch_nn_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, antialias: bool = False) -> TorchTensor:
3722def torch_nn_resize(
3723    array: _ArrayLike,
3724    /,
3725    shape: Iterable[int],
3726    *,
3727    filter: str,
3728    boundary: str = 'clamp',
3729    cval: float = 0.0,
3730    antialias: bool = False,
3731) -> _TorchTensor:
3732  """Invoke `torch.nn.functional.interpolate` using the same parameters as `resize`."""
3733  import torch
3734
3735  if filter not in _TORCH_INTERPOLATE_MODE_FROM_FILTER:
3736    raise ValueError(f'{filter=} not in {_TORCH_INTERPOLATE_MODE_FROM_FILTER=}.')
3737  if boundary != 'clamp':
3738    raise ValueError(f"{boundary=} must equal 'clamp'.")
3739  del cval
3740  a = torch.as_tensor(array)
3741  del array
3742  assert 1 <= a.ndim <= 3
3743  shape = tuple(shape)
3744  _check_eq(len(shape), 2 if a.ndim >= 2 else 1)
3745  mode = _TORCH_INTERPOLATE_MODE_FROM_FILTER[filter]
3746
3747  def local_resize(a: _TorchTensor) -> _TorchTensor:
3748    # For upsampling, BILINEAR antialias is same PSNR and slower,
3749    #  and BICUBIC antialias is worse PSNR and faster.
3750    # For downsampling, antialias improves PSNR for both BILINEAR and BICUBIC.
3751    # Default align_corners=None corresponds to False which is what we desire.
3752    return torch.nn.functional.interpolate(a, shape, mode=mode, antialias=antialias)
3753
3754  match a.ndim:
3755    case 1:
3756      shape = (1, *shape)
3757      return local_resize(a[None, None, None])[0, 0, 0]
3758    case 2:
3759      return local_resize(a[None, None])[0, 0]
3760    case _:
3761      return local_resize(a.moveaxis(2, 0)[None])[0].moveaxis(0, 2)

Invoke torch.nn.functional.interpolate using the same parameters as resize.

def jax_image_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0) -> JaxArray:
3764def jax_image_resize(
3765    array: _ArrayLike,
3766    /,
3767    shape: Iterable[int],
3768    *,
3769    filter: str,
3770    boundary: str = 'natural',
3771    cval: float = 0.0,
3772    scale: float | Iterable[float] = 1.0,
3773    translate: float | Iterable[float] = 0.0,
3774) -> _JaxArray:
3775  """Invoke `jax.image.scale_and_translate` using the same parameters as `resize`."""
3776  import jax.image
3777  import jax.numpy as jnp
3778
3779  filters = 'triangle cubic lanczos3 lanczos5'.split()
3780  if filter not in filters:
3781    raise ValueError(f'{filter=} not in {filters=}.')
3782  if boundary != 'natural':
3783    raise ValueError(f"{boundary=} must equal 'natural'.")
3784  # When `scale` or `translate` are applied, any region outside the unit domain is assigned value 0.
3785  # To be consistent, the parameter `cval` must be zero.
3786  if scale != 1.0 and cval != 0.0:
3787    raise ValueError(f'Non-unity {scale=} requires that {cval=} be zero.')
3788  if translate != 0.0 and cval != 0.0:
3789    raise ValueError(f'Nonzero {translate=} requires that {cval=} be zero.')
3790  array2 = jnp.asarray(array)
3791  del array
3792  shape = tuple(shape)
3793  assert len(shape) <= array2.ndim
3794  completed_shape = shape + (1,) * (array2.ndim - len(shape))
3795  spatial_dims = list(range(len(shape)))
3796  scale2 = np.broadcast_to(np.array(scale), len(shape))
3797  scale2 = scale2 / np.array(array2.shape[: len(shape)]) * np.array(shape)
3798  translate2 = np.broadcast_to(np.array(translate), len(shape))
3799  translate2 = translate2 * np.array(shape)
3800  return jax.image.scale_and_translate(
3801      array2, completed_shape, spatial_dims, scale2, translate2, filter
3802  )

Invoke jax.image.scale_and_translate using the same parameters as resize.