resampler

resampler: fast differentiable resizing and warping of arbitrary grids.

   1"""resampler: fast differentiable resizing and warping of arbitrary grids.
   2
   3.. include:: ../README.md
   4"""
   5
   6__docformat__ = 'google'
   7__version__ = '0.8.9'
   8__version_info__ = tuple(int(num) for num in __version__.split('.'))
   9
  10from collections.abc import Callable, Iterable, Sequence
  11import abc
  12import dataclasses
  13import functools
  14import importlib
  15import itertools
  16import math
  17import os
  18import sys
  19import types
  20import typing
  21from typing import Any, Generic, Literal, Union
  22
  23import numpy as np
  24import numpy.typing
  25import scipy.interpolate
  26import scipy.linalg
  27import scipy.ndimage
  28import scipy.sparse
  29import scipy.sparse.linalg
  30
  31
  32def _noop_decorator(*args: Any, **kwargs: Any) -> Any:
  33  """Return function decorated with no-operation; invocable with or without args."""
  34  if len(args) != 1 or not callable(args[0]) or kwargs:
  35    return _noop_decorator  # Decorator is invoked with arguments; ignore them.
  36  func: Callable[..., Any] = args[0]
  37  return func
  38
  39
  40try:
  41  import numba
  42except ModuleNotFoundError:
  43  numba = sys.modules['numba'] = types.ModuleType('numba')
  44  numba.njit = _noop_decorator
  45using_numba = hasattr(numba, 'jit')
  46
  47if TYPE_CHECKING:
  48  import jax.numpy
  49  import tensorflow as tf
  50  import torch
  51
  52  _DType = np.dtype[Any]  # (Requires Python 3.9 or TYPE_CHECKING.)
  53  _NDArray = numpy.NDArray[Any]
  54  _DTypeLike = numpy.DTypeLike
  55  _ArrayLike = numpy.ArrayLike
  56  _TensorflowTensor: TypeAlias = tf.Tensor
  57  _TorchTensor: TypeAlias = torch.Tensor
  58  _JaxArray: TypeAlias = jax.numpy.ndarray
  59
  60else:
  61  # Typically, create named types for use in the `pdoc` documentation.
  62  # But here, these are superseded by the declarations in __init__.pyi!
  63  _DType = Any
  64  _NDArray = Any
  65  _DTypeLike = Any
  66  _ArrayLike = Any
  67  _TensorflowTensor = Any
  68  _TorchTensor = Any
  69  _JaxArray = Any
  70
  71_Array = TypeVar('_Array', _NDArray, _TensorflowTensor, _TorchTensor, _JaxArray)
  72_AnyArray = Union[_NDArray, _TensorflowTensor, _TorchTensor, _JaxArray]
  73
  74
  75def _check_eq(a: Any, b: Any, /) -> None:
  76  """If the two values or arrays are not equal, raise an exception with a useful message."""
  77  are_equal = np.all(a == b) if isinstance(a, np.ndarray) else a == b
  78  if not are_equal:
  79    raise AssertionError(f'{a!r} == {b!r}')
  80
  81
  82def _real_precision(dtype: _DTypeLike, /) -> _DType:
  83  """Return the type of the real part of a complex number."""
  84  return np.array([], dtype).real.dtype
  85
  86
  87def _complex_precision(dtype: _DTypeLike, /) -> _DType:
  88  """Return a complex type to represent a non-complex type."""
  89  return np.result_type(dtype, np.complex64)
  90
  91
  92def _get_precision(
  93    precision: _DTypeLike, dtypes: list[_DType], weight_dtypes: list[_DType], /
  94) -> _DType:
  95  """Return dtype based on desired precision or on data and weight types."""
  96  precision2 = np.dtype(
  97      precision if precision is not None else np.result_type(np.float32, *dtypes, *weight_dtypes)
  98  )
  99  if not np.issubdtype(precision2, np.inexact):
 100    raise ValueError(f'Precision {precision2} is not floating or complex.')
 101  check_complex = [precision2, *dtypes]
 102  is_complex = [np.issubdtype(dtype, np.complexfloating) for dtype in check_complex]
 103  if len(set(is_complex)) != 1:
 104    s_types = ','.join(str(dtype) for dtype in check_complex)
 105    raise ValueError(f'Types {s_types} must be all real or all complex.')
 106  return precision2
 107
 108
 109def _sinc(x: _ArrayLike, /) -> _NDArray:
 110  """Return the value `np.sinc(x)` but improved to:
 111  (1) ignore underflow that occurs at 0.0 for np.float32, and
 112  (2) output exact zero for integer input values.
 113
 114  >>> _sinc(np.array([-3, -2, -1, 0], np.float32))
 115  array([0., 0., 0., 1.], dtype=float32)
 116
 117  >>> _sinc(np.array([-3, -2, -1, 0]))
 118  array([0., 0., 0., 1.])
 119
 120  >>> _sinc(0)
 121  1.0
 122  """
 123  x = np.asarray(x)
 124  x_is_scalar = x.ndim == 0
 125  with np.errstate(under='ignore'):
 126    result = np.sinc(np.atleast_1d(x))
 127    result[x == np.floor(x)] = 0.0
 128    result[x == 0] = 1.0
 129    return result.item() if x_is_scalar else result
 130
 131
 132def _is_symmetric(matrix: scipy.sparse.spmatrix, /, tol: float = 1e-10) -> bool:
 133  """Return True if the sparse matrix is symmetric."""
 134  norm: float = scipy.sparse.linalg.norm(matrix - matrix.T, np.inf)
 135  return norm <= tol
 136
 137
 138def _cache_sampled_1d_function(
 139    xmin: float,
 140    xmax: float,
 141    *,
 142    num_samples: int = 3_600,
 143    enable: bool = True,
 144) -> Callable[[Callable[[_ArrayLike], _NDArray]], Callable[[_ArrayLike], _NDArray]]:
 145  """Function decorator to linearly interpolate cached function values."""
 146  # Speed unchanged up to num_samples=12_000, then slow decrease until 100_000.
 147
 148  def wrap_it(func: Callable[[_ArrayLike], _NDArray]) -> Callable[[_ArrayLike], _NDArray]:
 149    if not enable:
 150      return func
 151
 152    dx = (xmax - xmin) / num_samples
 153    x = np.linspace(xmin, xmax + dx, num_samples + 2, dtype=np.float32)
 154    samples_func = func(x)
 155    assert np.all(samples_func[[0, -1, -2]] == 0.0)
 156
 157    @functools.wraps(func)
 158    def interpolate_using_cached_samples(x: _ArrayLike) -> _NDArray:
 159      x = np.asarray(x)
 160      index_float = np.clip((x - xmin) / dx, 0.0, num_samples)
 161      index = index_float.astype(np.int64)
 162      frac = np.subtract(index_float, index, dtype=np.float32)
 163      return (1 - frac) * samples_func[index] + frac * samples_func[index + 1]
 164
 165    return interpolate_using_cached_samples
 166
 167  return wrap_it
 168
 169
 170class _DownsampleIn2dUsingBoxFilter:
 171  """Fast 2D box-filter downsampling using cached numba-jitted functions."""
 172
 173  def __init__(self) -> None:
 174    # Downsampling function for params (dtype, block_height, block_width, ch).
 175    self._jitted_function: dict[tuple[_DType, int, int, int], Callable[[_NDArray], _NDArray]] = {}
 176
 177  def __call__(self, array: _NDArray, shape: tuple[int, int]) -> _NDArray:
 178    assert using_numba
 179    assert array.ndim in (2, 3), array.ndim
 180    _check_eq(len(shape), 2)
 181    dtype = array.dtype
 182    a = array[..., None] if array.ndim == 2 else array
 183    height, width, ch = a.shape
 184    new_height, new_width = shape
 185    if height % new_height != 0 or width % new_width != 0:
 186      raise ValueError(f'Shape {array.shape} not a multiple of {shape}.')
 187    block_height, block_width = height // new_height, width // new_width
 188
 189    def func(array: _NDArray) -> _NDArray:
 190      new_height = array.shape[0] // block_height
 191      new_width = array.shape[1] // block_width
 192      result = np.empty((new_height, new_width, ch), dtype)
 193      totals = np.empty(ch, dtype)
 194      factor = dtype.type(1.0 / (block_height * block_width))
 195      for y in numba.prange(new_height):  # pylint: disable=not-an-iterable
 196        for x in range(new_width):
 197          # Introducing "y2, x2 = y * block_height, x * block_width" is actually slower.
 198          if ch == 1:  # All the branches involve compile-time constants.
 199            total = dtype.type(0.0)
 200            for yy in range(block_height):
 201              for xx in range(block_width):
 202                total += array[y * block_height + yy, x * block_width + xx, 0]
 203            result[y, x, 0] = total * factor
 204          elif ch == 3:
 205            total0 = total1 = total2 = dtype.type(0.0)
 206            for yy in range(block_height):
 207              for xx in range(block_width):
 208                total0 += array[y * block_height + yy, x * block_width + xx, 0]
 209                total1 += array[y * block_height + yy, x * block_width + xx, 1]
 210                total2 += array[y * block_height + yy, x * block_width + xx, 2]
 211            result[y, x, 0] = total0 * factor
 212            result[y, x, 1] = total1 * factor
 213            result[y, x, 2] = total2 * factor
 214          elif block_height * block_width >= 9:
 215            for c in range(ch):
 216              totals[c] = 0.0
 217            for yy in range(block_height):
 218              for xx in range(block_width):
 219                for c in range(ch):
 220                  totals[c] += array[y * block_height + yy, x * block_width + xx, c]
 221            for c in range(ch):
 222              result[y, x, c] = totals[c] * factor
 223          else:
 224            for c in range(ch):
 225              total = dtype.type(0.0)
 226              for yy in range(block_height):
 227                for xx in range(block_width):
 228                  total += array[y * block_height + yy, x * block_width + xx, c]
 229              result[y, x, c] = total * factor
 230      return result
 231
 232    signature = dtype, block_height, block_width, ch
 233    jitted_function = self._jitted_function.get(signature)
 234    if not jitted_function:
 235      if 0:
 236        print(f'Creating numba jit-wrapper for {signature}.')
 237      jitted_function = numba.njit(func, parallel=True, fastmath=True, cache=True)
 238      self._jitted_function[signature] = jitted_function
 239
 240    try:
 241      result = jitted_function(a)
 242    except RuntimeError:
 243      message = (
 244          'resampler: This runtime error may be due to a corrupt resampler/__pycache__;'
 245          ' try deleting that directory.'
 246      )
 247      print(message, file=sys.stdout, flush=True)
 248      print(message, file=sys.stderr, flush=True)
 249      raise
 250
 251    return result[..., 0] if array.ndim == 2 else result
 252
 253
 254_downsample_in_2d_using_box_filter = _DownsampleIn2dUsingBoxFilter()
 255
 256
 257@numba.njit(nogil=True, fastmath=True, cache=True)  # type: ignore[misc]
 258def _numba_serial_csr_dense_mult(
 259    indptr: _NDArray,
 260    indices: _NDArray,
 261    data: _NDArray,
 262    src: _NDArray,
 263    dst: _NDArray,
 264) -> None:
 265  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 266
 267  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 268  """
 269  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 270  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 271  acc = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 272  for i in range(dst.shape[0]):
 273    acc[:] = 0
 274    for jj in range(indptr[i], indptr[i + 1]):
 275      j = indices[jj]
 276      acc += data[jj] * src[j]
 277    dst[i] = acc
 278
 279
 280# I tried using the "minimal" "parallel=" config but this did not result in any jit speedup:
 281#  parallel=dict(comprehension=False, prange=True, numpy=True, reduction=False,
 282#                setitem=False, stencil=False, fusion=False)
 283@numba.njit(parallel=True, fastmath=True, cache=True)  # type: ignore[misc]
 284def _numba_parallel_csr_dense_mult(
 285    indptr: _NDArray,
 286    indices: _NDArray,
 287    data: _NDArray,
 288    src: _NDArray,
 289    dst: _NDArray,
 290) -> None:
 291  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 292
 293  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 294  The introduction of parallel omp threads provides another 2-4x speedup.
 295  However, "parallel=True" leads to slow jitting (~3 s), so we cache the jitted code on disk.
 296  """
 297  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 298  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 299  acc0 = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 300  # Default is static scheduling, which is fine.
 301  for i in numba.prange(dst.shape[0]):  # pylint: disable=not-an-iterable
 302    acc = np.zeros_like(acc0)  # Numba automatically hoists the allocation outside the loop.
 303    for jj in range(indptr[i], indptr[i + 1]):
 304      j = indices[jj]
 305      acc += data[jj] * src[j]
 306    dst[i] = acc
 307
 308
 309@dataclasses.dataclass
 310class _Arraylib(abc.ABC, Generic[_Array]):
 311  """Abstract base class for abstraction of array libraries."""
 312
 313  arraylib: str
 314  """Name of array library (e.g., `'numpy'`, `'tensorflow'`, `'torch'`, `'jax'`)."""
 315
 316  array: _Array
 317
 318  @staticmethod
 319  @abc.abstractmethod
 320  def recognize(array: Any) -> bool:
 321    """Return True if `array` is recognized by this _Arraylib."""
 322
 323  @abc.abstractmethod
 324  def numpy(self) -> _NDArray:
 325    """Return a `numpy` version of `self.array`."""
 326
 327  @abc.abstractmethod
 328  def dtype(self) -> _DType:
 329    """Return the equivalent of `self.array.dtype` as a `numpy` `dtype`."""
 330
 331  @abc.abstractmethod
 332  def astype(self, dtype: _DTypeLike) -> _Array:
 333    """Return the equivalent of `self.array.astype(dtype, copy=False)` with `numpy` `dtype`."""
 334
 335  def reshape(self, shape: tuple[int, ...]) -> _Array:
 336    """Return the equivalent of `self.array.reshape(shape)`."""
 337    return self.array.reshape(shape)
 338
 339  def possibly_make_contiguous(self) -> _Array:
 340    """Return a contiguous copy of `self.array` or just `self.array` if already contiguous."""
 341    return self.array
 342
 343  @abc.abstractmethod
 344  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _Array:
 345    """Return the equivalent of `self.array.clip(low, high, dtype=dtype)` with `numpy` `dtype`."""
 346
 347  @abc.abstractmethod
 348  def square(self) -> _Array:
 349    """Return the equivalent of `np.square(self.array)`."""
 350
 351  @abc.abstractmethod
 352  def sqrt(self) -> _Array:
 353    """Return the equivalent of `np.sqrt(self.array)`."""
 354
 355  def getitem(self, indices: Any) -> _Array:
 356    """Return the equivalent of `self.array[indices]` (a "gather" operation)."""
 357    return self.array[indices]
 358
 359  @abc.abstractmethod
 360  def where(self, if_true: Any, if_false: Any) -> _Array:
 361    """Return the equivalent of `np.where(self.array, if_true, if_false)`."""
 362
 363  @abc.abstractmethod
 364  def transpose(self, axes: Sequence[int]) -> _Array:
 365    """Return the equivalent of `np.transpose(self.array, axes)`."""
 366
 367  @abc.abstractmethod
 368  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 369    """Return the best order in which to process dims for resizing `self.array` to `dst_shape`."""
 370
 371  @abc.abstractmethod
 372  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _Array:
 373    """Return the multiplication of the `sparse` matrix and `self.array`."""
 374
 375  @staticmethod
 376  @abc.abstractmethod
 377  def concatenate(arrays: Sequence[_Array], axis: int) -> _Array:
 378    """Return the equivalent of `np.concatenate(arrays, axis)`."""
 379
 380  @staticmethod
 381  @abc.abstractmethod
 382  def einsum(subscripts: str, *operands: _Array) -> _Array:
 383    """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 384
 385  @staticmethod
 386  @abc.abstractmethod
 387  def make_sparse_matrix(
 388      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 389  ) -> _Array:
 390    """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 391    However, the indices must be ordered and unique."""
 392
 393
 394class _NumpyArraylib(_Arraylib[_NDArray]):
 395  """Numpy implementation of the array abstraction."""
 396
 397  # pylint: disable=missing-function-docstring
 398
 399  def __init__(self, array: _NDArray) -> None:
 400    super().__init__(arraylib='numpy', array=np.asarray(array))
 401
 402  @staticmethod
 403  def recognize(array: Any) -> bool:
 404    return isinstance(array, (np.ndarray, np.number))
 405
 406  def numpy(self) -> _NDArray:
 407    return self.array
 408
 409  def dtype(self) -> _DType:
 410    dtype: _DType = self.array.dtype
 411    return dtype
 412
 413  def astype(self, dtype: _DTypeLike) -> _NDArray:
 414    return self.array.astype(dtype, copy=False)
 415
 416  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _NDArray:
 417    return self.array.clip(low, high, dtype=dtype)
 418
 419  def square(self) -> _NDArray:
 420    return np.square(self.array)
 421
 422  def sqrt(self) -> _NDArray:
 423    return np.sqrt(self.array)
 424
 425  def where(self, if_true: Any, if_false: Any) -> _NDArray:
 426    condition = self.array
 427    return np.where(condition, if_true, if_false)
 428
 429  def transpose(self, axes: Sequence[int]) -> _NDArray:
 430    return np.transpose(self.array, tuple(axes))
 431
 432  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 433    # Our heuristics: (1) a dimension with small scaling (especially minification) gets priority,
 434    # and (2) timings show preference to resizing dimensions with larger strides first.
 435    # (Of course, tensorflow.Tensor lacks strides, so (2) does not apply.)
 436    # The optimal ordering might be related to the logic in np.einsum_path().  (Unfortunately,
 437    # np.einsum() does not support the sparse multiplications that we require here.)
 438    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 439    strides: Sequence[int] = self.array.strides
 440    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 441
 442    def priority(dim: int) -> float:
 443      scaling = dst_shape[dim] / src_shape[dim]
 444      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 445
 446    return sorted(range(len(src_shape)), key=priority)
 447
 448  def premult_with_sparse(
 449      self, sparse: scipy.sparse.csr_matrix, num_threads: int | Literal['auto']
 450  ) -> _NDArray:
 451    assert self.array.ndim == sparse.ndim == 2 and sparse.shape[1] == self.array.shape[0]
 452    # Empirically faster than with default numba.config.NUMBA_NUM_THREADS (e.g., 24).
 453    if using_numba:
 454      num_threads2 = min(6, os.cpu_count()) if num_threads == 'auto' else num_threads
 455      src = np.ascontiguousarray(self.array)  # Like .ravel() in _mul_multivector().
 456      dtype = np.result_type(sparse.dtype, src.dtype)
 457      dst = np.empty((sparse.shape[0], src.shape[1]), dtype)
 458      num_scalar_multiplies = len(sparse.data) * src.shape[1]
 459      is_small_size = num_scalar_multiplies < 200_000
 460      if is_small_size or num_threads2 == 1:
 461        _numba_serial_csr_dense_mult(sparse.indptr, sparse.indices, sparse.data, src, dst)
 462      else:
 463        numba.set_num_threads(num_threads2)
 464        _numba_parallel_csr_dense_mult(sparse.indptr, sparse.indices, sparse.data, src, dst)
 465      return dst
 466
 467    # Note that sicpy.sparse does not use multithreading.  The "@" operation
 468    # calls _spbase.__matmul__() -> _spbase._mul_dispatch() -> _cs_matrix._mul_multivector() ->
 469    # scipy.sparse._sparsetools.csr_matvecs() in
 470    # https://github.com/scipy/scipy/blob/main/scipy/sparse/sparsetools/csr.h
 471    # which iteratively calls the (in theory, LEVEL 1 BLAS) function axpy() in
 472    # https://github.com/scipy/scipy/blob/main/scipy/sparse/sparsetools/dense.h
 473    return sparse @ self.array
 474
 475  @staticmethod
 476  def concatenate(arrays: Sequence[_NDArray], axis: int) -> _NDArray:
 477    return np.concatenate(arrays, axis)
 478
 479  @staticmethod
 480  def einsum(subscripts: str, *operands: _NDArray) -> _NDArray:
 481    return np.einsum(subscripts, *operands, optimize=True)
 482
 483  @staticmethod
 484  def make_sparse_matrix(
 485      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 486  ) -> _NDArray:
 487    return scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=shape)
 488
 489
 490class _TensorflowArraylib(_Arraylib[_TensorflowTensor]):
 491  """Tensorflow implementation of the array abstraction."""
 492
 493  def __init__(self, array: _NDArray) -> None:
 494    import tensorflow
 495
 496    self.tf = tensorflow
 497    super().__init__(arraylib='tensorflow', array=self.tf.convert_to_tensor(array))
 498
 499  @staticmethod
 500  def recognize(array: Any) -> bool:
 501    # Eager: tensorflow.python.framework.ops.Tensor
 502    # Non-eager: tensorflow.python.ops.resource_variable_ops.ResourceVariable
 503    return type(array).__module__.startswith('tensorflow.')
 504
 505  def numpy(self) -> _NDArray:
 506    return self.array.numpy()
 507
 508  def dtype(self) -> _DType:
 509    return np.dtype(self.array.dtype.as_numpy_dtype)
 510
 511  def astype(self, dtype: _DTypeLike) -> _TensorflowTensor:
 512    return self.tf.cast(self.array, dtype)
 513
 514  def reshape(self, shape: tuple[int, ...]) -> _TensorflowTensor:
 515    return self.tf.reshape(self.array, shape)
 516
 517  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _TensorflowTensor:
 518    array = self.array
 519    if dtype is not None:
 520      array = self.tf.cast(array, dtype)
 521    return self.tf.clip_by_value(array, low, high)
 522
 523  def square(self) -> _TensorflowTensor:
 524    return self.tf.square(self.array)
 525
 526  def sqrt(self) -> _TensorflowTensor:
 527    return self.tf.sqrt(self.array)
 528
 529  def getitem(self, indices: Any) -> _TensorflowTensor:
 530    if isinstance(indices, tuple):
 531      basic = all(isinstance(x, (type(None), type(Ellipsis), int, slice)) for x in indices)
 532      if not basic:
 533        # We require tf.gather_nd(), which unfortunately requires broadcast expansion of indices.
 534        assert all(isinstance(a, np.ndarray) for a in indices)
 535        assert all(a.ndim == indices[0].ndim for a in indices)
 536        broadcast_indices = np.broadcast_arrays(*indices)  # list of np.ndarray
 537        indices_array = np.moveaxis(np.array(broadcast_indices), 0, -1)
 538        return self.tf.gather_nd(self.array, indices_array)
 539    elif _arr_dtype(indices).type in (np.uint8, np.uint16):
 540      indices = self.tf.cast(indices, np.int32)
 541    return self.tf.gather(self.array, indices)
 542
 543  def where(self, if_true: Any, if_false: Any) -> _TensorflowTensor:
 544    condition = self.array
 545    return self.tf.where(condition, if_true, if_false)
 546
 547  def transpose(self, axes: Sequence[int]) -> _TensorflowTensor:
 548    return self.tf.transpose(self.array, tuple(axes))
 549
 550  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 551    # Note that a tensorflow.Tensor does not have strides.
 552    # Our heuristic is to process dimension 1 first iff dimension 0 is upsampling.  Improve?
 553    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 554    dims = list(range(len(src_shape)))
 555    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 556      dims[:2] = [1, 0]
 557    return dims
 558
 559  def premult_with_sparse(
 560      self, sparse: 'tf.sparse.SparseTensor', num_threads: int | Literal['auto']
 561  ) -> _TensorflowTensor:
 562    import tensorflow as tf
 563
 564    del num_threads
 565    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 566      sparse = sparse.with_values(_arr_astype(sparse.values, _arr_dtype(self.array)))
 567    return tf.sparse.sparse_dense_matmul(sparse, self.array)
 568
 569  @staticmethod
 570  def concatenate(arrays: Sequence[_TensorflowTensor], axis: int) -> _TensorflowTensor:
 571    import tensorflow as tf
 572
 573    return tf.concat(arrays, axis)
 574
 575  @staticmethod
 576  def einsum(subscripts: str, *operands: _TensorflowTensor) -> _TensorflowTensor:
 577    import tensorflow as tf
 578
 579    return tf.einsum(subscripts, *operands, optimize='greedy')
 580
 581  @staticmethod
 582  def make_sparse_matrix(
 583      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 584  ) -> _TensorflowTensor:
 585    import tensorflow as tf
 586
 587    indices = np.vstack((row_ind, col_ind)).T
 588    return tf.sparse.SparseTensor(indices, data, shape)
 589
 590
 591class _TorchArraylib(_Arraylib[_TorchTensor]):
 592  """Torch implementation of the array abstraction."""
 593
 594  # pylint: disable=missing-function-docstring
 595
 596  def __init__(self, array: _NDArray) -> None:
 597    import torch
 598
 599    self.torch = torch
 600    super().__init__(arraylib='torch', array=self.torch.as_tensor(array))
 601
 602  @staticmethod
 603  def recognize(array: Any) -> bool:
 604    return type(array).__module__ == 'torch'
 605
 606  def numpy(self) -> _NDArray:
 607    return self.array.numpy()
 608
 609  def dtype(self) -> _DType:
 610    numpy_type = {
 611        self.torch.float32: np.float32,
 612        self.torch.float64: np.float64,
 613        self.torch.complex64: np.complex64,
 614        self.torch.complex128: np.complex128,
 615        self.torch.uint8: np.uint8,  # No uint16, uint32, uint64.
 616        self.torch.int16: np.int16,
 617        self.torch.int32: np.int32,
 618        self.torch.int64: np.int64,
 619    }[self.array.dtype]
 620    return np.dtype(numpy_type)
 621
 622  def astype(self, dtype: _DTypeLike) -> _TorchTensor:
 623    torch_type = {
 624        np.float32: self.torch.float32,
 625        np.float64: self.torch.float64,
 626        np.complex64: self.torch.complex64,
 627        np.complex128: self.torch.complex128,
 628        np.uint8: self.torch.uint8,  # No uint16, uint32, uint64.
 629        np.int16: self.torch.int16,
 630        np.int32: self.torch.int32,
 631        np.int64: self.torch.int64,
 632    }[np.dtype(dtype).type]
 633    return self.array.type(torch_type)
 634
 635  def possibly_make_contiguous(self) -> _TorchTensor:
 636    return self.array.contiguous()
 637
 638  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _TorchTensor:
 639    array = self.array
 640    array = _arr_astype(array, dtype) if dtype is not None else array
 641    return array.clip(low, high)
 642
 643  def square(self) -> _TorchTensor:
 644    return self.array.square()
 645
 646  def sqrt(self) -> _TorchTensor:
 647    return self.array.sqrt()
 648
 649  def getitem(self, indices: Any) -> _TorchTensor:
 650    if not isinstance(indices, tuple):
 651      indices = indices.type(self.torch.int64)
 652    return self.array[indices]  # pylint: disable=unsubscriptable-object
 653
 654  def where(self, if_true: Any, if_false: Any) -> _TorchTensor:
 655    condition = self.array
 656    return if_true.where(condition, if_false)
 657
 658  def transpose(self, axes: Sequence[int]) -> _TorchTensor:
 659    return self.torch.permute(self.array, tuple(axes))
 660
 661  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 662    # Similar to `_NumpyArraylib`.  We access `array.stride()` instead of `array.strides`.
 663    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 664    strides: Sequence[int] = self.array.stride()
 665    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 666
 667    def priority(dim: int) -> float:
 668      scaling = dst_shape[dim] / src_shape[dim]
 669      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 670
 671    return sorted(range(len(src_shape)), key=priority)
 672
 673  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _TorchTensor:
 674    del num_threads
 675    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 676      sparse = _arr_astype(sparse, _arr_dtype(self.array))
 677    return sparse @ self.array  # Calls torch.sparse.mm().
 678
 679  @staticmethod
 680  def concatenate(arrays: Sequence[_TorchTensor], axis: int) -> _TorchTensor:
 681    import torch
 682
 683    return torch.cat(tuple(arrays), axis)
 684
 685  @staticmethod
 686  def einsum(subscripts: str, *operands: _TorchTensor) -> _TorchTensor:
 687    import torch
 688
 689    operands = tuple(torch.as_tensor(operand) for operand in operands)
 690    if any(np.issubdtype(_arr_dtype(array), np.complexfloating) for array in operands):
 691      operands = tuple(
 692          _arr_astype(array, _complex_precision(_arr_dtype(array))) for array in operands
 693      )
 694    return torch.einsum(subscripts, *operands)
 695
 696  @staticmethod
 697  def make_sparse_matrix(
 698      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 699  ) -> _TorchTensor:
 700    import torch
 701
 702    indices = np.vstack((row_ind, col_ind))
 703    return torch.sparse_coo_tensor(torch.as_tensor(indices), torch.as_tensor(data), shape)
 704    # .coalesce() is unnecessary because indices/data are already merged.
 705
 706
 707class _JaxArraylib(_Arraylib[_JaxArray]):
 708  """Jax implementation of the array abstraction."""
 709
 710  def __init__(self, array: _NDArray) -> None:
 711    import jax.numpy
 712
 713    self.jnp = jax.numpy
 714    super().__init__(arraylib='jax', array=self.jnp.asarray(array))
 715
 716  @staticmethod
 717  def recognize(array: Any) -> bool:
 718    # e.g., jaxlib.xla_extension.DeviceArray, jax.interpreters.ad.JVPTracer
 719    return type(array).__module__.startswith(('jaxlib.', 'jax.'))
 720
 721  def numpy(self) -> _NDArray:
 722    # 2023-01-09: jax 0.3.17 "DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead."
 723    # Whereas array.to_py() and np.asarray(array) may return a non-writable np.ndarray,
 724    # np.array(array) always returns a writable array but the copy may be more costly.
 725    # return self.array.to_py()
 726    return np.asarray(self.array)
 727
 728  def dtype(self) -> _DType:
 729    return np.dtype(self.array.dtype)
 730
 731  def astype(self, dtype: _DTypeLike) -> _JaxArray:
 732    return self.array.astype(dtype)  # (copy=False is unavailable)
 733
 734  def possibly_make_contiguous(self) -> _JaxArray:
 735    return self.array.copy()
 736
 737  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _JaxArray:
 738    array = self.array
 739    if dtype is not None:
 740      array = array.astype(dtype)  # (copy=False is unavailable)
 741    return self.jnp.clip(array, low, high)
 742
 743  def square(self) -> _JaxArray:
 744    return self.jnp.square(self.array)
 745
 746  def sqrt(self) -> _JaxArray:
 747    return self.jnp.sqrt(self.array)
 748
 749  def where(self, if_true: Any, if_false: Any) -> _JaxArray:
 750    condition = self.array
 751    return self.jnp.where(condition, if_true, if_false)
 752
 753  def transpose(self, axes: Sequence[int]) -> _JaxArray:
 754    return self.jnp.transpose(self.array, tuple(axes))
 755
 756  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 757    # Jax/XLA does not have strides.  Arrays are contiguous, almost always in C order; see
 758    # https://github.com/google/jax/discussions/7544#discussioncomment-1197038.
 759    # We use a heuristic similar to `_TensorflowArraylib`.
 760    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 761    dims = list(range(len(src_shape)))
 762    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 763      dims[:2] = [1, 0]
 764    return dims
 765
 766  def premult_with_sparse(
 767      self, sparse: 'jax.experimental.sparse.BCOO', num_threads: int | Literal['auto']
 768  ) -> _JaxArray:
 769    del num_threads
 770    return sparse @ self.array  # Calls jax.bcoo_multiply_dense().
 771
 772  @staticmethod
 773  def concatenate(arrays: Sequence[_JaxArray], axis: int) -> _JaxArray:
 774    import jax.numpy as jnp
 775
 776    return jnp.concatenate(arrays, axis)
 777
 778  @staticmethod
 779  def einsum(subscripts: str, *operands: _JaxArray) -> _JaxArray:
 780    import jax.numpy as jnp
 781
 782    return jnp.einsum(subscripts, *operands, optimize='greedy')
 783
 784  @staticmethod
 785  def make_sparse_matrix(
 786      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 787  ) -> _JaxArray:
 788    # https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html
 789    import jax.experimental.sparse
 790
 791    indices = np.vstack((row_ind, col_ind)).T
 792    return jax.experimental.sparse.BCOO(
 793        (data, indices), shape=shape, indices_sorted=True, unique_indices=True
 794    )
 795
 796
 797_CANDIDATE_ARRAYLIBS = {
 798    'numpy': _NumpyArraylib,
 799    'tensorflow': _TensorflowArraylib,
 800    'torch': _TorchArraylib,
 801    'jax': _JaxArraylib,
 802}
 803
 804
 805def is_available(arraylib: str) -> bool:
 806  """Return whether the array library (e.g. 'tensorflow') is available as an installed package."""
 807  return importlib.util.find_spec(arraylib) is not None  # Faster than trying to import it.
 808
 809
 810_DICT_ARRAYLIBS = {
 811    arraylib: cls for arraylib, cls in _CANDIDATE_ARRAYLIBS.items() if is_available(arraylib)
 812}
 813
 814ARRAYLIBS = list(_DICT_ARRAYLIBS)
 815"""Array libraries supported automatically in the resize and resampling operations.
 816
 817- The library is selected automatically based on the type of the `array` function parameter.
 818
 819- The class `_Arraylib` provides library-specific implementations of needed basic functions.
 820
 821- The `_arr_*()` functions dispatch the `_Arraylib` methods based on the array type.
 822"""
 823
 824
 825def _as_arr(array: _Array, /) -> _Arraylib[_Array]:
 826  """Return `array` wrapped as an `_Arraylib` for dispatch of functions."""
 827  for cls in _DICT_ARRAYLIBS.values():
 828    if cls.recognize(array):
 829      return cls(array)  # type: ignore[abstract]
 830  raise ValueError(f'{array} {type(array)} {type(array).__module__} unrecognized by {ARRAYLIBS}.')
 831
 832
 833def _arr_arraylib(array: _Array, /) -> str:
 834  """Return the name of the `Arraylib` representing `array`."""
 835  return _as_arr(array).arraylib
 836
 837
 838def _arr_numpy(array: _Array, /) -> _NDArray:
 839  """Return a `numpy` version of `array`."""
 840  return _as_arr(array).numpy()
 841
 842
 843def _arr_dtype(array: _Array, /) -> _DType:
 844  """Return the equivalent of `array.dtype` as a `numpy` `dtype`."""
 845  return _as_arr(array).dtype()
 846
 847
 848def _arr_astype(array: _Array, dtype: _DTypeLike, /) -> _Array:
 849  """Return the equivalent of `array.astype(dtype)` with `numpy` `dtype`."""
 850  return _as_arr(array).astype(dtype)
 851
 852
 853def _arr_reshape(array: _Array, shape: tuple[int, ...], /) -> _Array:
 854  """Return the equivalent of `array.reshape(shape)."""
 855  return _as_arr(array).reshape(shape)
 856
 857
 858def _arr_possibly_make_contiguous(array: _Array, /) -> _Array:
 859  """Return a contiguous copy of `array` or just `array` if already contiguous."""
 860  return _as_arr(array).possibly_make_contiguous()
 861
 862
 863def _arr_clip(array: _Array, low: _Array, high: _Array, /, dtype: _DTypeLike = None) -> _Array:
 864  """Return the equivalent of `array.clip(low, high, dtype)` with `numpy` `dtype`."""
 865  return _as_arr(array).clip(low, high, dtype)
 866
 867
 868def _arr_square(array: _Array, /) -> _Array:
 869  """Return the equivalent of `np.square(array)`."""
 870  return _as_arr(array).square()
 871
 872
 873def _arr_sqrt(array: _Array, /) -> _Array:
 874  """Return the equivalent of `np.sqrt(array)`."""
 875  return _as_arr(array).sqrt()
 876
 877
 878def _arr_getitem(array: _Array, indices: _Array, /) -> _Array:
 879  """Return the equivalent of `array[indices]`."""
 880  return _as_arr(array).getitem(indices)
 881
 882
 883def _arr_where(condition: _Array, if_true: _Array, if_false: _Array, /) -> _Array:
 884  """Return the equivalent of `np.where(condition, if_true, if_false)`."""
 885  return _as_arr(condition).where(if_true, if_false)
 886
 887
 888def _arr_transpose(array: _Array, axes: Sequence[int], /) -> _Array:
 889  """Return the equivalent of `np.transpose(array, axes)`."""
 890  return _as_arr(array).transpose(axes)
 891
 892
 893def _arr_best_dims_order_for_resize(array: _Array, dst_shape: tuple[int, ...], /) -> list[int]:
 894  """Return the best order in which to process dims for resizing `array` to `dst_shape`."""
 895  return _as_arr(array).best_dims_order_for_resize(dst_shape)
 896
 897
 898def _arr_matmul_sparse_dense(
 899    sparse: Any, dense: _Array, /, *, num_threads: int | Literal['auto'] = 'auto'
 900) -> _Array:
 901  """Return the multiplication of the `sparse` and `dense` matrices."""
 902  assert num_threads == 'auto' or num_threads >= 1
 903  return _as_arr(dense).premult_with_sparse(sparse, num_threads)
 904
 905
 906def _arr_concatenate(arrays: Sequence[_Array], axis: int, /) -> _Array:
 907  """Return the equivalent of `np.concatenate(arrays, axis)`."""
 908  arraylib = _arr_arraylib(arrays[0])
 909  return _DICT_ARRAYLIBS[arraylib].concatenate(arrays, axis)  # type: ignore
 910
 911
 912def _arr_einsum(subscripts: str, /, *operands: _Array) -> _Array:
 913  """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 914  arraylib = _arr_arraylib(operands[0])
 915  return _DICT_ARRAYLIBS[arraylib].einsum(subscripts, *operands)  # type: ignore
 916
 917
 918def _arr_swapaxes(array: _Array, axis1: int, axis2: int, /) -> _Array:
 919  """Return the equivalent of `np.swapaxes(array, axis1, axis2)`."""
 920  ndim = len(array.shape)
 921  assert 0 <= axis1 < ndim and 0 <= axis2 < ndim, (axis1, axis2, ndim)
 922  axes = list(range(ndim))
 923  axes[axis1] = axis2
 924  axes[axis2] = axis1
 925  return _arr_transpose(array, axes)
 926
 927
 928def _arr_moveaxis(array: _Array, source: int, destination: int, /) -> _Array:
 929  """Return the equivalent of `np.moveaxis(array, source, destination)`."""
 930  ndim = len(array.shape)
 931  assert 0 <= source < ndim and 0 <= destination < ndim, (source, destination, ndim)
 932  axes = [n for n in range(ndim) if n != source]
 933  axes.insert(destination, source)
 934  return _arr_transpose(array, axes)
 935
 936
 937def _make_sparse_matrix(
 938    data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int], arraylib: str, /
 939) -> Any:
 940  """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 941  However, indices must be ordered and unique."""
 942  return _DICT_ARRAYLIBS[arraylib].make_sparse_matrix(data, row_ind, col_ind, shape)
 943
 944
 945def _make_array(array: _ArrayLike, arraylib: str, /) -> Any:
 946  """Return an array from the library `arraylib` initialized with the `numpy` `array`."""
 947  return _DICT_ARRAYLIBS[arraylib](np.asarray(array)).array  # type: ignore[abstract]
 948
 949
 950# Because np.ndarray supports strides, np.moveaxis() and np.permute() are constant-time.
 951# However, ndarray.reshape() often creates a copy of the array if the data is non-contiguous,
 952# e.g. dim=1 in an RGB image.
 953#
 954# In contrast, tf.Tensor does not support strides, so tf.transpose() returns a new permuted
 955# tensor.  However, tf.reshape() is always efficient.
 956
 957
 958def _block_shape_with_min_size(
 959    shape: tuple[int, ...], min_size: int, compact: bool = True
 960) -> tuple[int, ...]:
 961  """Return shape of block (of size at least `min_size`) to subdivide shape."""
 962  if math.prod(shape) < min_size:
 963    raise ValueError(f'Shape {shape} smaller than min_size {min_size}.')
 964  if compact:
 965    root = int(math.ceil(min_size ** (1 / len(shape))))
 966    block_shape = np.minimum(shape, root)
 967    for dim in range(len(shape)):
 968      if block_shape[dim] == 2 and block_shape.prod() >= min_size * 2:
 969        block_shape[dim] = 1
 970    for dim in range(len(shape) - 1, -1, -1):
 971      if block_shape.prod() < min_size:
 972        block_shape[dim] = shape[dim]
 973  else:
 974    block_shape = np.ones_like(shape)
 975    for dim in range(len(shape) - 1, -1, -1):
 976      if block_shape.prod() < min_size:
 977        block_shape[dim] = min(shape[dim], math.ceil(min_size / block_shape.prod()))
 978  return tuple(block_shape)
 979
 980
 981def _array_split(array: _Array, axis: int, num_sections: int) -> list[Any]:
 982  """Split `array` into `num_sections` along `axis`."""
 983  assert 0 <= axis < len(array.shape)
 984  assert 1 <= num_sections <= array.shape[axis]
 985
 986  if 0:
 987    split = np.array_split(array, num_sections, axis=axis)  # Numpy-specific.
 988
 989  else:
 990    # Adapted from https://github.com/numpy/numpy/blob/main/numpy/lib/shape_base.py#L739-L792.
 991    num_total = array.shape[axis]
 992    num_each, num_extra = divmod(num_total, num_sections)
 993    section_sizes = [0] + num_extra * [num_each + 1] + (num_sections - num_extra) * [num_each]
 994    div_points = np.array(section_sizes).cumsum()
 995    split = []
 996    tmp = _arr_swapaxes(array, axis, 0)
 997    for i in range(num_sections):
 998      split.append(_arr_swapaxes(tmp[div_points[i] : div_points[i + 1]], axis, 0))
 999
1000  return split
1001
1002
1003def _split_array_into_blocks(array: _Array, block_shape: Sequence[int], start_axis: int = 0) -> Any:
1004  """Split `array` into nested lists of blocks of size at most `block_shape`."""
1005  # See https://stackoverflow.com/a/50305924.  (If the block_shape is known to
1006  # exactly partition the array, see https://stackoverflow.com/a/16858283.)
1007  if len(block_shape) > len(array.shape):
1008    raise ValueError(f'Block ndim {len(block_shape)} > array ndim {len(array.shape)}.')
1009  if start_axis == len(block_shape):
1010    return array
1011
1012  num_sections = math.ceil(array.shape[start_axis] / block_shape[start_axis])
1013  split = _array_split(array, start_axis, num_sections)
1014  return [_split_array_into_blocks(split_a, block_shape, start_axis + 1) for split_a in split]
1015
1016
1017def _map_function_over_blocks(blocks: Any, func: Callable[[Any], Any]) -> Any:
1018  """Apply `func` to each block in the nested lists of `blocks`."""
1019  if isinstance(blocks, list):
1020    return [_map_function_over_blocks(block, func) for block in blocks]
1021  return func(blocks)
1022
1023
1024def _merge_array_from_blocks(blocks: Any, axis: int = 0) -> Any:
1025  """Merge an array from the nested lists of array blocks in `blocks`."""
1026  # More general than np.block() because the blocks can have additional dims.
1027  if isinstance(blocks, list):
1028    new_blocks = [_merge_array_from_blocks(block, axis + 1) for block in blocks]
1029    return _arr_concatenate(new_blocks, axis)
1030  return blocks
1031
1032
1033@dataclasses.dataclass(frozen=True)
1034class Gridtype(abc.ABC):
1035  """Abstract base class for grid-types such as `'dual'` and `'primal'`.
1036
1037  In resampling operations, the grid-type may be specified separately as `src_gridtype` for the
1038  source domain and `dst_gridtype` for the destination domain.  Moreover, the grid-type may be
1039  specified per domain dimension.
1040
1041  Examples:
1042    `resize(source, shape, gridtype='primal')`  # Sets both src and dst to be `'primal'` grids.
1043
1044    `resize(source, shape, src_gridtype=['dual', 'primal'],
1045            dst_gridtype='dual')`  # Source is `'dual'` in dim0 and `'primal'` in dim1.
1046  """
1047
1048  name: str
1049  """Gridtype name."""
1050
1051  @abc.abstractmethod
1052  def min_size(self) -> int:
1053    """Return the necessary minimum number of grid samples."""
1054
1055  @abc.abstractmethod
1056  def size_in_samples(self, size: int, /) -> int:
1057    """Return the domain size in units of inter-sample spacing."""
1058
1059  @abc.abstractmethod
1060  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1061    """Return [0.0, 1.0] coordinates given [0, size - 1] indices."""
1062
1063  @abc.abstractmethod
1064  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1065    """Return location x given coordinates [0.0, 1.0], where x == 0.0 is the first grid sample
1066    and x == size - 1.0 is the last grid sample."""
1067
1068  @abc.abstractmethod
1069  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1070    """Map integer sample indices to interior ones using boundary reflection."""
1071
1072  @abc.abstractmethod
1073  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1074    """Map integer sample indices to interior ones using wrapping."""
1075
1076  @abc.abstractmethod
1077  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1078    """Map integer sample indices to interior ones using reflect-clamp."""
1079
1080
1081class DualGridtype(Gridtype):
1082  """Samples are at the center of cells in a uniform partition of the domain.
1083
1084  For a unit-domain dimension with N samples, each sample 0 <= i < N has position (i + 0.5) / N,
1085  e.g., [0.125, 0.375, 0.625, 0.875] for N = 4.
1086  """
1087
1088  def __init__(self) -> None:
1089    super().__init__(name='dual')
1090
1091  def min_size(self) -> int:
1092    return 1
1093
1094  def size_in_samples(self, size: int, /) -> int:
1095    return size
1096
1097  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1098    return (index + 0.5) / size
1099
1100  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1101    return point * size - 0.5
1102
1103  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1104    index = np.mod(index, size * 2)
1105    return np.where(index < size, index, 2 * size - 1 - index)
1106
1107  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1108    return np.mod(index, size)
1109
1110  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1111    return np.minimum(np.where(index < 0, -1 - index, index), size - 1)
1112
1113
1114class PrimalGridtype(Gridtype):
1115  """Samples are at the vertices of cells in a uniform partition of the domain.
1116
1117  For a unit-domain dimension with N samples, each sample 0 <= i < N has position i / (N - 1),
1118  e.g., [0, 1/3, 2/3, 1] for N = 4.
1119  """
1120
1121  def __init__(self) -> None:
1122    super().__init__(name='primal')
1123
1124  def min_size(self) -> int:
1125    return 2
1126
1127  def size_in_samples(self, size: int, /) -> int:
1128    return size - 1
1129
1130  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1131    return index / (size - 1)
1132
1133  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1134    return point * (size - 1)
1135
1136  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1137    index = np.mod(index, size * 2 - 2)
1138    return np.where(index < size, index, 2 * size - 2 - index)
1139
1140  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1141    return np.mod(index, size - 1)
1142
1143  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1144    return np.minimum(np.abs(index), size - 1)
1145
1146
1147_DICT_GRIDTYPES = {
1148    'dual': DualGridtype(),
1149    'primal': PrimalGridtype(),
1150}
1151
1152GRIDTYPES = list(_DICT_GRIDTYPES)
1153r"""Shortcut names for the two predefined grid types (specified per dimension):
1154
1155| `gridtype` | `'dual'`<br/>`DualGridtype()`<br/>(default) | `'primal'`<br/>`PrimalGridtype()`<br/>&nbsp; |
1156| --- |:---:|:---:|
1157| Sample positions in 2D<br/>and in 1D at different resolutions | ![Dual](https://github.com/hhoppe/resampler/raw/main/media/dual_grid_small.png) | ![Primal](https://github.com/hhoppe/resampler/raw/main/media/primal_grid_small.png) |
1158| Nesting of samples across resolutions | The samples positions do *not* nest. | The *even* samples remain at coarser scale. |
1159| Number $N_\ell$ of samples (per-dimension) at resolution level $\ell$ | $N_\ell=2^\ell$ | $N_\ell=2^\ell+1$ |
1160| Position of sample index $i$ within domain $[0, 1]$ | $\frac{i + 0.5}{N}$ ("half-integer" coordinates) | $\frac{i}{N-1}$ |
1161| Image resolutions ($N_\ell\times N_\ell$) for dyadic scales | $1\times1, ~~2\times2, ~~4\times4, ~~8\times8, ~\ldots$ | $2\times2, ~~3\times3, ~~5\times5, ~~9\times9, ~\ldots$ |
1162
1163See the source code for extensibility.
1164"""
1165
1166
1167def _get_gridtype(gridtype: str | Gridtype) -> Gridtype:
1168  """Return a `Gridtype`, which can be specified as a name in `GRIDTYPES`."""
1169  return gridtype if isinstance(gridtype, Gridtype) else _DICT_GRIDTYPES[gridtype]
1170
1171
1172def _get_gridtypes(
1173    gridtype: str | Gridtype | None,
1174    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1175    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1176    src_ndim: int,
1177    dst_ndim: int,
1178) -> tuple[list[Gridtype], list[Gridtype]]:
1179  """Return per-dim source and destination grid types given all parameters."""
1180  if gridtype is None and src_gridtype is None and dst_gridtype is None:
1181    gridtype = 'dual'
1182  if gridtype is not None:
1183    if src_gridtype is not None:
1184      raise ValueError('Cannot have both gridtype and src_gridtype.')
1185    if dst_gridtype is not None:
1186      raise ValueError('Cannot have both gridtype and dst_gridtype.')
1187    src_gridtype = dst_gridtype = gridtype
1188  src_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(src_gridtype), src_ndim)]
1189  dst_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(dst_gridtype), dst_ndim)]
1190  return src_gridtype2, dst_gridtype2
1191
1192
1193@dataclasses.dataclass(frozen=True)
1194class RemapCoordinates(abc.ABC):
1195  """Abstract base class for modifying the specified coordinates prior to evaluating the
1196  reconstruction kernels."""
1197
1198  @abc.abstractmethod
1199  def __call__(self, point: _NDArray, /) -> _NDArray:
1200    ...
1201
1202
1203class NoRemapCoordinates(RemapCoordinates):
1204  """The coordinates are not remapped."""
1205
1206  def __call__(self, point: _NDArray, /) -> _NDArray:
1207    return point
1208
1209
1210class MirrorRemapCoordinates(RemapCoordinates):
1211  """The coordinates are reflected across the domain boundaries so that they lie in the unit
1212  interval.  The resulting function is continuous but not smooth across the boundaries."""
1213
1214  def __call__(self, point: _NDArray, /) -> _NDArray:
1215    point = np.mod(point, 2.0)
1216    return np.where(point >= 1.0, 2.0 - point, point)
1217
1218
1219class TileRemapCoordinates(RemapCoordinates):
1220  """The coordinates are mapped to the unit interval using a "modulo 1.0" operation.  The resulting
1221  function is generally discontinuous across the domain boundaries."""
1222
1223  def __call__(self, point: _NDArray, /) -> _NDArray:
1224    return np.mod(point, 1.0)
1225
1226
1227@dataclasses.dataclass(frozen=True)
1228class ExtendSamples(abc.ABC):
1229  """Abstract base class for replacing references to grid samples exterior to the unit domain by
1230  affine combinations of interior sample(s) and possibly the constant value (`cval`)."""
1231
1232  uses_cval: bool = False
1233  """True if some exterior samples are defined in terms of `cval`, i.e., if the computed weight
1234  is non-affine."""
1235
1236  @abc.abstractmethod
1237  def __call__(
1238      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1239  ) -> tuple[_NDArray, _NDArray]:
1240    """Detect references to exterior samples, i.e., entries of `index` that lie outside the
1241    interval [0, size), and update these indices (and possibly their associated weights) to
1242    reference only interior samples.  Return `new_index, new_weight`."""
1243
1244
1245class ReflectExtendSamples(ExtendSamples):
1246  """Find the interior sample by reflecting across domain boundaries."""
1247
1248  def __call__(
1249      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1250  ) -> tuple[_NDArray, _NDArray]:
1251    index = gridtype.reflect(index, size)
1252    return index, weight
1253
1254
1255class WrapExtendSamples(ExtendSamples):
1256  """Wrap the interior samples periodically.  For a `'primal'` grid, the last
1257  sample is ignored as its value is replaced by the first sample."""
1258
1259  def __call__(
1260      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1261  ) -> tuple[_NDArray, _NDArray]:
1262    index = gridtype.wrap(index, size)
1263    return index, weight
1264
1265
1266class ClampExtendSamples(ExtendSamples):
1267  """Use the nearest interior sample."""
1268
1269  def __call__(
1270      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1271  ) -> tuple[_NDArray, _NDArray]:
1272    index = index.clip(0, size - 1)
1273    return index, weight
1274
1275
1276class ReflectClampExtendSamples(ExtendSamples):
1277  """Extend the grid samples from [0, 1] into [-1, 0] using reflection and then define grid
1278  samples outside [-1, 1] as that of the nearest sample."""
1279
1280  def __call__(
1281      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1282  ) -> tuple[_NDArray, _NDArray]:
1283    index = gridtype.reflect_clamp(index, size)
1284    return index, weight
1285
1286
1287class BorderExtendSamples(ExtendSamples):
1288  """Let all exterior samples have the constant value (`cval`)."""
1289
1290  def __init__(self) -> None:
1291    super().__init__(uses_cval=True)
1292
1293  def __call__(
1294      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1295  ) -> tuple[_NDArray, _NDArray]:
1296    low = index < 0
1297    weight[low] = 0.0
1298    index[low] = 0
1299    high = index >= size
1300    weight[high] = 0.0
1301    index[high] = size - 1
1302    return index, weight
1303
1304
1305class ValidExtendSamples(ExtendSamples):
1306  """Assign all domain samples weight 1 and all outside samples weight 0.
1307  Compute a weighted reconstruction and divide by the reconstructed weight."""
1308
1309  def __init__(self) -> None:
1310    super().__init__(uses_cval=True)
1311
1312  def __call__(
1313      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1314  ) -> tuple[_NDArray, _NDArray]:
1315    low = index < 0
1316    weight[low] = 0.0
1317    index[low] = 0
1318    high = index >= size
1319    weight[high] = 0.0
1320    index[high] = size - 1
1321    sum_weight = weight.sum(axis=-1)
1322    nonzero_sum = sum_weight != 0.0
1323    np.divide(weight, sum_weight[..., None], out=weight, where=nonzero_sum[..., None])
1324    return index, weight
1325
1326
1327class LinearExtendSamples(ExtendSamples):
1328  """Linearly extrapolate beyond boundary samples."""
1329
1330  def __call__(
1331      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1332  ) -> tuple[_NDArray, _NDArray]:
1333    if size < 2:
1334      index = gridtype.reflect(index, size)
1335      return index, weight
1336    # For each boundary, define new columns in index and weight arrays to represent the last and
1337    # next-to-last samples.  When we later construct the sparse resize matrix, we will sum the
1338    # duplicate index entries.
1339    low = index < 0
1340    high = index >= size
1341    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 4), weight.dtype)
1342    x = index
1343    w[..., -4] = ((1 - x) * weight).sum(where=low, axis=-1)
1344    w[..., -3] = ((x) * weight).sum(where=low, axis=-1)
1345    x = (size - 1) - index
1346    w[..., -2] = ((x) * weight).sum(where=high, axis=-1)
1347    w[..., -1] = ((1 - x) * weight).sum(where=high, axis=-1)
1348    weight[low] = 0.0
1349    index[low] = 0
1350    weight[high] = 0.0
1351    index[high] = size - 1
1352    w[..., :-4] = weight
1353    weight = w
1354    new_index = np.empty(w.shape, index.dtype)
1355    new_index[..., :-4] = index
1356    # Let matrix (including zero values) be banded.
1357    new_index[..., -4:] = np.where(w[..., -4:] != 0.0, [0, 1, size - 2, size - 1], index[..., :1])
1358    index = new_index
1359    return index, weight
1360
1361
1362class QuadraticExtendSamples(ExtendSamples):
1363  """Quadratically extrapolate beyond boundary samples."""
1364
1365  def __call__(
1366      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1367  ) -> tuple[_NDArray, _NDArray]:
1368    # [Keys 1981] suggests this as x[-1] = 3*x[0] - 3*x[1] + x[2], calling it "cubic precision",
1369    # but it seems just quadratic.
1370    if size < 3:
1371      index = gridtype.reflect(index, size)
1372      return index, weight
1373    low = index < 0
1374    high = index >= size
1375    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 6), weight.dtype)
1376    x = index
1377    w[..., -6] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=low, axis=-1)
1378    w[..., -5] = (((-x + 2) * x) * weight).sum(where=low, axis=-1)
1379    w[..., -4] = (((0.5 * x - 0.5) * x) * weight).sum(where=low, axis=-1)
1380    x = (size - 1) - index
1381    w[..., -3] = (((0.5 * x - 0.5) * x) * weight).sum(where=high, axis=-1)
1382    w[..., -2] = (((-x + 2) * x) * weight).sum(where=high, axis=-1)
1383    w[..., -1] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=high, axis=-1)
1384    weight[low] = 0.0
1385    index[low] = 0
1386    weight[high] = 0.0
1387    index[high] = size - 1
1388    w[..., :-6] = weight
1389    weight = w
1390    new_index = np.empty(w.shape, index.dtype)
1391    new_index[..., :-6] = index
1392    # Let matrix (including zero values) be banded.
1393    new_index[..., -6:] = np.where(
1394        w[..., -6:] != 0.0, [0, 1, 2, size - 3, size - 2, size - 1], index[..., :1]
1395    )
1396    index = new_index
1397    return index, weight
1398
1399
1400@dataclasses.dataclass(frozen=True)
1401class OverrideExteriorValue:
1402  """Abstract base class to set the value outside some domain extent to a
1403  constant value (`cval`)."""
1404
1405  boundary_antialiasing: bool = True
1406  """Antialias the pixel values adjacent to the boundary of the extent."""
1407
1408  uses_cval: bool = False
1409  """Modify some weights to introduce references to the `cval` constant value."""
1410
1411  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1412    """For all `point` outside some extent, modify the weight to be zero."""
1413
1414  def override_using_signed_distance(
1415      self, weight: _NDArray, point: _NDArray, signed_distance: _NDArray, /
1416  ) -> None:
1417    """Reduce sample weights for "outside" values based on the signed distance function,
1418    to effectively assign the constant value `cval`."""
1419    all_points_inside_domain = np.all(signed_distance <= 0.0)
1420    if all_points_inside_domain:
1421      return
1422    if self.boundary_antialiasing and min(point.shape) >= 2:
1423      # For discontinuous coordinate mappings, we may need to somehow ignore
1424      # the large finite differences computed across the map discontinuities.
1425      gradient = np.gradient(point)
1426      gradient_norm = np.linalg.norm(np.atleast_2d(gradient), axis=0)
1427      signed_distance_in_samples = signed_distance / (gradient_norm + 1e-20)
1428      # Opacity is in linear space, which is correct if Gamma is set.
1429      opacity = (0.5 - signed_distance_in_samples).clip(0.0, 1.0)
1430      weight *= opacity[..., None]
1431    else:
1432      is_outside = signed_distance > 0.0
1433      weight[is_outside, :] = 0.0
1434
1435
1436class NoOverrideExteriorValue(OverrideExteriorValue):
1437  """The function value is not overridden."""
1438
1439  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1440    pass
1441
1442
1443class UnitDomainOverrideExteriorValue(OverrideExteriorValue):
1444  """Values outside the unit interval [0, 1] are replaced by the constant `cval`."""
1445
1446  def __init__(self, **kwargs: Any) -> None:
1447    super().__init__(uses_cval=True, **kwargs)
1448
1449  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1450    signed_distance = abs(point - 0.5) - 0.5  # Boundaries at 0.0 and 1.0.
1451    self.override_using_signed_distance(weight, point, signed_distance)
1452
1453
1454class PlusMinusOneOverrideExteriorValue(OverrideExteriorValue):
1455  """Values outside the interval [-1, 1] are replaced by the constant `cval`."""
1456
1457  def __init__(self, **kwargs: Any) -> None:
1458    super().__init__(uses_cval=True, **kwargs)
1459
1460  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1461    signed_distance = abs(point) - 1.0  # Boundaries at -1.0 and 1.0.
1462    self.override_using_signed_distance(weight, point, signed_distance)
1463
1464
1465@dataclasses.dataclass(frozen=True)
1466class Boundary:
1467  """Domain boundary rules.  These define the reconstruction over the source domain near and beyond
1468  the domain boundaries.  The rules may be specified separately for each domain dimension."""
1469
1470  name: str = ''
1471  """Boundary rule name."""
1472
1473  coord_remap: RemapCoordinates = NoRemapCoordinates()
1474  """Modify specified coordinates prior to evaluating the reconstruction kernels."""
1475
1476  extend_samples: ExtendSamples = ReflectExtendSamples()
1477  """Define the value of each grid sample outside the unit domain as an affine combination of
1478  interior sample(s) and possibly the constant value (`cval`)."""
1479
1480  override_value: OverrideExteriorValue = NoOverrideExteriorValue()
1481  """Set the value outside some extent to a constant value (`cval`)."""
1482
1483  @property
1484  def uses_cval(self) -> bool:
1485    """True if weights may be non-affine, involving the constant value (`cval`)."""
1486    return self.extend_samples.uses_cval or self.override_value.uses_cval
1487
1488  def preprocess_coordinates(self, point: _NDArray, /) -> _NDArray:
1489    """Modify coordinates prior to evaluating the filter kernels."""
1490    # Antialiasing across the tile boundaries may be feasible but seems hard.
1491    point = self.coord_remap(point)
1492    return point
1493
1494  def apply(
1495      self, index: _NDArray, weight: _NDArray, point: _NDArray, size: int, gridtype: Gridtype, /
1496  ) -> tuple[_NDArray, _NDArray]:
1497    """Replace exterior samples by combinations of interior samples."""
1498    index, weight = self.extend_samples(index, weight, size, gridtype)
1499    self.override_reconstruction(weight, point)
1500    return index, weight
1501
1502  def override_reconstruction(self, weight: _NDArray, point: _NDArray, /) -> None:
1503    """For points outside an extent, modify weight to zero to assign `cval`."""
1504    self.override_value(weight, point)
1505
1506
1507_DICT_BOUNDARIES = {
1508    'reflect': Boundary('reflect', extend_samples=ReflectExtendSamples()),
1509    'wrap': Boundary('wrap', extend_samples=WrapExtendSamples()),
1510    'tile': Boundary(
1511        'title', coord_remap=TileRemapCoordinates(), extend_samples=ReflectExtendSamples()
1512    ),
1513    'clamp': Boundary('clamp', extend_samples=ClampExtendSamples()),
1514    'border': Boundary('border', extend_samples=BorderExtendSamples()),
1515    'natural': Boundary(
1516        'natural',
1517        extend_samples=ValidExtendSamples(),
1518        override_value=UnitDomainOverrideExteriorValue(),
1519    ),
1520    'linear_constant': Boundary(
1521        'linear_constant',
1522        extend_samples=LinearExtendSamples(),
1523        override_value=UnitDomainOverrideExteriorValue(),
1524    ),
1525    'quadratic_constant': Boundary(
1526        'quadratic_constant',
1527        extend_samples=QuadraticExtendSamples(),
1528        override_value=UnitDomainOverrideExteriorValue(),
1529    ),
1530    'reflect_clamp': Boundary('reflect_clamp', extend_samples=ReflectClampExtendSamples()),
1531    'constant': Boundary(
1532        'constant',
1533        extend_samples=ReflectExtendSamples(),
1534        override_value=UnitDomainOverrideExteriorValue(),
1535    ),
1536    'linear': Boundary('linear', extend_samples=LinearExtendSamples()),
1537    'quadratic': Boundary('quadratic', extend_samples=QuadraticExtendSamples()),
1538}
1539
1540BOUNDARIES = list(_DICT_BOUNDARIES)
1541"""Shortcut names for some predefined boundary rules (as defined by `_DICT_BOUNDARIES`):
1542
1543| name                   | a.k.a. / comments |
1544|------------------------|-------------------|
1545| `'reflect'`            | *reflected*, *symm*, *symmetric*, *mirror*, *grid-mirror* |
1546| `'wrap'`               | *periodic*, *repeat*, *grid-wrap* |
1547| `'tile'`               | like `'reflect'` within unit domain, then tile discontinuously |
1548| `'clamp'`              | *clamped*, *nearest*, *edge*, *clamp-to-edge*, repeat last sample |
1549| `'border'`             | *grid-constant*, use `cval` for samples outside unit domain |
1550| `'natural'`            | *renormalize* using only interior samples, use `cval` outside domain |
1551| `'reflect_clamp'`      | *mirror-clamp-to-edge* |
1552| `'constant'`           | like `'reflect'` but replace by `cval` outside unit domain |
1553| `'linear'`             | extrapolate from 2 last samples |
1554| `'quadratic'`          | extrapolate from 3 last samples |
1555| `'linear_constant'`    | like `'linear'` but replace by `cval` outside unit domain |
1556| `'quadratic_constant'` | like `'quadratic'` but replace by `cval` outside unit domain |
1557
1558These boundary rules may be specified per dimension.  See the source code for extensibility
1559using the classes `RemapCoordinates`, `ExtendSamples`, and `OverrideExteriorValue`.
1560
1561**Boundary rules illustrated in 1D:**
1562
1563<center>
1564<img src="https://github.com/hhoppe/resampler/raw/main/media/boundary_rules_in_1D.png" width="100%"/>
1565</center>
1566
1567**Boundary rules illustrated in 2D:**
1568
1569<center>
1570<img src="https://github.com/hhoppe/resampler/raw/main/media/boundary_rules_in_2D.png" width="100%"/>
1571</center>
1572"""
1573
1574_OFTUSED_BOUNDARIES = (
1575    'reflect wrap tile clamp border natural linear_constant quadratic_constant'.split()
1576)
1577"""A useful subset of `BOUNDARIES` for visualization in figures."""
1578
1579
1580def _get_boundary(boundary: str | Boundary, /) -> Boundary:
1581  """Return a `Boundary`, which can be specified as a name in `BOUNDARIES`."""
1582  return boundary if isinstance(boundary, Boundary) else _DICT_BOUNDARIES[boundary]
1583
1584
1585@dataclasses.dataclass(frozen=True)
1586class Filter(abc.ABC):
1587  """Abstract base class for filter kernel functions.
1588
1589  Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support
1590  interval [-radius, radius].  (Some sites instead define kernels over the interval [0, N]
1591  where N = 2 * radius.)
1592
1593  Portions of this code are adapted from the C++ library in
1594  https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp
1595
1596  See also https://hhoppe.com/proj/filtering/.
1597  """
1598
1599  name: str
1600  """Filter kernel name."""
1601
1602  radius: float
1603  """Max absolute value of x for which self(x) is nonzero."""
1604
1605  interpolating: bool = True
1606  """True if self(0) == 1.0 and self(i) == 0.0 for all nonzero integers i."""
1607
1608  continuous: bool = True
1609  """True if the kernel function has $C^0$ continuity."""
1610
1611  partition_of_unity: bool = True
1612  """True if the convolution of the kernel with a Dirac comb reproduces the
1613  unity function."""
1614
1615  unit_integral: bool = True
1616  """True if the integral of the kernel function is 1."""
1617
1618  requires_digital_filter: bool = False
1619  """True if the filter needs a pre/post digital filter for interpolation."""
1620
1621  @abc.abstractmethod
1622  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1623    """Return evaluation of filter kernel at locations x."""
1624
1625
1626class ImpulseFilter(Filter):
1627  """See https://en.wikipedia.org/wiki/Dirac_delta_function."""
1628
1629  def __init__(self) -> None:
1630    super().__init__(name='impulse', radius=1e-20, continuous=False, partition_of_unity=False)
1631
1632  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1633    raise AssertionError('The Impulse is infinitely narrow, so cannot be directly evaluated.')
1634
1635
1636class BoxFilter(Filter):
1637  """See https://en.wikipedia.org/wiki/Box_function.
1638
1639  The kernel function has value 1.0 over the half-open interval [-.5, .5).
1640  """
1641
1642  def __init__(self) -> None:
1643    super().__init__(name='box', radius=0.5, continuous=False)
1644
1645  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1646    use_asymmetric = True
1647    if use_asymmetric:
1648      x = np.asarray(x)
1649      return np.where((-0.5 <= x) & (x < 0.5), 1.0, 0.0)
1650    x = np.abs(x)
1651    return np.where(x < 0.5, 1.0, np.where(x == 0.5, 0.5, 0.0))
1652
1653
1654class TrapezoidFilter(Filter):
1655  """Filter for antialiased "area-based" filtering.
1656
1657  Args:
1658    radius: Specifies the support [-radius, radius] of the filter, where 0.5 < radius <= 1.0.
1659      The special case `radius = None` is a placeholder that indicates that the filter will be
1660      replaced by a trapezoid of the appropriate radius (based on scaling) for correct
1661      antialiasing in both minification and magnification.
1662
1663  This filter is similar to the BoxFilter but with linearly sloped sides.  It has value 1.0
1664  in the interval abs(x) <= 1.0 - radius and decreases linearly to value 0.0 in the interval
1665  1.0 - radius <= abs(x) <= radius, always with value 0.5 at x = 0.5.
1666  """
1667
1668  def __init__(self, *, radius: float | None = None) -> None:
1669    if radius is None:
1670      super().__init__(name='trapezoid', radius=0.0)
1671      return
1672    if not 0.5 < radius <= 1.0:
1673      raise ValueError(f'Radius {radius} is outside the range (0.5, 1.0].')
1674    super().__init__(name=f'trapezoid_{radius}', radius=radius)
1675
1676  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1677    x = np.abs(x)
1678    assert 0.5 < self.radius <= 1.0
1679    return ((0.5 + 0.25 / (self.radius - 0.5)) - (0.5 / (self.radius - 0.5)) * x).clip(0.0, 1.0)
1680
1681
1682class TriangleFilter(Filter):
1683  """See https://en.wikipedia.org/wiki/Triangle_function.
1684
1685  Also known as the hat or tent function.  It is used for piecewise-linear
1686  (or bilinear, or trilinear, ...) interpolation.
1687  """
1688
1689  def __init__(self) -> None:
1690    super().__init__(name='triangle', radius=1.0)
1691
1692  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1693    return (1.0 - np.abs(x)).clip(0.0, 1.0)
1694
1695
1696class CubicFilter(Filter):
1697  """Family of cubic filters parameterized by two scalar parameters.
1698
1699  Args:
1700    b: first scalar parameter.
1701    c: second scalar parameter.
1702
1703  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1704  https://doi.org/10.1145/378456.378514.
1705
1706  [D. P. Mitchell and A. N. Netravali. Reconstruction filters in computer graphics.
1707  Computer Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1708
1709  - The filter has quadratic precision iff b + 2 * c == 1.
1710  - The filter is interpolating iff b == 0.
1711  - (b=1, c=0) is the (non-interpolating) cubic B-spline basis;
1712  - (b=1/3, c=1/3) is the Mitchell filter;
1713  - (b=0, c=0.5) is the Catmull-Rom spline (which has cubic precision);
1714  - (b=0, c=0.75) is the "sharper cubic" used in Photoshop and OpenCV.
1715  """
1716
1717  def __init__(self, *, b: float, c: float, name: str | None = None) -> None:
1718    name = f'cubic_b{b}_c{c}' if name is None else name
1719    interpolating = b == 0
1720    super().__init__(name=name, radius=2.0, interpolating=interpolating)
1721    self.b, self.c = b, c
1722
1723  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1724    x = np.abs(x)
1725    b, c = self.b, self.c
1726    f3, f2, f0 = 2 - 9 / 6 * b - c, -3 + 2 * b + c, 1 - 1 / 3 * b
1727    g3, g2, g1, g0 = -b / 6 - c, b + 5 * c, -2 * b - 8 * c, 8 / 6 * b + 4 * c
1728    # (np.polynomial.polynomial.polyval(x, [f0, 0, f2, f3]) is almost
1729    # twice as slow; see also https://stackoverflow.com/questions/24065904)
1730    v01 = ((f3 * x + f2) * x) * x + f0
1731    v12 = ((g3 * x + g2) * x + g1) * x + g0
1732    return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1733
1734
1735class CatmullRomFilter(CubicFilter):
1736  """Cubic filter with cubic precision.  Also known as Keys filter.
1737
1738  [E. Catmull, R. Rom.  A class of local interpolating splines.  Computer aided geometric
1739  design, 1974]
1740  [Wikipedia](https://en.wikipedia.org/wiki/Cubic_Hermite_spline#Catmull%E2%80%93Rom_spline)
1741
1742  [R. G. Keys.  Cubic convolution interpolation for digital image processing.
1743  IEEE Trans. on Acoustics, Speech, and Signal Processing, 29(6), 1981.]
1744  https://ieeexplore.ieee.org/document/1163711/.
1745  """
1746
1747  def __init__(self) -> None:
1748    super().__init__(b=0, c=0.5, name='cubic')
1749
1750
1751class MitchellFilter(CubicFilter):
1752  """See https://doi.org/10.1145/378456.378514.
1753
1754  [D. P. Mitchell and A. N. Netravali.  Reconstruction filters in computer graphics.  Computer
1755  Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1756  """
1757
1758  def __init__(self) -> None:
1759    super().__init__(b=1 / 3, c=1 / 3, name='mitchell')
1760
1761
1762class SharpCubicFilter(CubicFilter):
1763  """Cubic filter that is sharper than Catmull-Rom filter.
1764
1765  Used by some tools including OpenCV and Photoshop.
1766
1767  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1768  https://entropymine.com/resamplescope/notes/photoshop/.
1769  """
1770
1771  def __init__(self) -> None:
1772    super().__init__(b=0, c=0.75, name='sharpcubic')
1773
1774
1775class LanczosFilter(Filter):
1776  """High-quality filter: sinc function modulated by a sinc window.
1777
1778  Args:
1779    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1780    sampled: If True, use a discretized approximation for improved speed.
1781
1782  See https://en.wikipedia.org/wiki/Lanczos_kernel.
1783  """
1784
1785  def __init__(self, *, radius: int, sampled: bool = True) -> None:
1786    super().__init__(
1787        name=f'lanczos_{radius}', radius=radius, partition_of_unity=False, unit_integral=False
1788    )
1789
1790    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1791    def _eval(x: _ArrayLike) -> _NDArray:
1792      x = np.abs(x)
1793      # Note that window[n] = sinc(2*n/N - 1), with 0 <= n <= N.
1794      # But, x = n - N/2, or equivalently, n = x + N/2, with -N/2 <= x <= N/2.
1795      window = _sinc(x / radius)  # Zero-phase function w_0(x).
1796      return np.where(x < radius, _sinc(x) * window, 0.0)
1797
1798    self._function = _eval
1799
1800  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1801    return self._function(x)
1802
1803
1804class GeneralizedHammingFilter(Filter):
1805  """Sinc function modulated by a Hamming window.
1806
1807  Args:
1808    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1809    a0: Scalar parameter, where 0.0 < a0 < 1.0.  The case of a0=0.5 is the Hann filter.
1810
1811  See https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows,
1812  and hamming() in https://github.com/scipy/scipy/blob/main/scipy/signal/windows/_windows.py.
1813
1814  Note that `'hamming3'` is `(radius=3, a0=25/46)`, which close to but different from `a0=0.54`.
1815
1816  See also np.hamming() and np.hanning().
1817  """
1818
1819  def __init__(self, *, radius: int, a0: float) -> None:
1820    super().__init__(
1821        name=f'hamming_{radius}',
1822        radius=radius,
1823        partition_of_unity=False,  # 1:1.00242  av=1.00188  sd=0.00052909
1824        unit_integral=False,  # 1.00188
1825    )
1826    assert 0.0 < a0 < 1.0
1827    self.a0 = a0
1828
1829  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1830    x = np.abs(x)
1831    # Note that window[n] = a0 - (1 - a0) * cos(2 * pi * n / N), 0 <= n <= N.
1832    # With n = x + N/2, we get the zero-phase function w_0(x):
1833    window = self.a0 + (1.0 - self.a0) * np.cos(np.pi / self.radius * x)
1834    return np.where(x < self.radius, _sinc(x) * window, 0.0)
1835
1836
1837class KaiserFilter(Filter):
1838  """Sinc function modulated by a Kaiser-Bessel window.
1839
1840  See https://en.wikipedia.org/wiki/Kaiser_window, and example use in:
1841  [Karras et al. 20201.  Alias-free generative adversarial networks.
1842  https://arxiv.org/pdf/2106.12423.pdf].
1843
1844  See also np.kaiser().
1845
1846  Args:
1847    radius: Value L/2 in the definition.  It may be fractional for a (digital) resizing filter
1848      (sample spacing s != 1) with an even number of samples (dual grid), e.g., Eq. (6)
1849      in [Karras et al. 2021] --- this effects the precise shape of the window function.
1850    beta: Determines the trade-off between main-lobe width and side-lobe level.
1851    sampled: If True, use a discretized approximation for improved speed.
1852  """
1853
1854  def __init__(self, *, radius: float, beta: float, sampled: bool = True) -> None:
1855    assert beta >= 0.0
1856    super().__init__(
1857        name=f'kaiser_{radius}_{beta}', radius=radius, partition_of_unity=False, unit_integral=False
1858    )
1859
1860    @_cache_sampled_1d_function(xmin=-math.ceil(radius), xmax=math.ceil(radius), enable=sampled)
1861    def _eval(x: _ArrayLike) -> _NDArray:
1862      x = np.abs(x)
1863      window = np.i0(beta * np.sqrt((1.0 - np.square(x / radius)).clip(0.0, 1.0))) / np.i0(beta)
1864      return np.where(x <= radius + 1e-6, _sinc(x) * window, 0.0)
1865
1866    self._function = _eval
1867
1868  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1869    return self._function(x)
1870
1871
1872class BsplineFilter(Filter):
1873  """B-spline of a non-negative degree.
1874
1875  Args:
1876    degree: The polynomial degree of the B-spline segments.
1877      With `degree=0`, it is like `BoxFilter` except with f(0.5) = f(-0.5) = 0.
1878      With `degree=1`, it is identical to `TriangleFilter`.
1879      With `degree >= 2`, it is no longer interpolating.
1880
1881  See [Carl de Boor.  A practical guide to splines.  Springer, 2001.]
1882  https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.BSpline.html
1883  """
1884
1885  def __init__(self, *, degree: int) -> None:
1886    if degree < 0:
1887      raise ValueError(f'Bspline of degree {degree} is invalid.')
1888    radius = (degree + 1) / 2
1889    interpolating = degree <= 1
1890    super().__init__(name=f'bspline{degree}', radius=radius, interpolating=interpolating)
1891    t = list(range(degree + 2))
1892    self._bspline = scipy.interpolate.BSpline.basis_element(t)
1893
1894  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1895    x = np.abs(x)
1896    return np.where(x < self.radius, self._bspline(x + self.radius), 0.0)
1897
1898
1899class CardinalBsplineFilter(Filter):
1900  """Interpolating B-spline, achieved with aid of digital pre or post filter.
1901
1902  Args:
1903    degree: The polynomial degree of the B-spline segments.
1904    sampled: If True, use a discretized approximation for improved speed.
1905
1906  See [Hou and Andrews.  Cubic splines for image interpolation and digital filtering, 1978] and
1907  [Unser et al.  Fast B-spline transforms for continuous image representation and interpolation,
1908  1991].
1909  """
1910
1911  def __init__(self, *, degree: int, sampled: bool = True) -> None:
1912    self.degree = degree
1913    if degree < 0:
1914      raise ValueError(f'Bspline of degree {degree} is invalid.')
1915    radius = (degree + 1) / 2
1916    super().__init__(
1917        name=f'cardinal{degree}',
1918        radius=radius,
1919        requires_digital_filter=degree >= 2,
1920        continuous=degree >= 1,
1921    )
1922    t = list(range(degree + 2))
1923    bspline = scipy.interpolate.BSpline.basis_element(t)
1924
1925    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1926    def _eval(x: _ArrayLike) -> _NDArray:
1927      x = np.abs(x)
1928      return np.where(x < radius, bspline(x + radius), 0.0)
1929
1930    self._function = _eval
1931
1932  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1933    return self._function(x)
1934
1935
1936class OmomsFilter(Filter):
1937  """OMOMS interpolating filter, with aid of digital pre or post filter.
1938
1939  Args:
1940    degree: The polynomial degree of the filter segments.
1941
1942  Optimal MOMS (maximal-order-minimal-support) function; see [Blu and Thevenaz, MOMS: Maximal-order
1943  interpolation of minimal support, 2001].
1944  https://infoscience.epfl.ch/record/63074/files/blu0101.pdf
1945  """
1946
1947  def __init__(self, *, degree: int) -> None:
1948    if degree not in (3, 5):
1949      raise ValueError(f'Degree {degree} not supported.')
1950    super().__init__(name=f'omoms{degree}', radius=(degree + 1) / 2, requires_digital_filter=True)
1951    self.degree = degree
1952
1953  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1954    x = np.abs(x)
1955    match self.degree:
1956      case 3:
1957        v01 = ((0.5 * x - 1.0) * x + 3 / 42) * x + 26 / 42
1958        v12 = ((-7 / 42 * x + 1.0) * x - 85 / 42) * x + 58 / 42
1959        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1960      case 5:
1961        v01 = ((((-1 / 12 * x + 1 / 4) * x - 5 / 99) * x - 9 / 22) * x - 1 / 792) * x + 229 / 440
1962        v12 = (
1963            (((1 / 24 * x - 3 / 8) * x + 505 / 396) * x - 83 / 44) * x + 1351 / 1584
1964        ) * x + 839 / 2640
1965        v23 = (
1966            (((-1 / 120 * x + 1 / 8) * x - 299 / 396) * x + 101 / 44) * x - 27811 / 7920
1967        ) * x + 5707 / 2640
1968        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, np.where(x < 3.0, v23, 0.0)))
1969      case _:
1970        raise ValueError(self.degree)
1971
1972
1973class GaussianFilter(Filter):
1974  r"""See https://en.wikipedia.org/wiki/Gaussian_function.
1975
1976  Args:
1977    standard_deviation: Sets the Gaussian $\sigma$.  The default value is 1.25/3.0, which
1978      creates a kernel that is as-close-as-possible to a partition of unity.
1979  """
1980
1981  DEFAULT_STANDARD_DEVIATION = 1.25 / 3.0
1982  """This value creates a kernel that is as-close-as-possible to a partition of unity; see
1983  mesh_processing/test/GridOp_test.cpp: `0.93503:1.06497     av=1           sd=0.0459424`.
1984  Another possibility is 0.5, as suggested on p. 4 of [Ken Turkowski.  Filters for common
1985  resampling tasks, 1990] for kernels with a support of 3 pixels.
1986  https://cadxfem.org/inf/ResamplingFilters.pdf
1987  """
1988
1989  def __init__(self, *, standard_deviation: float = DEFAULT_STANDARD_DEVIATION) -> None:
1990    super().__init__(
1991        name=f'gaussian_{standard_deviation:.3f}',
1992        radius=np.ceil(8.0 * standard_deviation),  # Sufficiently large.
1993        interpolating=False,
1994        partition_of_unity=False,
1995    )
1996    self.standard_deviation = standard_deviation
1997
1998  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1999    x = np.abs(x)
2000    sdv = self.standard_deviation
2001    v0r = np.exp(np.square(x / sdv) / -2.0) / (np.sqrt(math.tau) * sdv)
2002    return np.where(x < self.radius, v0r, 0.0)
2003
2004
2005class NarrowBoxFilter(Filter):
2006  """Compact footprint, used for visualization of grid sample location.
2007
2008  Args:
2009    radius: Specifies the support [-radius, radius] of the narrow box function.  (The default
2010      value 0.199 is an inexact 0.2 to avoid numerical ambiguities.)
2011  """
2012
2013  def __init__(self, *, radius: float = 0.199) -> None:
2014    super().__init__(
2015        name='narrowbox',
2016        radius=radius,
2017        continuous=False,
2018        unit_integral=False,
2019        partition_of_unity=False,
2020    )
2021
2022  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2023    radius = self.radius
2024    magnitude = 1.0
2025    x = np.asarray(x)
2026    return np.where((-radius <= x) & (x < radius), magnitude, 0.0)
2027
2028
2029_DEFAULT_FILTER = 'lanczos3'
2030
2031_DICT_FILTERS = {
2032    'impulse': ImpulseFilter(),
2033    'box': BoxFilter(),
2034    'trapezoid': TrapezoidFilter(),
2035    'triangle': TriangleFilter(),
2036    'cubic': CatmullRomFilter(),
2037    'sharpcubic': SharpCubicFilter(),
2038    'lanczos3': LanczosFilter(radius=3),
2039    'lanczos5': LanczosFilter(radius=5),
2040    'lanczos10': LanczosFilter(radius=10),
2041    'cardinal3': CardinalBsplineFilter(degree=3),
2042    'cardinal5': CardinalBsplineFilter(degree=5),
2043    'omoms3': OmomsFilter(degree=3),
2044    'omoms5': OmomsFilter(degree=5),
2045    'hamming3': GeneralizedHammingFilter(radius=3, a0=25 / 46),
2046    'kaiser3': KaiserFilter(radius=3.0, beta=7.12),
2047    'gaussian': GaussianFilter(),
2048    'bspline3': BsplineFilter(degree=3),
2049    'mitchell': MitchellFilter(),
2050    'narrowbox': NarrowBoxFilter(),
2051    # Not in FILTERS:
2052    'hamming1': GeneralizedHammingFilter(radius=1, a0=0.54),
2053    'hann3': GeneralizedHammingFilter(radius=3, a0=0.5),
2054    'lanczos4': LanczosFilter(radius=4),
2055}
2056
2057FILTERS = list(itertools.takewhile(lambda x: x != 'hamming1', _DICT_FILTERS))
2058r"""Shortcut names for some predefined filter kernels (specified per dimension).
2059The names expand to:
2060
2061| name           | `Filter`                      | a.k.a. / comments |
2062|----------------|-------------------------------|-------------------|
2063| `'impulse'`    | `ImpulseFilter()`             | *nearest* |
2064| `'box'`        | `BoxFilter()`                 | non-antialiased box, e.g. ImageMagick |
2065| `'trapezoid'`  | `TrapezoidFilter()`           | *area* antialiasing, e.g. `cv.INTER_AREA` |
2066| `'triangle'`   | `TriangleFilter()`            | *linear*  (*bilinear* in 2D), spline `order=1` |
2067| `'cubic'`      | `CatmullRomFilter()`          | *catmullrom*, *keys*, *bicubic* |
2068| `'sharpcubic'` | `SharpCubicFilter()`          | `cv.INTER_CUBIC`, `torch 'bicubic'` |
2069| `'lanczos3'`   | `LanczosFilter`(radius=3)     | support window [-3, 3] |
2070| `'lanczos5'`   | `LanczosFilter`(radius=5)     | [-5, 5] |
2071| `'lanczos10'`  | `LanczosFilter`(radius=10)    | [-10, 10] |
2072| `'cardinal3'`  | `CardinalBsplineFilter`(degree=3) | *spline interpolation*, `order=3`, *GF* |
2073| `'cardinal5'`  | `CardinalBsplineFilter`(degree=5) | *spline interpolation*, `order=5`, *GF* |
2074| `'omoms3'`     | `OmomsFilter`(degree=3)       | non-$C^1$, [-3, 3], *GF* |
2075| `'omoms5'`     | `OmomsFilter`(degree=5)       | non-$C^1$, [-5, 5], *GF* |
2076| `'hamming3'`   | `GeneralizedHammingFilter`(...) | (radius=3, a0=25/46) |
2077| `'kaiser3'`    | `KaiserFilter`(radius=3.0, beta=7.12) | |
2078| `'gaussian'`   | `GaussianFilter()`            | non-interpolating, default $\sigma=1.25/3$ |
2079| `'bspline3'`   | `BsplineFilter`(degree=3)     | non-interpolating |
2080| `'mitchell'`   | `MitchellFilter()`            | *mitchellcubic* |
2081| `'narrowbox'`  | `NarrowBoxFilter()`           | for visualization of sample positions |
2082
2083The comment label *GF* denotes a [generalized filter](https://hhoppe.com/proj/filtering/), formed
2084as the composition of a finitely supported kernel and a discrete inverse convolution.
2085
2086**Some example filter kernels:**
2087
2088<center>
2089<img src="https://github.com/hhoppe/resampler/raw/main/media/filter_summary.png" width="100%"/>
2090</center>
2091
2092<br/>A more extensive set of filters is presented [here](#plots_of_filters) in the
2093[notebook](https://colab.research.google.com/github/hhoppe/resampler/blob/main/resampler_notebook.ipynb),
2094together with visualizations and analyses of the filter properties.
2095See the source code for extensibility.
2096"""
2097
2098
2099def _get_filter(filter: str | Filter, /) -> Filter:
2100  """Return a `Filter`, which can be specified as a name string key in `FILTERS`."""
2101  return filter if isinstance(filter, Filter) else _DICT_FILTERS[filter]
2102
2103
2104def _to_float_01(array: _Array, /, dtype: _DTypeLike) -> _Array:
2105  """Scale uint to the range [0.0, 1.0], and clip float to [0.0, 1.0]."""
2106  array_dtype = _arr_dtype(array)
2107  dtype = np.dtype(dtype)
2108  assert np.issubdtype(dtype, np.floating)
2109  match array_dtype.type:
2110    case np.uint8 | np.uint16 | np.uint32:
2111      if _arr_arraylib(array) == 'numpy':
2112        assert isinstance(array, np.ndarray)  # Help mypy.
2113        return np.multiply(array, 1 / np.iinfo(array_dtype).max, dtype=dtype)
2114      return _arr_astype(array, dtype) / np.iinfo(array_dtype).max
2115    case _:
2116      assert np.issubdtype(array_dtype, np.floating)
2117      return _arr_clip(array, 0.0, 1.0, dtype)
2118
2119
2120def _from_float(array: _Array, /, dtype: _DTypeLike) -> _Array:
2121  """Convert a float in range [0.0, 1.0] to uint or float type."""
2122  assert np.issubdtype(_arr_dtype(array), np.floating)
2123  dtype = np.dtype(dtype)
2124  match dtype.type:
2125    case np.uint8 | np.uint16:
2126      return cast(_Array, _arr_astype(array * np.float32(np.iinfo(dtype).max) + 0.5, dtype))
2127    case np.uint32:
2128      return cast(_Array, _arr_astype(array * np.float64(np.iinfo(dtype).max) + 0.5, dtype))
2129    case _:
2130      assert np.issubdtype(dtype, np.floating)
2131      return _arr_astype(array, dtype)
2132
2133
2134@dataclasses.dataclass(frozen=True)
2135class Gamma(abc.ABC):
2136  """Abstract base class for transfer functions on sample values.
2137
2138  Image/video content is often stored using a color component transfer function.
2139  See https://en.wikipedia.org/wiki/Gamma_correction.
2140
2141  Converts between integer types and [0.0, 1.0] internal value range.
2142  """
2143
2144  name: str
2145  """Name of component transfer function."""
2146
2147  @abc.abstractmethod
2148  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2149    """Decode source sample values into floating-point, possibly nonlinearly.
2150
2151    Uint source values are mapped to the range [0.0, 1.0].
2152    """
2153
2154  @abc.abstractmethod
2155  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2156    """Encode float signal into destination samples, possibly nonlinearly.
2157
2158    Uint destination values are mapped from the range [0.0, 1.0].
2159
2160    Note that non-integer destination types are not clipped to the range [0.0, 1.0].
2161    If that is desired, it can be performed as a postprocess using `output.clip(0.0, 1.0)`.
2162    """
2163
2164
2165class IdentityGamma(Gamma):
2166  """Identity component transfer function."""
2167
2168  def __init__(self) -> None:
2169    super().__init__('identity')
2170
2171  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2172    dtype = np.dtype(dtype)
2173    assert np.issubdtype(dtype, np.inexact)
2174    if np.issubdtype(_arr_dtype(array), np.unsignedinteger):
2175      return _to_float_01(array, dtype)
2176    return _arr_astype(array, dtype)
2177
2178  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2179    dtype = np.dtype(dtype)
2180    assert np.issubdtype(dtype, np.number)
2181    if np.issubdtype(dtype, np.unsignedinteger):
2182      return _from_float(_arr_clip(array, 0.0, 1.0), dtype)
2183    if np.issubdtype(dtype, np.integer):
2184      return _arr_astype(array + 0.5, dtype)
2185    return _arr_astype(array, dtype)
2186
2187
2188class PowerGamma(Gamma):
2189  """Gamma correction using a power function."""
2190
2191  def __init__(self, power: float) -> None:
2192    super().__init__(name=f'power_{power}')
2193    self.power = power
2194
2195  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2196    dtype = np.dtype(dtype)
2197    assert np.issubdtype(dtype, np.floating)
2198    if _arr_dtype(array) == np.uint8 and self.power != 2:
2199      arraylib = _arr_arraylib(array)
2200      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2201      return _arr_getitem(decode_table, array)
2202
2203    array = _to_float_01(array, dtype)
2204    return _arr_square(array) if self.power == 2 else array**self.power
2205
2206  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2207    array = _arr_clip(array, 0.0, 1.0)
2208    array = _arr_sqrt(array) if self.power == 2 else array ** (1.0 / self.power)
2209    return _from_float(array, dtype)
2210
2211
2212class SrgbGamma(Gamma):
2213  """Gamma correction using sRGB; see https://en.wikipedia.org/wiki/SRGB."""
2214
2215  def __init__(self) -> None:
2216    super().__init__(name='srgb')
2217
2218  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2219    dtype = np.dtype(dtype)
2220    assert np.issubdtype(dtype, np.floating)
2221    if _arr_dtype(array) == np.uint8:
2222      arraylib = _arr_arraylib(array)
2223      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2224      return _arr_getitem(decode_table, array)
2225
2226    x = _to_float_01(array, dtype)
2227    return _arr_where(x > 0.04045, ((x + 0.055) / 1.055) ** 2.4, x / 12.92)
2228
2229  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2230    x = _arr_clip(array, 0.0, 1.0)
2231    # Unfortunately, exponentiation is slow, and np.digitize() is even slower.
2232    # pytype: disable=wrong-arg-types
2233    x = _arr_where(x > 0.0031308, x ** (1.0 / 2.4) * 1.055 - (0.055 - 1e-17), x * 12.92)
2234    # pytype: enable=wrong-arg-types
2235    return _from_float(x, dtype)
2236
2237
2238_DICT_GAMMAS = {
2239    'identity': IdentityGamma(),
2240    'power2': PowerGamma(2.0),
2241    'power22': PowerGamma(2.2),
2242    'srgb': SrgbGamma(),
2243}
2244
2245GAMMAS = list(_DICT_GAMMAS)
2246r"""Shortcut names for some predefined gamma-correction schemes:
2247
2248| name | `Gamma` | Decoding function<br/> (linear space from stored value) | Encoding function<br/> (stored value from linear space) |
2249|---|---|:---:|:---:|
2250| `'identity'` | `IdentityGamma()` | $l = e$ | $e = l$ |
2251| `'power2'` | `PowerGamma`(2.0) | $l = e^{2.0}$ | $e = l^{1/2.0}$ |
2252| `'power22'` | `PowerGamma`(2.2) | $l = e^{2.2}$ | $e = l^{1/2.2}$ |
2253| `'srgb'` | `SrgbGamma()` | $l = \left(\left(e + 0.055\right) / 1.055\right)^{2.4}$ | $e = l^{1/2.4} * 1.055 - 0.055$ |
2254
2255See the source code for extensibility.
2256"""
2257
2258
2259def _get_gamma(gamma: str | Gamma, /) -> Gamma:
2260  """Return a `Gamma`, which can be specified as a name in `GAMMAS`."""
2261  return gamma if isinstance(gamma, Gamma) else _DICT_GAMMAS[gamma]
2262
2263
2264def _get_src_dst_gamma(
2265    gamma: str | Gamma | None,
2266    src_gamma: str | Gamma | None,
2267    dst_gamma: str | Gamma | None,
2268    src_dtype: _DType,
2269    dst_dtype: _DType,
2270) -> tuple[Gamma, Gamma]:
2271  if gamma is None and src_gamma is None and dst_gamma is None:
2272    src_uint = np.issubdtype(src_dtype, np.unsignedinteger)
2273    dst_uint = np.issubdtype(dst_dtype, np.unsignedinteger)
2274    if src_uint and dst_uint:
2275      # The default might ideally be 'srgb' but that conversion is costlier.
2276      gamma = 'power2'
2277    elif not src_uint and not dst_uint:
2278      gamma = 'identity'
2279    else:
2280      raise ValueError(f'Gamma must be specified because {src_dtype=} and {dst_dtype=}.')
2281  if gamma is not None:
2282    if src_gamma is not None:
2283      raise ValueError('Cannot specify both gamma and src_gamma.')
2284    if dst_gamma is not None:
2285      raise ValueError('Cannot specify both gamma and dst_gamma.')
2286    src_gamma = dst_gamma = gamma
2287  assert src_gamma and dst_gamma
2288  src_gamma = _get_gamma(src_gamma)
2289  dst_gamma = _get_gamma(dst_gamma)
2290  return src_gamma, dst_gamma
2291
2292
2293def _create_resize_matrix(
2294    src_size: int,
2295    dst_size: int,
2296    src_gridtype: Gridtype,
2297    dst_gridtype: Gridtype,
2298    boundary: Boundary,
2299    filter: Filter,
2300    prefilter: Filter | None = None,
2301    scale: float = 1.0,
2302    translate: float = 0.0,
2303    dtype: _DTypeLike = np.float64,
2304    arraylib: str = 'numpy',
2305) -> tuple[_Array, _Array | None]:
2306  """Compute affine weights for 1D resampling from `src_size` to `dst_size`.
2307
2308  Compute a sparse matrix in which each row expresses a destination sample value as a combination
2309  of source sample values depending on the boundary rule.  If the combination is non-affine,
2310  the remainder (returned as `cval_weight`) is the contribution of the special constant value
2311  (cval) defined outside the domain.
2312
2313  Args:
2314    src_size: The number of samples within the source 1D domain.
2315    dst_size: The number of samples within the destination 1D domain.
2316    src_gridtype: Placement of the samples in the source domain grid.
2317    dst_gridtype: Placement of the output samples in the destination domain grid.
2318    boundary: The reconstruction boundary rule.
2319    filter: The reconstruction kernel (used for upsampling/magnification).
2320    prefilter: The prefilter kernel (used for downsampling/minification).  If it is `None`,
2321      `filter` is used.
2322    scale: Scaling factor applied when mapping the source domain onto the destination domain.
2323    translate: Offset applied when mapping the scaled source domain onto the destination domain.
2324    dtype: Precision of computed resize matrix entries.
2325    arraylib: Representation of output.  Must be an element of `ARRAYLIBS`.
2326
2327  Returns:
2328    sparse_matrix: Matrix whose rows express output sample values as affine combinations of the
2329      source sample values.
2330    cval_weight: Optional vector expressing the additional contribution of the constant value
2331      (`cval`) to the combination in each row of `sparse_matrix`.  It equals one minus the sum of
2332      the weights in each matrix row.
2333  """
2334  if src_size < src_gridtype.min_size():
2335    raise ValueError(f'Source size {src_size} is too small for resize.')
2336  if dst_size < dst_gridtype.min_size():
2337    raise ValueError(f'Destination size {dst_size} is too small for resize.')
2338  prefilter = filter if prefilter is None else prefilter
2339  dtype = np.dtype(dtype)
2340  assert np.issubdtype(dtype, np.floating)
2341
2342  scaling = dst_gridtype.size_in_samples(dst_size) / src_gridtype.size_in_samples(src_size) * scale
2343  is_minification = scaling < 1.0
2344  filter = prefilter if is_minification else filter
2345  if filter.name == 'trapezoid':
2346    radius = 0.5 + 0.5 * min(scaling, 1.0 / scaling)
2347    filter = TrapezoidFilter(radius=radius)
2348  radius = filter.radius
2349  num_samples = int(np.ceil(radius * 2 / scaling) if is_minification else np.ceil(radius * 2))
2350
2351  dst_index = np.arange(dst_size, dtype=dtype)
2352  # Destination sample locations in unit domain [0, 1].
2353  dst_position = dst_gridtype.point_from_index(dst_index, dst_size)
2354
2355  src_position = (dst_position - translate) / scale
2356  src_position = boundary.preprocess_coordinates(src_position)
2357
2358  # Sample positions mapped back to source unit domain [0, 1].
2359  src_float_index = src_gridtype.index_from_point(src_position, src_size)
2360  src_first_index = (
2361      np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
2362      - (num_samples - 1) // 2
2363  )
2364
2365  sample_index = np.arange(num_samples, dtype=np.int32)
2366  src_index = src_first_index[:, None] + sample_index  # (dst_size, num_samples)
2367
2368  def get_weight_matrix() -> _NDArray:
2369    if filter.name == 'impulse':
2370      return np.ones(src_index.shape, dtype)
2371    if is_minification:
2372      x = (src_float_index[:, None] - src_index.astype(dtype)) * scaling
2373      return filter(x) * scaling
2374    # Either same size or magnification.
2375    x = src_float_index[:, None] - src_index.astype(dtype)
2376    return filter(x)
2377
2378  weight = get_weight_matrix().astype(dtype, copy=False)
2379
2380  if filter.name != 'narrowbox' and (is_minification or not filter.partition_of_unity):
2381    weight = weight / weight.sum(axis=-1)[..., None]
2382
2383  src_index, weight = boundary.apply(src_index, weight, src_position, src_size, src_gridtype)
2384  shape = dst_size, src_size
2385
2386  def prepare_sparse_resize_matrix() -> tuple[_NDArray, _NDArray, _NDArray]:
2387    linearized = (src_index + np.indices(src_index.shape)[0] * src_size).ravel()
2388    values = weight.ravel()
2389    # Remove the zero weights.
2390    nonzero = values != 0.0
2391    linearized, values = linearized[nonzero], values[nonzero]
2392    # Sort and merge the duplicate indices.
2393    unique, unique_inverse = np.unique(linearized, return_inverse=True)
2394    data2 = np.ones(len(linearized), np.float32)
2395    row_ind2 = unique_inverse
2396    col_ind2 = np.arange(len(linearized))
2397    shape2 = len(unique), len(linearized)
2398    csr = scipy.sparse.csr_matrix((data2, (row_ind2, col_ind2)), shape=shape2)
2399    data = csr * values  # Merged values.
2400    row_ind, col_ind = unique // src_size, unique % src_size  # Merged indices.
2401    return data, row_ind, col_ind
2402
2403  data, row_ind, col_ind = prepare_sparse_resize_matrix()
2404  resize_matrix = _make_sparse_matrix(data, row_ind, col_ind, shape, arraylib)
2405
2406  uses_cval = boundary.uses_cval or filter.name == 'narrowbox'
2407  cval_weight = _make_array(1.0 - weight.sum(axis=-1), arraylib) if uses_cval else None
2408
2409  return resize_matrix, cval_weight
2410
2411
2412def _apply_digital_filter_1d(
2413    array: _Array,
2414    gridtype: Gridtype,
2415    boundary: Boundary,
2416    cval: _ArrayLike,
2417    filter: Filter,
2418    /,
2419    *,
2420    axis: int = 0,
2421) -> _Array:
2422  """Apply inverse convolution to the specified dimension of the array.
2423
2424  Find the array coefficients such that convolution with the (continuous) filter (given
2425  gridtype and boundary) interpolates the original array values.
2426  """
2427  assert filter.requires_digital_filter
2428  arraylib = _arr_arraylib(array)
2429
2430  if arraylib == 'tensorflow':
2431    import tensorflow as tf
2432
2433    def forward(x: _NDArray) -> _NDArray:
2434      return _apply_digital_filter_1d_numpy(x, gridtype, boundary, cval, filter, axis, False)
2435
2436    def backward(grad_output: _NDArray) -> _NDArray:
2437      return _apply_digital_filter_1d_numpy(
2438          grad_output, gridtype, boundary, cval, filter, axis, True
2439      )
2440
2441    @tf.custom_gradient  # type: ignore[misc]
2442    def tensorflow_inverse_convolution(x: _TensorflowTensor) -> _TensorflowTensor:
2443      # Although `forward` accesses parameters gridtype, boundary, etc., it is not stateful
2444      # because the function is redefined on each invocation of _apply_digital_filter_1d.
2445      y = tf.numpy_function(forward, [x], x.dtype, stateful=False)
2446
2447      def grad(grad_output: _TensorflowTensor) -> _TensorflowTensor:
2448        return tf.numpy_function(backward, [grad_output], x.dtype, stateful=False)
2449
2450      return y, grad
2451
2452    return tensorflow_inverse_convolution(array)
2453
2454  if arraylib == 'torch':
2455    import torch.autograd
2456
2457    class InverseConvolution(torch.autograd.Function):  # type: ignore[misc] # pylint: disable=abstract-method
2458      """Differentiable wrapper for _apply_digital_filter_1d."""
2459
2460      @staticmethod
2461      # pylint: disable-next=arguments-differ
2462      def forward(ctx: Any, *args: _TorchTensor, **kwargs: Any) -> _TorchTensor:
2463        del ctx
2464        assert not kwargs
2465        (x,) = args
2466        a = _apply_digital_filter_1d_numpy(
2467            x.detach().numpy(), gridtype, boundary, cval, filter, axis, False
2468        )
2469        return torch.as_tensor(a)
2470
2471      @staticmethod
2472      def backward(ctx: Any, *grad_outputs: _TorchTensor) -> _TorchTensor:
2473        del ctx
2474        (grad_output,) = grad_outputs
2475        a = _apply_digital_filter_1d_numpy(
2476            grad_output.detach().numpy(), gridtype, boundary, cval, filter, axis, True
2477        )
2478        return torch.as_tensor(a)
2479
2480    return InverseConvolution.apply(array)
2481
2482  if arraylib == 'jax':
2483    import jax
2484    import jax.numpy as jnp
2485    # It seems rather difficult to implement this digital filter (inverse convolution) in Jax.
2486    # https://jax.readthedocs.io/en/latest/jax.scipy.html sadly omits scipy.signal.filtfilt().
2487    # To include a (non-traceable) numpy function in Jax requires jax.experimental.host_callback
2488    # and/or defining a new jax.core.Primitive (which allows differentiability).  See
2489    # https://github.com/google/jax/issues/1142#issuecomment-544286585
2490    # https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb  :-(
2491    # https://github.com/google/jax/issues/5934
2492
2493    @jax.custom_gradient  # type: ignore[misc]
2494    def jax_inverse_convolution(x: _JaxArray) -> _JaxArray:
2495      # This function is not jax-traceable due to the presence of to_py(), so jit and grad fail.
2496      x_py = np.asarray(x)  # to_py() deprecated.
2497      a = _apply_digital_filter_1d_numpy(x_py, gridtype, boundary, cval, filter, axis, False)
2498      y = jnp.asarray(a)
2499
2500      def grad(grad_output: _JaxArray) -> _JaxArray:
2501        grad_output_py = np.asarray(grad_output)  # to_py() deprecated.
2502        a = _apply_digital_filter_1d_numpy(
2503            grad_output_py, gridtype, boundary, cval, filter, axis, True
2504        )
2505        return jnp.asarray(a)
2506
2507      return y, grad
2508
2509    return jax_inverse_convolution(array)
2510
2511  assert arraylib == 'numpy'
2512  assert isinstance(array, np.ndarray)  # Help mypy.
2513  return _apply_digital_filter_1d_numpy(array, gridtype, boundary, cval, filter, axis, False)
2514
2515
2516def _apply_digital_filter_1d_numpy(
2517    array: _NDArray,
2518    gridtype: Gridtype,
2519    boundary: Boundary,
2520    cval: _ArrayLike,
2521    filter: Filter,
2522    axis: int,
2523    compute_backward: bool,
2524    /,
2525) -> _NDArray:
2526  """Version of _apply_digital_filter_1d` specialized to numpy array."""
2527  assert np.issubdtype(array.dtype, np.inexact)
2528  cval = np.asarray(cval).astype(array.dtype, copy=False)
2529
2530  # Use fast spline_filter1d() if we have a compatible gridtype, boundary, and filter:
2531  mode = {
2532      ('reflect', 'dual'): 'reflect',
2533      ('reflect', 'primal'): 'mirror',
2534      ('wrap', 'dual'): 'grid-wrap',
2535      ('wrap', 'primal'): 'wrap',
2536  }.get((boundary.name, gridtype.name))
2537  filter_is_compatible = isinstance(filter, CardinalBsplineFilter)
2538  use_split_filter1d = filter_is_compatible and mode
2539  if use_split_filter1d:
2540    assert isinstance(filter, CardinalBsplineFilter)  # Help mypy.
2541    assert filter.degree >= 2
2542    # compute_backward=True is same: matrix is symmetric and cval is unused.
2543    return scipy.ndimage.spline_filter1d(
2544        array, axis=axis, order=filter.degree, mode=mode, output=array.dtype
2545    )
2546
2547  array_dim = np.moveaxis(array, axis, 0)
2548  l = original_l = math.ceil(filter.radius) - 1
2549  x = np.arange(-l, l + 1, dtype=array.real.dtype)
2550  values = filter(x)
2551  size = array_dim.shape[0]
2552  src_index = np.arange(size)[:, None] + np.arange(len(values)) - l
2553  weight = np.full((size, len(values)), values)
2554  src_position = np.broadcast_to(0.5, len(values))
2555  src_index, weight = boundary.apply(src_index, weight, src_position, size, gridtype)
2556  if gridtype.name == 'primal' and boundary.name == 'wrap':
2557    # Overwrite redundant last row to preserve unreferenced last sample and thereby make the
2558    # matrix non-singular.
2559    src_index[-1] = [size - 1] + [0] * (src_index.shape[1] - 1)
2560    weight[-1] = [1.0] + [0.0] * (weight.shape[1] - 1)
2561  bandwidth = abs(src_index - np.arange(size)[:, None]).max()
2562  is_banded = bandwidth <= l + 1  # Add one for quadratic boundary and l == 1.
2563  # Currently, matrix is always banded unless boundary.name == 'wrap'.
2564
2565  data = weight.reshape(-1).astype(array.dtype, copy=False)
2566  row_ind = np.arange(size).repeat(src_index.shape[1])
2567  col_ind = src_index.reshape(-1)
2568  matrix = scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=(size, size))
2569  if compute_backward:
2570    matrix = matrix.T
2571
2572  if boundary.uses_cval and not compute_backward:
2573    cval_weight = 1.0 - np.asarray(matrix.sum(axis=-1))[:, 0]
2574    if array_dim.ndim == 2:  # Handle the case that we have array_flat.
2575      cval = np.tile(cval.reshape(-1), array_dim[0].size // cval.size)
2576    array_dim = array_dim - cval_weight.reshape(-1, *(1,) * array_dim[0].ndim) * cval
2577
2578  if is_banded:
2579    matrix = matrix.todia()
2580    assert np.all(np.diff(matrix.offsets) == 1)  # Consecutive, often [-l, l].
2581    l, u = -matrix.offsets[0], matrix.offsets[-1]
2582    assert l <= original_l + 1 and u <= original_l + 1, (l, u, original_l)
2583    options = dict(check_finite=False, overwrite_ab=True, overwrite_b=False)
2584    if _is_symmetric(matrix):
2585      array_dim = scipy.linalg.solveh_banded(matrix.data[-1 : l - 1 : -1], array_dim, **options)
2586    else:
2587      array_dim = scipy.linalg.solve_banded((l, u), matrix.data[::-1], array_dim, **options)
2588
2589  else:
2590    lu = scipy.sparse.linalg.splu(matrix.tocsc(), permc_spec='NATURAL')
2591    assert all(s <= size * len(values) for s in (lu.L.nnz, lu.U.nnz))  # Sparse.
2592    array_dim = lu.solve(array_dim.reshape(array_dim.shape[0], -1)).reshape(array_dim.shape)
2593
2594  return np.moveaxis(array_dim, 0, axis)
2595
2596
2597def resize(
2598    array: _Array,
2599    /,
2600    shape: Iterable[int],
2601    *,
2602    gridtype: str | Gridtype | None = None,
2603    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2604    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2605    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2606    cval: _ArrayLike = 0.0,
2607    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2608    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2609    gamma: str | Gamma | None = None,
2610    src_gamma: str | Gamma | None = None,
2611    dst_gamma: str | Gamma | None = None,
2612    scale: float | Iterable[float] = 1.0,
2613    translate: float | Iterable[float] = 0.0,
2614    precision: _DTypeLike = None,
2615    dtype: _DTypeLike = None,
2616    dim_order: Iterable[int] | None = None,
2617    num_threads: int | Literal['auto'] = 'auto',
2618) -> _Array:
2619  """Resample `array` (a grid of sample values) onto a grid with resolution `shape`.
2620
2621  The source `array` is any object recognized by `ARRAYLIBS`.  It is interpreted as a grid
2622  with `len(shape)` domain coordinate dimensions, where each grid sample value has shape
2623  `array.shape[len(shape):]`.
2624
2625  Some examples:
2626
2627  - A grayscale image has `array.shape = height, width` and resizing it with `len(shape) == 2`
2628    produces a new image of scalar values.
2629  - An RGB image has `array.shape = height, width, 3` and resizing it with `len(shape) == 2`
2630    produces a new image of RGB values.
2631  - An 3D grid of 3x3 Jacobians has `array.shape = Z, Y, X, 3, 3` and resizing it with
2632    `len(shape) == 3` produces a new 3D grid of Jacobians.
2633
2634  This function also allows scaling and translation from the source domain to the output domain
2635  through the parameters `scale` and `translate`.  For more general transforms, see `resample`.
2636
2637  Args:
2638    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
2639      The array must have numeric type.  Its first `len(shape)` dimensions are the domain
2640      coordinate dimensions.  Each grid dimension must be at least 1 for a `'dual'` grid or
2641      at least 2 for a `'primal'` grid.
2642    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2643      `array` must have at least as many dimensions as `len(shape)`.
2644    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
2645      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
2646      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
2647    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2648      Parameters `gridtype` and `src_gridtype` cannot both be set.
2649    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2650      Parameters `gridtype` and `dst_gridtype` cannot both be set.
2651    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2652      a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses `'reflect'`
2653      for upsampling and `'clamp'` for downsampling.
2654    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
2655      onto `array.shape[len(shape):]`.  It is subject to `src_gamma`.
2656    filter: The reconstruction kernel for each dimension in `shape`, specified as either a filter
2657      name in `FILTERS` or a `Filter` instance.  It is used during upsampling (i.e., magnification).
2658    prefilter: The prefilter kernel for each dimension in `shape`, specified as either a filter
2659      name in `FILTERS` or a `Filter` instance.  It is used during downsampling
2660      (i.e., minification).  If `None`, it inherits the value of `filter`.  The default
2661      `'lanczos3'` is good for natural images.  For vector graphics images, `'trapezoid'` is better
2662      because it avoids ringing artifacts.
2663    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from
2664      `array` and when creating output grid samples.  It is specified as either a name in `GAMMAS`
2665      or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default is
2666      `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
2667      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
2668      range [0.0, 1.0].
2669    src_gamma: Component transfer function used to "decode" `array` samples.
2670      Parameters `gamma` and `src_gamma` cannot both be set.
2671    dst_gamma: Component transfer function used to "encode" the output samples.
2672      Parameters `gamma` and `dst_gamma` cannot both be set.
2673    scale: Scaling factor applied to each dimension of the source domain when it is mapped onto
2674      the destination domain.
2675    translate: Offset applied to each dimension of the scaled source domain when it is mapped onto
2676      the destination domain.
2677    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
2678      on `array.dtype` and `dtype`.
2679    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
2680      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
2681      to the uint range.
2682    dim_order: Override the automatically selected order in which the grid dimensions are resized.
2683      Must contain a permutation of `range(len(shape))`.
2684    num_threads: Used to determine multithread parallelism if `array` is from `numpy`.  If set to
2685      `'auto'`, it is selected automatically.  Otherwise, it must be a positive integer.
2686
2687  Returns:
2688    An array of the same class as the source `array`, with shape `shape + array.shape[len(shape):]`
2689      and data type `dtype`.
2690
2691  **Example of image upsampling:**
2692
2693  >>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
2694  >>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.
2695
2696  <center>
2697  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_upsampled.png"/>
2698  </center>
2699
2700  **Example of image downsampling:**
2701
2702  >>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
2703  >>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
2704  >>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
2705  >>> downsampled = resize(array, (24, 48))
2706
2707  <center>
2708  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_downsampled2.png"/>
2709  </center>
2710
2711  **Unit test:**
2712
2713  >>> result = resize([1.0, 4.0, 5.0], shape=(4,))
2714  >>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
2715  """
2716  if isinstance(array, (tuple, list)):
2717    array = np.asarray(array)
2718  arraylib = _arr_arraylib(array)
2719  array_dtype = _arr_dtype(array)
2720  if not np.issubdtype(array_dtype, np.number):
2721    raise ValueError(f'Type {array.dtype} is not numeric.')
2722  shape2 = tuple(shape)
2723  array_ndim = len(array.shape)
2724  if not 0 < len(shape2) <= array_ndim:
2725    raise ValueError(f'Shape {array.shape} cannot be resized to {shape2}.')
2726  src_shape = array.shape[: len(shape2)]
2727  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2728      gridtype, src_gridtype, dst_gridtype, len(shape2), len(shape2)
2729  )
2730  boundary2 = np.broadcast_to(np.array(boundary), len(shape2))
2731  cval = np.broadcast_to(cval, array.shape[len(shape2) :])
2732  prefilter = filter if prefilter is None else prefilter
2733  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape2))]
2734  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), len(shape2))]
2735  dtype = array_dtype if dtype is None else np.dtype(dtype)
2736  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, array_dtype, dtype)
2737  scale2 = np.broadcast_to(np.array(scale), len(shape2))
2738  translate2 = np.broadcast_to(np.array(translate), len(shape2))
2739  del (shape, src_gridtype, dst_gridtype, boundary, filter, prefilter)
2740  del (src_gamma, dst_gamma, scale, translate)
2741  precision = _get_precision(precision, [array_dtype, dtype], [])
2742  weight_precision = _real_precision(precision)
2743
2744  is_noop = (
2745      all(src == dst for src, dst in zip(src_shape, shape2))
2746      and all(gt1 == gt2 for gt1, gt2 in zip(src_gridtype2, dst_gridtype2))
2747      and all(f.interpolating for f in prefilter2)
2748      and np.all(scale2 == 1.0)
2749      and np.all(translate2 == 0.0)
2750      and src_gamma2 == dst_gamma2
2751  )
2752  if is_noop:
2753    return array
2754
2755  if dim_order is None:
2756    dim_order = _arr_best_dims_order_for_resize(array, shape2)
2757  else:
2758    dim_order = tuple(dim_order)
2759    if sorted(dim_order) != list(range(len(shape2))):
2760      raise ValueError(f'{dim_order} not a permutation of {list(range(len(shape2)))}.')
2761
2762  array = src_gamma2.decode(array, precision)
2763  cval = _arr_numpy(src_gamma2.decode(cval, precision))
2764
2765  can_use_fast_box_downsampling = (
2766      using_numba
2767      and arraylib == 'numpy'
2768      and len(shape2) == 2
2769      and array_ndim in (2, 3)
2770      and all(src > dst for src, dst in zip(src_shape, shape2))
2771      and all(src % dst == 0 for src, dst in zip(src_shape, shape2))
2772      and all(gridtype.name == 'dual' for gridtype in src_gridtype2)
2773      and all(gridtype.name == 'dual' for gridtype in dst_gridtype2)
2774      and all(f.name in ('box', 'trapezoid') for f in prefilter2)
2775      and np.all(scale2 == 1.0)
2776      and np.all(translate2 == 0.0)
2777  )
2778  if can_use_fast_box_downsampling:
2779    assert isinstance(array, np.ndarray)  # Help mypy.
2780    array = _downsample_in_2d_using_box_filter(array, cast(Any, shape2))
2781    array = dst_gamma2.encode(array, dtype)
2782    return array
2783
2784  # Multidimensional resize can be expressed using einsum() with multiple per-dim resize matrices,
2785  # e.g., as in jax.image.resize().  A benefit is to seek the optimal order of multiplications.
2786  # However, efficiency often requires sparse resize matrices, which are unsupported in einsum().
2787  # Sparse tensors requested for tf.einsum: https://github.com/tensorflow/tensorflow/issues/43497
2788  # https://github.com/tensor-compiler/taco: C++ library that computes tensor algebra expressions
2789  # on sparse and dense tensors; however it does not interoperate with tensorflow, torch, or jax.
2790
2791  for dim in dim_order:
2792    skip_resize_on_this_dim = (
2793        shape2[dim] == array.shape[dim]
2794        and scale2[dim] == 1.0
2795        and translate2[dim] == 0.0
2796        and filter2[dim].interpolating
2797    )
2798    if skip_resize_on_this_dim:
2799      continue
2800
2801    def get_is_minification() -> bool:
2802      src_in_samples = src_gridtype2[dim].size_in_samples(array.shape[dim])
2803      dst_in_samples = dst_gridtype2[dim].size_in_samples(shape2[dim])
2804      return dst_in_samples / src_in_samples * scale2[dim] < 1.0
2805
2806    is_minification = get_is_minification()
2807    boundary_dim = boundary2[dim]
2808    if boundary_dim == 'auto':
2809      boundary_dim = 'clamp' if is_minification else 'reflect'
2810    boundary_dim = _get_boundary(boundary_dim)
2811    resize_matrix, cval_weight = _create_resize_matrix(
2812        array.shape[dim],
2813        shape2[dim],
2814        src_gridtype=src_gridtype2[dim],
2815        dst_gridtype=dst_gridtype2[dim],
2816        boundary=boundary_dim,
2817        filter=filter2[dim],
2818        prefilter=prefilter2[dim],
2819        scale=scale2[dim],
2820        translate=translate2[dim],
2821        dtype=weight_precision,
2822        arraylib=arraylib,
2823    )
2824
2825    array_dim: _Array = _arr_moveaxis(array, dim, 0)
2826    array_flat = _arr_reshape(array_dim, (array_dim.shape[0], -1))
2827    array_flat = _arr_possibly_make_contiguous(array_flat)
2828    if not is_minification and filter2[dim].requires_digital_filter:
2829      array_flat = _apply_digital_filter_1d(
2830          array_flat, src_gridtype2[dim], boundary_dim, cval, filter2[dim]
2831      )
2832
2833    array_flat = _arr_matmul_sparse_dense(resize_matrix, array_flat, num_threads=num_threads)
2834    if cval_weight is not None:
2835      cval_flat = np.broadcast_to(cval, array_dim.shape[1:]).reshape(-1)
2836      if np.issubdtype(array_dtype, np.complexfloating):
2837        cval_weight = _arr_astype(cval_weight, array_dtype)  # (Only necessary for 'tensorflow'.)
2838      array_flat += cval_weight[:, None] * cval_flat
2839
2840    if is_minification and filter2[dim].requires_digital_filter:  # use prefilter2[dim]?
2841      array_flat = _apply_digital_filter_1d(
2842          array_flat, dst_gridtype2[dim], boundary_dim, cval, filter2[dim]
2843      )
2844    array_dim = _arr_reshape(array_flat, (array_flat.shape[0], *array_dim.shape[1:]))
2845    array = _arr_moveaxis(array_dim, 0, dim)
2846
2847  array = dst_gamma2.encode(array, dtype)
2848  return array
2849
2850
2851_original_resize = resize
2852
2853
2854def resize_in_arraylib(array: _NDArray, /, *args: Any, arraylib: str, **kwargs: Any) -> _NDArray:
2855  """Evaluate the `resize()` operation using the specified array library from `ARRAYLIBS`."""
2856  _check_eq(_arr_arraylib(array), 'numpy')
2857  return _arr_numpy(_original_resize(_make_array(array, arraylib), *args, **kwargs))
2858
2859
2860def resize_in_numpy(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2861  """Evaluate the `resize()` operation using the `numpy` library."""
2862  return resize_in_arraylib(array, *args, arraylib='numpy', **kwargs)
2863
2864
2865def resize_in_tensorflow(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2866  """Evaluate the `resize()` operation using the `tensorflow` library."""
2867  return resize_in_arraylib(array, *args, arraylib='tensorflow', **kwargs)
2868
2869
2870def resize_in_torch(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2871  """Evaluate the `resize()` operation using the `torch` library."""
2872  return resize_in_arraylib(array, *args, arraylib='torch', **kwargs)
2873
2874
2875def resize_in_jax(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2876  """Evaluate the `resize()` operation using the `jax` library."""
2877  return resize_in_arraylib(array, *args, arraylib='jax', **kwargs)
2878
2879
2880def _resize_possibly_in_arraylib(
2881    array: _Array, /, *args: Any, arraylib: str, **kwargs: Any
2882) -> _AnyArray:
2883  """If `array` is from numpy, evaluate `resize()` using the array library from `ARRAYLIBS`."""
2884  if _arr_arraylib(array) == 'numpy':
2885    return _arr_numpy(
2886        _original_resize(_make_array(cast(_ArrayLike, array), arraylib), *args, **kwargs)
2887    )
2888  return _original_resize(array, *args, **kwargs)
2889
2890
2891@functools.cache
2892def _create_jaxjit_resize() -> Callable[..., _Array]:
2893  """Lazily invoke `jax.jit` on `resize`."""
2894  import jax
2895
2896  jitted: Any = jax.jit(
2897      _original_resize, static_argnums=(1,), static_argnames=list(_original_resize.__kwdefaults__)
2898  )
2899  return jitted
2900
2901
2902def jaxjit_resize(array: _Array, /, *args: Any, **kwargs: Any) -> _Array:
2903  """Compute `resize` but with resize function jitted using Jax."""
2904  return _create_jaxjit_resize()(array, *args, **kwargs)  # pylint: disable=not-callable
2905
2906
2907def uniform_resize(
2908    array: _Array,
2909    /,
2910    shape: Iterable[int],
2911    *,
2912    object_fit: Literal['contain', 'cover'] = 'contain',
2913    gridtype: str | Gridtype | None = None,
2914    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2915    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2916    boundary: str | Boundary | Iterable[str | Boundary] = 'natural',  # Instead of 'auto' default.
2917    scale: float | Iterable[float] = 1.0,
2918    translate: float | Iterable[float] = 0.0,
2919    **kwargs: Any,
2920) -> _Array:
2921  """Resample `array` onto a grid with resolution `shape` but with uniform scaling.
2922
2923  Calls function `resize` with `scale` and `translate` set such that the aspect ratio of `array`
2924  is preserved.  The effect is similar to CSS `object-fit: contain`.
2925  The parameter `boundary` (whose default is changed to `'natural'`) determines the values assigned
2926  outside the source domain.
2927
2928  Args:
2929    array: Regular grid of source sample values.
2930    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2931      `array` must have at least as many dimensions as `len(shape)`.
2932    object_fit: Like CSS `object-fit`.  If `'contain'`, `array` is resized uniformly to fit within
2933      `shape`. If `'cover'`, `array` is resized to fully cover `shape`.
2934    gridtype: Placement of samples on all dimensions of both the source and output domain grids.
2935    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2936    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2937    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2938      a name in `BOUNDARIES` or a `Boundary` instance.  The default is `'natural'`, which assigns
2939      `cval` to output points that map outside the source unit domain.
2940    scale: Parameter may not be specified.
2941    translate: Parameter may not be specified.
2942    **kwargs: Additional parameters for `resize` function (including `cval`).
2943
2944  Returns:
2945    An array with shape `shape + array.shape[len(shape):]`.
2946
2947  >>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
2948  array([[0., 1., 1., 0.],
2949         [0., 1., 1., 0.]])
2950
2951  >>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
2952  array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
2953         [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
2954
2955  >>> a = np.arange(6.0).reshape(2, 3)
2956  >>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
2957  array([[0.5, 1.5],
2958         [3.5, 4.5]])
2959  """
2960  if scale != 1.0 or translate != 0.0:
2961    raise ValueError('`uniform_resize()` does not accept `scale` or `translate` parameters.')
2962  if isinstance(array, (tuple, list)):
2963    array = np.asarray(array)
2964  shape = tuple(shape)
2965  array_ndim = len(array.shape)
2966  if not 0 < len(shape) <= array_ndim:
2967    raise ValueError(f'Shape {array.shape} cannot be resized to {shape}.')
2968  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2969      gridtype, src_gridtype, dst_gridtype, len(shape), len(shape)
2970  )
2971  raw_scales = [
2972      dst_gridtype2[dim].size_in_samples(shape[dim])
2973      / src_gridtype2[dim].size_in_samples(array.shape[dim])
2974      for dim in range(len(shape))
2975  ]
2976  scale0 = {'contain': min(raw_scales), 'cover': max(raw_scales)}[object_fit]
2977  scale2 = scale0 / np.array(raw_scales)
2978  translate = (1.0 - scale2) / 2
2979  return resize(array, shape, boundary=boundary, scale=scale2, translate=translate, **kwargs)
2980
2981
2982_MAX_BLOCK_SIZE_RECURSING = -999  # Special value to indicate re-invocation on partitioned blocks.
2983
2984
2985def resample(
2986    array: _Array,
2987    /,
2988    coords: _ArrayLike,
2989    *,
2990    gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual',
2991    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2992    cval: _ArrayLike = 0.0,
2993    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2994    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2995    gamma: str | Gamma | None = None,
2996    src_gamma: str | Gamma | None = None,
2997    dst_gamma: str | Gamma | None = None,
2998    jacobian: _ArrayLike | None = None,
2999    precision: _DTypeLike = None,
3000    dtype: _DTypeLike = None,
3001    max_block_size: int = 40_000,
3002    debug: bool = False,
3003) -> _Array:
3004  """Interpolate `array` (a grid of samples) at specified unit-domain coordinates `coords`.
3005
3006  The last dimension of `coords` contains unit-domain coordinates at which to interpolate the
3007  domain grid samples in `array`.
3008
3009  The number of coordinates (`coords.shape[-1]`) determines how to interpret `array`: its first
3010  `coords.shape[-1]` dimensions define the grid, and the remaining dimensions describe each grid
3011  sample (e.g., scalar, vector, tensor).
3012
3013  Concretely, the grid has shape `array.shape[:coords.shape[-1]]` and each grid sample has shape
3014  `array.shape[coords.shape[-1]:]`.
3015
3016  Examples include:
3017
3018  - Resample a grayscale image with `array.shape = height, width` onto a new grayscale image with
3019    `new.shape = height2, width2` by using `coords.shape = height2, width2, 2`.
3020
3021  - Resample an RGB image with `array.shape = height, width, 3` onto a new RGB image with
3022    `new.shape = height2, width2, 3` by using `coords.shape = height2, width2, 2`.
3023
3024  - Sample an RGB image at `num` 2D points along a line segment by using `coords.shape = num, 2`.
3025
3026  - Sample an RGB image at a single 2D point by using `coords.shape = (2,)`.
3027
3028  - Sample a 3D grid of 3x3 Jacobians with `array.shape = nz, ny, nx, 3, 3` along a 2D plane by
3029    using `coords.shape = height, width, 3`.
3030
3031  - Map a grayscale image through a color map by using `array.shape = 256, 3` and
3032    `coords.shape = height, width`.
3033
3034  Args:
3035    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3036      The array must have numeric type.  The coordinate dimensions appear first, and
3037      each grid sample may have an arbitrary shape.  Each grid dimension must be at least 1 for
3038      a `'dual'` grid or at least 2 for a `'primal'` grid.
3039    coords: Grid of points at which to resample `array`.  The point coordinates are in the last
3040      dimension of `coords`.  The domain associated with the source grid is a unit hypercube,
3041      i.e. with a range [0, 1] on each coordinate dimension.  The output grid has shape
3042      `coords.shape[:-1]` and each of its grid samples has shape `array.shape[coords.shape[-1]:]`.
3043    gridtype: Placement of the samples in the source domain grid for each dimension, specified as
3044      either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`.
3045    boundary: The reconstruction boundary rule for each dimension in `coords.shape[-1]`, specified
3046      as either a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses
3047      `'reflect'` for upsampling and `'clamp'` for downsampling.
3048    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
3049      onto the shape `array.shape[coords.shape[-1]:]`.  It is subject to `src_gamma`.
3050    filter: The reconstruction kernel for each dimension in `coords.shape[-1]`, specified as either
3051      a filter name in `FILTERS` or a `Filter` instance.
3052    prefilter: The prefilter kernel for each dimension in `coords.shape[:-1]`, specified as either
3053      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3054      (i.e., minification).  If `None`, it inherits the value of `filter`.
3055    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples
3056      from `array` and when creating output grid samples.  It is specified as either a name in
3057      `GAMMAS` or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default
3058      is `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
3059      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
3060      range [0.0, 1.0].
3061    src_gamma: Component transfer function used to "decode" `array` samples.
3062      Parameters `gamma` and `src_gamma` cannot both be set.
3063    dst_gamma: Component transfer function used to "encode" the output samples.
3064      Parameters `gamma` and `dst_gamma` cannot both be set.
3065    jacobian: Optional array, which must be broadcastable onto the shape
3066      `coords.shape[:-1] + (coords.shape[-1], coords.shape[-1])`, storing for each point in the
3067      output grid the Jacobian matrix of the map from the unit output domain to the unit source
3068      domain.  If omitted, it is estimated by computing finite differences on `coords`.
3069    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3070      on `array.dtype`, `coords.dtype`, and `dtype`.
3071    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3072      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3073      to the uint range.
3074    max_block_size: If nonzero, maximum number of grid points in `coords` before the resampling
3075      evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
3076    debug: Show internal information.
3077
3078  Returns:
3079    A new sample grid of shape `coords.shape[:-1]`, represented as an array of shape
3080    `coords.shape[:-1] + array.shape[coords.shape[-1]:]`, of the same array library type as
3081    the source array.
3082
3083  **Example of resample operation:**
3084
3085  <center>
3086  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_warp_coords.png"/>
3087  </center>
3088
3089  For reference, the identity resampling for a scalar-valued grid with the default grid-type
3090  `'dual'` is:
3091
3092  >>> array = np.random.default_rng(0).random((5, 7, 3))
3093  >>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
3094  >>> new_array = resample(array, coords)
3095  >>> assert np.allclose(new_array, array)
3096
3097  It is more efficient to use the function `resize` for the special case where the `coords` are
3098  obtained as simple scaling and translation of a new regular grid over the source domain:
3099
3100  >>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
3101  >>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
3102  >>> coords = (coords - translate) / scale
3103  >>> resampled = resample(array, coords)
3104  >>> resized = resize(array, new_shape, scale=scale, translate=translate)
3105  >>> assert np.allclose(resampled, resized)
3106  """
3107  if isinstance(array, (tuple, list)):
3108    array = np.asarray(array)
3109  arraylib = _arr_arraylib(array)
3110  if len(array.shape) == 0:
3111    array = array[None]
3112  coords = np.atleast_1d(coords)
3113  if not np.issubdtype(_arr_dtype(array), np.number):
3114    raise ValueError(f'Type {array.dtype} is not numeric.')
3115  if not np.issubdtype(coords.dtype, np.floating):
3116    raise ValueError(f'Type {coords.dtype} is not floating.')
3117  array_ndim = len(array.shape)
3118  if coords.ndim == 1 and coords.shape[0] > 1 and array_ndim == 1:
3119    coords = coords[:, None]
3120  grid_ndim = coords.shape[-1]
3121  grid_shape = array.shape[:grid_ndim]
3122  sample_shape = array.shape[grid_ndim:]
3123  resampled_ndim = coords.ndim - 1
3124  resampled_shape = coords.shape[:-1]
3125  if grid_ndim > array_ndim:
3126    raise ValueError(
3127        f'There are more coordinate dimensions ({grid_ndim}) in {coords=} than in {array.shape=}.'
3128    )
3129  gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(gridtype), grid_ndim)]
3130  boundary2 = np.broadcast_to(np.array(boundary), grid_ndim).tolist()
3131  cval = np.broadcast_to(cval, sample_shape)
3132  prefilter = filter if prefilter is None else prefilter
3133  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), grid_ndim)]
3134  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), resampled_ndim)]
3135  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3136  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, _arr_dtype(array), dtype)
3137  del gridtype, boundary, filter, prefilter, src_gamma, dst_gamma
3138  if jacobian is not None:
3139    jacobian = np.broadcast_to(jacobian, resampled_shape + (coords.shape[-1],) * 2)
3140  precision = _get_precision(precision, [_arr_dtype(array), dtype], [coords.dtype])
3141  weight_precision = _real_precision(precision)
3142  coords = coords.astype(weight_precision, copy=False)
3143  is_minification = False  # Current limitation; no prefiltering!
3144  assert max_block_size >= 0 or max_block_size == _MAX_BLOCK_SIZE_RECURSING
3145  for dim in range(grid_ndim):
3146    if boundary2[dim] == 'auto':
3147      boundary2[dim] = 'clamp' if is_minification else 'reflect'
3148    boundary2[dim] = _get_boundary(boundary2[dim])
3149
3150  if max_block_size != _MAX_BLOCK_SIZE_RECURSING:
3151    array = src_gamma2.decode(array, precision)
3152    for dim in range(grid_ndim):
3153      assert not is_minification
3154      if filter2[dim].requires_digital_filter:
3155        array = _apply_digital_filter_1d(
3156            array, gridtype2[dim], boundary2[dim], cval, filter2[dim], axis=dim
3157        )
3158    cval = _arr_numpy(src_gamma2.decode(cval, precision))
3159
3160  if math.prod(resampled_shape) > max_block_size > 0:
3161    block_shape = _block_shape_with_min_size(resampled_shape, max_block_size)
3162    if debug:
3163      print(f'(resample: splitting coords into blocks {block_shape}).')
3164    coord_blocks = _split_array_into_blocks(coords, block_shape)
3165
3166    def process_block(coord_block: _NDArray) -> _Array:
3167      return resample(
3168          array,
3169          coord_block,
3170          gridtype=gridtype2,
3171          boundary=boundary2,
3172          cval=cval,
3173          filter=filter2,
3174          prefilter=prefilter2,
3175          src_gamma='identity',
3176          dst_gamma=dst_gamma2,
3177          jacobian=jacobian,
3178          precision=precision,
3179          dtype=dtype,
3180          max_block_size=_MAX_BLOCK_SIZE_RECURSING,
3181      )
3182
3183    result_blocks = _map_function_over_blocks(coord_blocks, process_block)
3184    array = _merge_array_from_blocks(result_blocks)
3185    return array
3186
3187  # A concrete example of upsampling:
3188  #   array = np.ones((5, 7, 3))  # source RGB image has height=5 width=7
3189  #   coords = np.random.default_rng(0).random((8, 9, 2))  # output RGB image has height=8 width=9
3190  #   resample(array, coords, filter=('cubic', 'lanczos3'))
3191  #   grid_shape = 5, 7  grid_ndim = 2
3192  #   resampled_shape = 8, 9  resampled_ndim = 2
3193  #   sample_shape = (3,)
3194  #   src_float_index.shape = 8, 9
3195  #   src_first_index.shape = 8, 9
3196  #   sample_index.shape = (4,) for dim == 0, then (6,) for dim == 1
3197  #   weight = [shape(8, 9, 4), shape(8, 9, 6)]
3198  #   src_index = [shape(8, 9, 4), shape(8, 9, 6)]
3199
3200  # Both:[shape(8, 9, 4), shape(8, 9, 6)]
3201  weight: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3202  src_index: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3203  uses_cval = False
3204  all_num_samples = []  # will be [4, 6]
3205
3206  for dim in range(grid_ndim):
3207    src_size = grid_shape[dim]  # scalar
3208    coords_dim = coords[..., dim]  # (8, 9)
3209    radius = filter2[dim].radius  # scalar
3210    num_samples = int(np.ceil(radius * 2))  # scalar
3211    all_num_samples.append(num_samples)
3212
3213    boundary_dim = boundary2[dim]
3214    coords_dim = boundary_dim.preprocess_coordinates(coords_dim)
3215
3216    # Sample positions mapped back to source unit domain [0, 1].
3217    src_float_index = gridtype2[dim].index_from_point(coords_dim, src_size)  # (8, 9)
3218    src_first_index = (
3219        np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
3220        - (num_samples - 1) // 2
3221    )  # (8, 9)
3222
3223    sample_index = np.arange(num_samples, dtype=np.int32)  # (4,) then (6,)
3224    src_index[dim] = src_first_index[..., None] + sample_index  # (8, 9, 4) then (8, 9, 6)
3225    if filter2[dim].name == 'trapezoid':
3226      # (It might require changing the filter radius at every sample.)
3227      raise ValueError('resample() cannot use adaptive `trapezoid` filter.')
3228    if filter2[dim].name == 'impulse':
3229      weight[dim] = np.ones_like(src_index[dim], weight_precision)
3230    else:
3231      x = src_float_index[..., None] - src_index[dim].astype(weight_precision)
3232      weight[dim] = filter2[dim](x).astype(weight_precision, copy=False)
3233      if filter2[dim].name != 'narrowbox' and (
3234          is_minification or not filter2[dim].partition_of_unity
3235      ):
3236        weight[dim] = weight[dim] / weight[dim].sum(axis=-1)[..., None]
3237
3238    src_index[dim], weight[dim] = boundary_dim.apply(
3239        src_index[dim], weight[dim], coords_dim, src_size, gridtype2[dim]
3240    )
3241    if boundary_dim.uses_cval or filter2[dim].name == 'narrowbox':
3242      uses_cval = True
3243
3244  # Gather the samples.
3245
3246  # Recall that src_index = [shape(8, 9, 4), shape(8, 9, 6)].
3247  src_index_expanded = []
3248  for dim in range(grid_ndim):
3249    src_index_dim = np.moveaxis(
3250        src_index[dim].reshape(src_index[dim].shape + (1,) * (grid_ndim - 1)),
3251        resampled_ndim,
3252        resampled_ndim + dim,
3253    )
3254    src_index_expanded.append(src_index_dim)
3255  indices = tuple(src_index_expanded)  # (shape(8, 9, 4, 1), shape(8, 9, 1, 6))
3256  samples = _arr_getitem(array, indices)  # (8, 9, 4, 6, 3)
3257
3258  # Indirectly derive samples.ndim (which is unavailable during Tensorflow grad computation).
3259  samples_ndim = resampled_ndim + grid_ndim + len(sample_shape)
3260
3261  # Compute an Einstein summation over the samples and each of the per-dimension weights.
3262
3263  def label(dims: Iterable[int]) -> str:
3264    return ''.join(chr(ord('a') + i) for i in dims)
3265
3266  operands = [samples]  # (8, 9, 4, 6, 3)
3267  assert samples_ndim < 26  # Letters 'a' through 'z'.
3268  labels = [label(range(samples_ndim))]  # ['abcde']
3269  for dim in range(grid_ndim):
3270    operands.append(weight[dim])  # (8, 9, 4), then (8, 9, 6)
3271    labels.append(label(list(range(resampled_ndim)) + [resampled_ndim + dim]))  # 'abc' then 'abd'
3272  output_label = label(
3273      list(range(resampled_ndim)) + list(range(resampled_ndim + grid_ndim, samples_ndim))
3274  )  # 'abe'
3275  subscripts = ','.join(labels) + '->' + output_label  # 'abcde,abc,abd->abe'
3276  # Starting in numpy 2.0, np.einsum() outputs np.float64 even with all np.float32 inputs;
3277  # GPT: "aligns np.einsum with other functions where intermediate calculations use higher
3278  # precision (np.float64) regardless of input type when floating-point arithmetic is involved."
3279  # we could explicitly add the parameter `dtype=precision`.
3280  array = _arr_einsum(subscripts, *operands)  # (8, 9, 3)
3281
3282  # Gathering `samples` is the memory bottleneck.  It would be ideal if the gather() and einsum()
3283  # computations could be fused.  In Jax, https://github.com/google/jax/issues/3206 suggests
3284  # that this may become possible.  In any case, for large outputs it helps to partition the
3285  # evaluation over output tiles (using max_block_size).
3286
3287  if uses_cval:
3288    cval_weight = 1.0 - np.multiply.reduce(
3289        [weight[dim].sum(axis=-1) for dim in range(resampled_ndim)]
3290    )  # (8, 9)
3291    cval_weight_reshaped = cval_weight.reshape(cval_weight.shape + (1,) * len(sample_shape))
3292    array += _make_array((cval_weight_reshaped * cval).astype(precision, copy=False), arraylib)
3293
3294  array = dst_gamma2.encode(array, dtype)
3295  return array
3296
3297
3298def resample_affine(
3299    array: _Array,
3300    /,
3301    shape: Iterable[int],
3302    matrix: _ArrayLike,
3303    *,
3304    gridtype: str | Gridtype | None = None,
3305    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3306    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3307    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3308    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3309    precision: _DTypeLike = None,
3310    dtype: _DTypeLike = None,
3311    **kwargs: Any,
3312) -> _Array:
3313  """Resample a source array using an affinely transformed grid of given shape.
3314
3315  The `matrix` transformation can be linear,
3316    `source_point = matrix @ destination_point`,
3317  or it can be affine where the last matrix column is an offset vector,
3318    `source_point = matrix @ (destination_point, 1.0)`.
3319
3320  Args:
3321    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3322      The array must have numeric type.  The number of grid dimensions is determined from
3323      `matrix.shape[0]`; the remaining dimensions are for each sample value and are all
3324      linearly interpolated.
3325    shape: Dimensions of the desired destination grid.  The number of destination grid dimensions
3326      may be different from that of the source grid.
3327    matrix: 2D array for a linear or affine transform from unit-domain destination points
3328      (in a space with `len(shape)` dimensions) into unit-domain source points (in a space with
3329      `matrix.shape[0]` dimensions).  If the matrix has `len(shape) + 1` columns, the last column
3330      is the affine offset (i.e., translation).
3331    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
3332      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
3333      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
3334    src_gridtype: Placement of samples in the source domain grid for each dimension.
3335      Parameters `gridtype` and `src_gridtype` cannot both be set.
3336    dst_gridtype: Placement of samples in the output domain grid for each dimension.
3337      Parameters `gridtype` and `dst_gridtype` cannot both be set.
3338    filter: The reconstruction kernel for each dimension in `matrix.shape[0]`, specified as either
3339      a filter name in `FILTERS` or a `Filter` instance.
3340    prefilter: The prefilter kernel for each dimension in `len(shape)`, specified as either
3341      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3342      (i.e., minification).  If `None`, it inherits the value of `filter`.
3343    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3344      on `array.dtype` and `dtype`.
3345    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3346      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3347      to the uint range.
3348    **kwargs: Additional parameters for `resample` function.
3349
3350  Returns:
3351    An array of the same class as the source `array`, representing a grid with specified `shape`,
3352    where each grid value is resampled from `array`.  Thus the shape of the returned array is
3353    `shape + array.shape[matrix.shape[0]:]`.
3354  """
3355  if isinstance(array, (tuple, list)):
3356    array = np.asarray(array)
3357  shape = tuple(shape)
3358  matrix = np.asarray(matrix)
3359  dst_ndim = len(shape)
3360  if matrix.ndim != 2:
3361    raise ValueError(f'Array {matrix} is not 2D matrix.')
3362  src_ndim = matrix.shape[0]
3363  # grid_shape = array.shape[:src_ndim]
3364  is_affine = matrix.shape[1] == dst_ndim + 1
3365  if src_ndim > len(array.shape):
3366    raise ValueError(
3367        f'Matrix {matrix} has more rows ({matrix.shape[0]}) than ndim in {array.shape=}.'
3368    )
3369  if matrix.shape[1] != dst_ndim and not is_affine:
3370    raise ValueError(
3371        f'Matrix has {matrix.shape=}, but we expect either {dst_ndim} or {dst_ndim + 1} columns.'
3372    )
3373  src_gridtype2, dst_gridtype2 = _get_gridtypes(
3374      gridtype, src_gridtype, dst_gridtype, src_ndim, dst_ndim
3375  )
3376  prefilter = filter if prefilter is None else prefilter
3377  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), src_ndim)]
3378  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), dst_ndim)]
3379  del src_gridtype, dst_gridtype, filter, prefilter
3380  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3381  precision = _get_precision(precision, [_arr_dtype(array), dtype], [])
3382  weight_precision = _real_precision(precision)
3383
3384  dst_position_list = []  # per dimension
3385  for dim in range(dst_ndim):
3386    dst_size = shape[dim]
3387    dst_index = np.arange(dst_size, dtype=weight_precision)
3388    dst_position_list.append(dst_gridtype2[dim].point_from_index(dst_index, dst_size))
3389  dst_position = np.meshgrid(*dst_position_list, indexing='ij')
3390
3391  linear_matrix = matrix[:, :-1] if is_affine else matrix
3392  src_position = np.tensordot(linear_matrix, dst_position, 1)
3393  coords = np.moveaxis(src_position, 0, -1)
3394  if is_affine:
3395    coords += matrix[:, -1]
3396
3397  # TODO: Based on grid_shape, shape, linear_matrix, and prefilter, determine a
3398  # convolution prefilter and apply it to bandlimit 'array', using boundary for padding.
3399
3400  return resample(
3401      array,
3402      coords,
3403      gridtype=src_gridtype2,
3404      filter=filter2,
3405      prefilter=prefilter2,
3406      precision=precision,
3407      dtype=dtype,
3408      **kwargs,
3409  )
3410
3411
3412def _resize_using_resample(
3413    array: _Array,
3414    /,
3415    shape: Iterable[int],
3416    *,
3417    scale: _ArrayLike = 1.0,
3418    translate: _ArrayLike = 0.0,
3419    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3420    fallback: bool = False,
3421    **kwargs: Any,
3422) -> _Array:
3423  """Use the more general `resample` operation for `resize`, as a debug tool."""
3424  if isinstance(array, (tuple, list)):
3425    array = np.asarray(array)
3426  shape = tuple(shape)
3427  scale = np.broadcast_to(scale, len(shape))
3428  translate = np.broadcast_to(translate, len(shape))
3429  # TODO: let resample() do prefiltering for proper downsampling.
3430  has_minification = np.any(np.array(shape) < array.shape[: len(shape)]) or np.any(scale < 1.0)
3431  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape))]
3432  has_auto_trapezoid = any(f.name == 'trapezoid' for f in filter2)
3433  if fallback and (has_minification or has_auto_trapezoid):
3434    return _original_resize(array, shape, scale=scale, translate=translate, filter=filter, **kwargs)
3435  offset = -translate / scale
3436  matrix = np.concatenate([np.diag(1.0 / scale), offset[:, None]], axis=1)
3437  return resample_affine(array, shape, matrix, filter=filter, **kwargs)
3438
3439
3440def rotation_about_center_in_2d(
3441    src_shape: _ArrayLike,
3442    /,
3443    angle: float,
3444    *,
3445    new_shape: _ArrayLike | None = None,
3446    scale: float = 1.0,
3447) -> _NDArray:
3448  """Return the 3x3 matrix mapping destination into a source unit domain.
3449
3450  The returned matrix accounts for the possibly non-square domain shapes.
3451
3452  Args:
3453    src_shape: Resolution `(ny, nx)` of the source domain grid.
3454    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3455      onto the destination domain.
3456    new_shape: Resolution `(ny, nx)` of the destination domain grid; it defaults to `src_shape`.
3457    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3458  """
3459
3460  def translation_matrix(vector: _NDArray) -> _NDArray:
3461    matrix = np.eye(len(vector) + 1)
3462    matrix[:-1, -1] = vector
3463    return matrix
3464
3465  def scaling_matrix(scale: _NDArray) -> _NDArray:
3466    return np.diag(tuple(scale) + (1.0,))
3467
3468  def rotation_matrix_2d(angle: float) -> _NDArray:
3469    cos, sin = np.cos(angle), np.sin(angle)
3470    return np.array([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
3471
3472  src_shape = np.asarray(src_shape)
3473  new_shape = src_shape if new_shape is None else np.asarray(new_shape)
3474  _check_eq(src_shape.shape, (2,))
3475  _check_eq(new_shape.shape, (2,))
3476  half = np.array([0.5, 0.5])
3477  matrix = (
3478      translation_matrix(half)
3479      @ scaling_matrix(min(src_shape) / src_shape)
3480      @ rotation_matrix_2d(angle)
3481      @ scaling_matrix(scale * new_shape / min(new_shape))
3482      @ translation_matrix(-half)
3483  )
3484  assert np.allclose(matrix[-1], [0.0, 0.0, 1.0])
3485  return matrix
3486
3487
3488def rotate_image_about_center(
3489    image: _NDArray,
3490    /,
3491    angle: float,
3492    *,
3493    new_shape: _ArrayLike | None = None,
3494    scale: float = 1.0,
3495    num_rotations: int = 1,
3496    **kwargs: Any,
3497) -> _NDArray:
3498  """Return a copy of `image` rotated about its center.
3499
3500  Args:
3501    image: Source grid samples; the first two dimensions are spatial (ny, nx).
3502    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3503      onto the destination domain.
3504    new_shape: Resolution `(ny, nx)` of the output grid; it defaults to `image.shape[:2]`.
3505    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3506    num_rotations: Number of rotations (each by `angle`).  Successive resamplings are useful in
3507      analyzing the filtering quality.
3508    **kwargs: Additional parameters for `resample_affine`.
3509  """
3510  new_shape = image.shape[:2] if new_shape is None else np.asarray(new_shape)
3511  matrix = rotation_about_center_in_2d(image.shape[:2], angle, new_shape=new_shape, scale=scale)
3512  for _ in range(num_rotations):
3513    image = resample_affine(image, new_shape, matrix[:-1], **kwargs)
3514  return image
3515
3516
3517def pil_image_resize(
3518    array: _ArrayLike,
3519    /,
3520    shape: Iterable[int],
3521    *,
3522    filter: str,
3523    boundary: str = 'natural',
3524    cval: float = 0.0,
3525) -> _NDArray:
3526  """Invoke `PIL.Image.resize` using the same parameters as `resize`."""
3527  import PIL.Image
3528
3529  if boundary != 'natural':
3530    raise ValueError(f"{boundary=} must equal 'natural'.")
3531  del cval
3532  array = np.asarray(array)
3533  assert 1 <= array.ndim <= 3
3534  assert np.issubdtype(array.dtype, np.floating)
3535  shape = tuple(shape)
3536  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3537  if array.ndim == 1:
3538    return pil_image_resize(array[None], (1, *shape), filter=filter)[0]
3539  if not hasattr(PIL.Image, 'Resampling'):  # Pillow<9.0
3540    PIL.Image.Resampling = PIL.Image
3541  filters = {
3542      'impulse': PIL.Image.Resampling.NEAREST,
3543      'box': PIL.Image.Resampling.BOX,
3544      'triangle': PIL.Image.Resampling.BILINEAR,
3545      'hamming1': PIL.Image.Resampling.HAMMING,
3546      'cubic': PIL.Image.Resampling.BICUBIC,
3547      'lanczos3': PIL.Image.Resampling.LANCZOS,
3548  }
3549  if filter not in filters:
3550    raise ValueError(f'{filter=} not in {filters=}.')
3551  pil_resample = filters[filter]
3552  if array.ndim == 2:
3553    return np.array(
3554        PIL.Image.fromarray(array).resize(shape[::-1], resample=pil_resample), array.dtype
3555    )
3556  stack = []
3557  for channel in np.moveaxis(array, -1, 0):
3558    pil_image = PIL.Image.fromarray(channel).resize(shape[::-1], resample=pil_resample)
3559    stack.append(np.array(pil_image, array.dtype))
3560  return np.dstack(stack)
3561
3562
3563def cv_resize(
3564    array: _ArrayLike,
3565    /,
3566    shape: Iterable[int],
3567    *,
3568    filter: str,
3569    boundary: str = 'clamp',
3570    cval: float = 0.0,
3571) -> _NDArray:
3572  """Invoke `cv.resize` using the same parameters as `resize`."""
3573  import cv2 as cv
3574
3575  if boundary != 'clamp':
3576    raise ValueError(f"{boundary=} must equal 'clamp'.")
3577  del cval
3578  array = np.asarray(array)
3579  assert 1 <= array.ndim <= 3
3580  shape = tuple(shape)
3581  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3582  if array.ndim == 1:
3583    return cv_resize(array[None], (1, *shape), filter=filter)[0]
3584  filters = {
3585      'impulse': cv.INTER_NEAREST,  # Or consider cv.INTER_NEAREST_EXACT.
3586      'triangle': cv.INTER_LINEAR_EXACT,  # Or just cv.INTER_LINEAR.
3587      'trapezoid': cv.INTER_AREA,
3588      'sharpcubic': cv.INTER_CUBIC,
3589      'lanczos4': cv.INTER_LANCZOS4,
3590  }
3591  if filter not in filters:
3592    raise ValueError(f'{filter=} not in {filters=}.')
3593  interpolation = filters[filter]
3594  result = cv.resize(array, shape[::-1], interpolation=interpolation)
3595  if array.ndim == 3 and result.ndim == 2:
3596    assert array.shape[2] == 1
3597    return result[..., None]  # Add back the last dimension dropped by cv.resize().
3598  return result
3599
3600
3601def scipy_ndimage_resize(
3602    array: _ArrayLike,
3603    /,
3604    shape: Iterable[int],
3605    *,
3606    filter: str,
3607    boundary: str = 'reflect',
3608    cval: float = 0.0,
3609    scale: float | Iterable[float] = 1.0,
3610    translate: float | Iterable[float] = 0.0,
3611) -> _NDArray:
3612  """Invoke `scipy.ndimage.map_coordinates` using the same parameters as `resize`."""
3613  array = np.asarray(array)
3614  shape = tuple(shape)
3615  assert 1 <= len(shape) <= array.ndim
3616  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3617  if filter not in filters:
3618    raise ValueError(f'{filter=} not in {filters=}.')
3619  order = filters[filter]
3620  boundaries = {'reflect': 'reflect', 'wrap': 'grid-wrap', 'clamp': 'nearest', 'border': 'constant'}
3621  if boundary not in boundaries:
3622    raise ValueError(f'{boundary=} not in {boundaries=}.')
3623  mode = boundaries[boundary]
3624  shape_all = shape + array.shape[len(shape) :]
3625  coords = np.moveaxis(np.indices(shape_all, array.dtype), 0, -1)
3626  coords[..., : len(shape)] = (
3627      (coords[..., : len(shape)] + 0.5) / shape - np.asarray(translate)
3628  ) / np.asarray(scale) * np.array(array.shape)[: len(shape)] - 0.5
3629  coords = np.moveaxis(coords, -1, 0)
3630  return scipy.ndimage.map_coordinates(array, coords, order=order, mode=mode, cval=cval)
3631
3632
3633def skimage_transform_resize(
3634    array: _ArrayLike,
3635    /,
3636    shape: Iterable[int],
3637    *,
3638    filter: str,
3639    boundary: str = 'reflect',
3640    cval: float = 0.0,
3641) -> _NDArray:
3642  """Invoke `skimage.transform.resize` using the same parameters as `resize`."""
3643  import skimage.transform
3644
3645  array = np.asarray(array)
3646  shape = tuple(shape)
3647  assert 1 <= len(shape) <= array.ndim
3648  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3649  if filter not in filters:
3650    raise ValueError(f'{filter=} not in {filters=}.')
3651  order = filters[filter]
3652  boundaries = {'reflect': 'symmetric', 'wrap': 'wrap', 'clamp': 'edge', 'border': 'constant'}
3653  if boundary not in boundaries:
3654    raise ValueError(f'{boundary=} not in {boundaries=}.')
3655  mode = boundaries[boundary]
3656  shape_all = shape + array.shape[len(shape) :]
3657  # Default anti_aliasing=None automatically enables (poor) Gaussian prefilter if downsampling.
3658  # clip=False is the default behavior in `resampler` if the output type is non-integer.
3659  return skimage.transform.resize(
3660      array, shape_all, order=order, mode=mode, cval=cval, clip=False
3661  )  # type: ignore[no-untyped-call]
3662
3663
3664_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER = {
3665    'impulse': 'nearest',
3666    'trapezoid': 'area',
3667    'triangle': 'bilinear',
3668    'mitchell': 'mitchellcubic',
3669    'cubic': 'bicubic',
3670    'lanczos3': 'lanczos3',
3671    'lanczos5': 'lanczos5',
3672    # GaussianFilter(0.5): 'gaussian',  # radius_4 > desired_radius_3.
3673}
3674
3675
3676def tf_image_resize(
3677    array: _ArrayLike,
3678    /,
3679    shape: Iterable[int],
3680    *,
3681    filter: str,
3682    boundary: str = 'natural',
3683    cval: float = 0.0,
3684    antialias: bool = True,
3685) -> _TensorflowTensor:
3686  """Invoke `tf.image.resize` using the same parameters as `resize`."""
3687  import tensorflow as tf
3688
3689  if filter not in _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER:
3690    raise ValueError(f'{filter=} not in {_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER=}.')
3691  if boundary != 'natural':
3692    raise ValueError(f"{boundary=} must equal 'natural'.")
3693  del cval
3694  array2 = tf.convert_to_tensor(array)
3695  ndim = len(array2.shape)
3696  del array
3697  assert 1 <= ndim <= 3
3698  shape = tuple(shape)
3699  _check_eq(len(shape), 2 if ndim >= 2 else 1)
3700  match ndim:
3701    case 1:
3702      return tf_image_resize(array2[None], (1, *shape), filter=filter, antialias=antialias)[0]
3703    case 2:
3704      return tf_image_resize(array2[..., None], shape, filter=filter, antialias=antialias)[..., 0]
3705    case _:
3706      method = _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER[filter]
3707      return tf.image.resize(array2, shape, method=method, antialias=antialias)
3708
3709
3710_TORCH_INTERPOLATE_MODE_FROM_FILTER = {
3711    'impulse': 'nearest-exact',  # ('nearest' matches buggy OpenCV's INTER_NEAREST)
3712    'trapezoid': 'area',
3713    'triangle': 'bilinear',
3714    'sharpcubic': 'bicubic',
3715}
3716
3717
3718def torch_nn_resize(
3719    array: _ArrayLike,
3720    /,
3721    shape: Iterable[int],
3722    *,
3723    filter: str,
3724    boundary: str = 'clamp',
3725    cval: float = 0.0,
3726    antialias: bool = False,
3727) -> _TorchTensor:
3728  """Invoke `torch.nn.functional.interpolate` using the same parameters as `resize`."""
3729  import torch
3730
3731  if filter not in _TORCH_INTERPOLATE_MODE_FROM_FILTER:
3732    raise ValueError(f'{filter=} not in {_TORCH_INTERPOLATE_MODE_FROM_FILTER=}.')
3733  if boundary != 'clamp':
3734    raise ValueError(f"{boundary=} must equal 'clamp'.")
3735  del cval
3736  a = torch.as_tensor(array)
3737  del array
3738  assert 1 <= a.ndim <= 3
3739  shape = tuple(shape)
3740  _check_eq(len(shape), 2 if a.ndim >= 2 else 1)
3741  mode = _TORCH_INTERPOLATE_MODE_FROM_FILTER[filter]
3742
3743  def local_resize(a: _TorchTensor) -> _TorchTensor:
3744    # For upsampling, BILINEAR antialias is same PSNR and slower,
3745    #  and BICUBIC antialias is worse PSNR and faster.
3746    # For downsampling, antialias improves PSNR for both BILINEAR and BICUBIC.
3747    # Default align_corners=None corresponds to False which is what we desire.
3748    return torch.nn.functional.interpolate(a, shape, mode=mode, antialias=antialias)
3749
3750  match a.ndim:
3751    case 1:
3752      shape = (1, *shape)
3753      return local_resize(a[None, None, None])[0, 0, 0]
3754    case 2:
3755      return local_resize(a[None, None])[0, 0]
3756    case _:
3757      return local_resize(a.moveaxis(2, 0)[None])[0].moveaxis(0, 2)
3758
3759
3760def jax_image_resize(
3761    array: _ArrayLike,
3762    /,
3763    shape: Iterable[int],
3764    *,
3765    filter: str,
3766    boundary: str = 'natural',
3767    cval: float = 0.0,
3768    scale: float | Iterable[float] = 1.0,
3769    translate: float | Iterable[float] = 0.0,
3770) -> _JaxArray:
3771  """Invoke `jax.image.scale_and_translate` using the same parameters as `resize`."""
3772  import jax.image
3773  import jax.numpy as jnp
3774
3775  filters = 'triangle cubic lanczos3 lanczos5'.split()
3776  if filter not in filters:
3777    raise ValueError(f'{filter=} not in {filters=}.')
3778  if boundary != 'natural':
3779    raise ValueError(f"{boundary=} must equal 'natural'.")
3780  # When `scale` or `translate` are applied, any region outside the unit domain is assigned value 0.
3781  # To be consistent, the parameter `cval` must be zero.
3782  if scale != 1.0 and cval != 0.0:
3783    raise ValueError(f'Non-unity {scale=} requires that {cval=} be zero.')
3784  if translate != 0.0 and cval != 0.0:
3785    raise ValueError(f'Nonzero {translate=} requires that {cval=} be zero.')
3786  array2 = jnp.asarray(array)
3787  del array
3788  shape = tuple(shape)
3789  assert len(shape) <= array2.ndim
3790  completed_shape = shape + (1,) * (array2.ndim - len(shape))
3791  spatial_dims = list(range(len(shape)))
3792  scale2 = np.broadcast_to(np.array(scale), len(shape))
3793  scale2 = scale2 / np.array(array2.shape[: len(shape)]) * np.array(shape)
3794  translate2 = np.broadcast_to(np.array(translate), len(shape))
3795  translate2 = translate2 * np.array(shape)
3796  return jax.image.scale_and_translate(
3797      array2, completed_shape, spatial_dims, scale2, translate2, filter
3798  )
3799
3800
3801_CANDIDATE_RESIZERS = {
3802    'resampler.resize': resize,
3803    'PIL.Image.resize': pil_image_resize,
3804    'cv.resize': cv_resize,
3805    'scipy.ndimage.map_coordinates': scipy_ndimage_resize,
3806    'skimage.transform.resize': skimage_transform_resize,
3807    'tf.image.resize': tf_image_resize,
3808    'torch.nn.functional.interpolate': torch_nn_resize,
3809    'jax.image.scale_and_translate': jax_image_resize,
3810}
3811
3812
3813def _resizer_is_available(library_function: str) -> bool:
3814  """Return whether the resizer is available as an installed package."""
3815  top_name = library_function.split('.', 1)[0]
3816  module = {'PIL': 'Pillow', 'cv': 'cv2', 'tf': 'tensorflow'}.get(top_name, top_name)
3817  return importlib.util.find_spec(module) is not None
3818
3819
3820_RESIZERS = {
3821    library_function: resizer
3822    for library_function, resizer in _CANDIDATE_RESIZERS.items()
3823    if _resizer_is_available(library_function)
3824}
3825
3826
3827def _find_closest_filter(filter: str, resizer: Callable[..., Any]) -> str:
3828  """Return the filter supported by `resizer` (i.e., `*_resize`) that is closest to `filter`."""
3829  match filter:
3830    case 'box_like':
3831      return {
3832          cv_resize: 'trapezoid',
3833          skimage_transform_resize: 'box',
3834          tf_image_resize: 'trapezoid',
3835          torch_nn_resize: 'trapezoid',
3836      }.get(resizer, 'box')
3837    case 'cubic_like':
3838      return {
3839          cv_resize: 'sharpcubic',
3840          scipy_ndimage_resize: 'cardinal3',
3841          skimage_transform_resize: 'cardinal3',
3842          torch_nn_resize: 'sharpcubic',
3843      }.get(resizer, 'cubic')
3844    case 'high_quality':
3845      return {
3846          pil_image_resize: 'lanczos3',
3847          cv_resize: 'lanczos4',
3848          scipy_ndimage_resize: 'cardinal5',
3849          skimage_transform_resize: 'cardinal5',
3850          torch_nn_resize: 'sharpcubic',
3851      }.get(resizer, 'lanczos5')
3852    case _:
3853      return filter
3854
3855
3856# For Emacs:
3857# Local Variables:
3858# fill-column: 100
3859# End:
using_numba = False
def is_available(arraylib: str) -> bool:
806def is_available(arraylib: str) -> bool:
807  """Return whether the array library (e.g. 'tensorflow') is available as an installed package."""
808  return importlib.util.find_spec(arraylib) is not None  # Faster than trying to import it.

Return whether the array library (e.g. 'tensorflow') is available as an installed package.

ARRAYLIBS: list[str] = ['numpy']

Array libraries supported automatically in the resize and resampling operations.

  • The library is selected automatically based on the type of the array function parameter.

  • The class _Arraylib provides library-specific implementations of needed basic functions.

  • The _arr_*() functions dispatch the _Arraylib methods based on the array type.

@dataclasses.dataclass(frozen=True)
class Gridtype(abc.ABC):
1034@dataclasses.dataclass(frozen=True)
1035class Gridtype(abc.ABC):
1036  """Abstract base class for grid-types such as `'dual'` and `'primal'`.
1037
1038  In resampling operations, the grid-type may be specified separately as `src_gridtype` for the
1039  source domain and `dst_gridtype` for the destination domain.  Moreover, the grid-type may be
1040  specified per domain dimension.
1041
1042  Examples:
1043    `resize(source, shape, gridtype='primal')`  # Sets both src and dst to be `'primal'` grids.
1044
1045    `resize(source, shape, src_gridtype=['dual', 'primal'],
1046            dst_gridtype='dual')`  # Source is `'dual'` in dim0 and `'primal'` in dim1.
1047  """
1048
1049  name: str
1050  """Gridtype name."""
1051
1052  @abc.abstractmethod
1053  def min_size(self) -> int:
1054    """Return the necessary minimum number of grid samples."""
1055
1056  @abc.abstractmethod
1057  def size_in_samples(self, size: int, /) -> int:
1058    """Return the domain size in units of inter-sample spacing."""
1059
1060  @abc.abstractmethod
1061  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1062    """Return [0.0, 1.0] coordinates given [0, size - 1] indices."""
1063
1064  @abc.abstractmethod
1065  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1066    """Return location x given coordinates [0.0, 1.0], where x == 0.0 is the first grid sample
1067    and x == size - 1.0 is the last grid sample."""
1068
1069  @abc.abstractmethod
1070  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1071    """Map integer sample indices to interior ones using boundary reflection."""
1072
1073  @abc.abstractmethod
1074  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1075    """Map integer sample indices to interior ones using wrapping."""
1076
1077  @abc.abstractmethod
1078  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1079    """Map integer sample indices to interior ones using reflect-clamp."""

Abstract base class for grid-types such as 'dual' and 'primal'.

In resampling operations, the grid-type may be specified separately as src_gridtype for the source domain and dst_gridtype for the destination domain. Moreover, the grid-type may be specified per domain dimension.

Examples:

resize(source, shape, gridtype='primal') # Sets both src and dst to be 'primal' grids.

resize(source, shape, src_gridtype=['dual', 'primal'], dst_gridtype='dual') # Source is 'dual' in dim0 and 'primal' in dim1.

GRIDTYPES: list[str] = ['dual', 'primal']

Shortcut names for the two predefined grid types (specified per dimension):

gridtype 'dual'
DualGridtype()
(default)
'primal'
PrimalGridtype()
 
Sample positions in 2D
and in 1D at different resolutions
Dual Primal
Nesting of samples across resolutions The samples positions do not nest. The even samples remain at coarser scale.
Number $N_\ell$ of samples (per-dimension) at resolution level $\ell$ $N_\ell=2^\ell$ $N_\ell=2^\ell+1$
Position of sample index $i$ within domain $[0, 1]$ $\frac{i + 0.5}{N}$ ("half-integer" coordinates) $\frac{i}{N-1}$
Image resolutions ($N_\ell\times N_\ell$) for dyadic scales $1\times1, ~~2\times2, ~~4\times4, ~~8\times8, ~\ldots$ $2\times2, ~~3\times3, ~~5\times5, ~~9\times9, ~\ldots$

See the source code for extensibility.

@dataclasses.dataclass(frozen=True)
class Boundary:
1466@dataclasses.dataclass(frozen=True)
1467class Boundary:
1468  """Domain boundary rules.  These define the reconstruction over the source domain near and beyond
1469  the domain boundaries.  The rules may be specified separately for each domain dimension."""
1470
1471  name: str = ''
1472  """Boundary rule name."""
1473
1474  coord_remap: RemapCoordinates = NoRemapCoordinates()
1475  """Modify specified coordinates prior to evaluating the reconstruction kernels."""
1476
1477  extend_samples: ExtendSamples = ReflectExtendSamples()
1478  """Define the value of each grid sample outside the unit domain as an affine combination of
1479  interior sample(s) and possibly the constant value (`cval`)."""
1480
1481  override_value: OverrideExteriorValue = NoOverrideExteriorValue()
1482  """Set the value outside some extent to a constant value (`cval`)."""
1483
1484  @property
1485  def uses_cval(self) -> bool:
1486    """True if weights may be non-affine, involving the constant value (`cval`)."""
1487    return self.extend_samples.uses_cval or self.override_value.uses_cval
1488
1489  def preprocess_coordinates(self, point: _NDArray, /) -> _NDArray:
1490    """Modify coordinates prior to evaluating the filter kernels."""
1491    # Antialiasing across the tile boundaries may be feasible but seems hard.
1492    point = self.coord_remap(point)
1493    return point
1494
1495  def apply(
1496      self, index: _NDArray, weight: _NDArray, point: _NDArray, size: int, gridtype: Gridtype, /
1497  ) -> tuple[_NDArray, _NDArray]:
1498    """Replace exterior samples by combinations of interior samples."""
1499    index, weight = self.extend_samples(index, weight, size, gridtype)
1500    self.override_reconstruction(weight, point)
1501    return index, weight
1502
1503  def override_reconstruction(self, weight: _NDArray, point: _NDArray, /) -> None:
1504    """For points outside an extent, modify weight to zero to assign `cval`."""
1505    self.override_value(weight, point)

Domain boundary rules. These define the reconstruction over the source domain near and beyond the domain boundaries. The rules may be specified separately for each domain dimension.

BOUNDARIES: list[str] = ['reflect', 'wrap', 'tile', 'clamp', 'border', 'natural', 'linear_constant', 'quadratic_constant', 'reflect_clamp', 'constant', 'linear', 'quadratic']

Shortcut names for some predefined boundary rules (as defined by _DICT_BOUNDARIES):

name a.k.a. / comments
'reflect' reflected, symm, symmetric, mirror, grid-mirror
'wrap' periodic, repeat, grid-wrap
'tile' like 'reflect' within unit domain, then tile discontinuously
'clamp' clamped, nearest, edge, clamp-to-edge, repeat last sample
'border' grid-constant, use cval for samples outside unit domain
'natural' renormalize using only interior samples, use cval outside domain
'reflect_clamp' mirror-clamp-to-edge
'constant' like 'reflect' but replace by cval outside unit domain
'linear' extrapolate from 2 last samples
'quadratic' extrapolate from 3 last samples
'linear_constant' like 'linear' but replace by cval outside unit domain
'quadratic_constant' like 'quadratic' but replace by cval outside unit domain

These boundary rules may be specified per dimension. See the source code for extensibility using the classes RemapCoordinates, ExtendSamples, and OverrideExteriorValue.

Boundary rules illustrated in 1D:

Boundary rules illustrated in 2D:

@dataclasses.dataclass(frozen=True)
class Filter(abc.ABC):
1586@dataclasses.dataclass(frozen=True)
1587class Filter(abc.ABC):
1588  """Abstract base class for filter kernel functions.
1589
1590  Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support
1591  interval [-radius, radius].  (Some sites instead define kernels over the interval [0, N]
1592  where N = 2 * radius.)
1593
1594  Portions of this code are adapted from the C++ library in
1595  https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp
1596
1597  See also https://hhoppe.com/proj/filtering/.
1598  """
1599
1600  name: str
1601  """Filter kernel name."""
1602
1603  radius: float
1604  """Max absolute value of x for which self(x) is nonzero."""
1605
1606  interpolating: bool = True
1607  """True if self(0) == 1.0 and self(i) == 0.0 for all nonzero integers i."""
1608
1609  continuous: bool = True
1610  """True if the kernel function has $C^0$ continuity."""
1611
1612  partition_of_unity: bool = True
1613  """True if the convolution of the kernel with a Dirac comb reproduces the
1614  unity function."""
1615
1616  unit_integral: bool = True
1617  """True if the integral of the kernel function is 1."""
1618
1619  requires_digital_filter: bool = False
1620  """True if the filter needs a pre/post digital filter for interpolation."""
1621
1622  @abc.abstractmethod
1623  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1624    """Return evaluation of filter kernel at locations x."""

Abstract base class for filter kernel functions.

Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support interval [-radius, radius]. (Some sites instead define kernels over the interval [0, N] where N = 2 * radius.)

Portions of this code are adapted from the C++ library in https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp

See also https://hhoppe.com/proj/filtering/.

class ImpulseFilter(Filter):
1627class ImpulseFilter(Filter):
1628  """See https://en.wikipedia.org/wiki/Dirac_delta_function."""
1629
1630  def __init__(self) -> None:
1631    super().__init__(name='impulse', radius=1e-20, continuous=False, partition_of_unity=False)
1632
1633  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1634    raise AssertionError('The Impulse is infinitely narrow, so cannot be directly evaluated.')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class BoxFilter(Filter):
1637class BoxFilter(Filter):
1638  """See https://en.wikipedia.org/wiki/Box_function.
1639
1640  The kernel function has value 1.0 over the half-open interval [-.5, .5).
1641  """
1642
1643  def __init__(self) -> None:
1644    super().__init__(name='box', radius=0.5, continuous=False)
1645
1646  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1647    use_asymmetric = True
1648    if use_asymmetric:
1649      x = np.asarray(x)
1650      return np.where((-0.5 <= x) & (x < 0.5), 1.0, 0.0)
1651    x = np.abs(x)
1652    return np.where(x < 0.5, 1.0, np.where(x == 0.5, 0.5, 0.0))

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class TrapezoidFilter(Filter):
1655class TrapezoidFilter(Filter):
1656  """Filter for antialiased "area-based" filtering.
1657
1658  Args:
1659    radius: Specifies the support [-radius, radius] of the filter, where 0.5 < radius <= 1.0.
1660      The special case `radius = None` is a placeholder that indicates that the filter will be
1661      replaced by a trapezoid of the appropriate radius (based on scaling) for correct
1662      antialiasing in both minification and magnification.
1663
1664  This filter is similar to the BoxFilter but with linearly sloped sides.  It has value 1.0
1665  in the interval abs(x) <= 1.0 - radius and decreases linearly to value 0.0 in the interval
1666  1.0 - radius <= abs(x) <= radius, always with value 0.5 at x = 0.5.
1667  """
1668
1669  def __init__(self, *, radius: float | None = None) -> None:
1670    if radius is None:
1671      super().__init__(name='trapezoid', radius=0.0)
1672      return
1673    if not 0.5 < radius <= 1.0:
1674      raise ValueError(f'Radius {radius} is outside the range (0.5, 1.0].')
1675    super().__init__(name=f'trapezoid_{radius}', radius=radius)
1676
1677  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1678    x = np.abs(x)
1679    assert 0.5 < self.radius <= 1.0
1680    return ((0.5 + 0.25 / (self.radius - 0.5)) - (0.5 / (self.radius - 0.5)) * x).clip(0.0, 1.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class TriangleFilter(Filter):
1683class TriangleFilter(Filter):
1684  """See https://en.wikipedia.org/wiki/Triangle_function.
1685
1686  Also known as the hat or tent function.  It is used for piecewise-linear
1687  (or bilinear, or trilinear, ...) interpolation.
1688  """
1689
1690  def __init__(self) -> None:
1691    super().__init__(name='triangle', radius=1.0)
1692
1693  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1694    return (1.0 - np.abs(x)).clip(0.0, 1.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CubicFilter(Filter):
1697class CubicFilter(Filter):
1698  """Family of cubic filters parameterized by two scalar parameters.
1699
1700  Args:
1701    b: first scalar parameter.
1702    c: second scalar parameter.
1703
1704  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1705  https://doi.org/10.1145/378456.378514.
1706
1707  [D. P. Mitchell and A. N. Netravali. Reconstruction filters in computer graphics.
1708  Computer Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1709
1710  - The filter has quadratic precision iff b + 2 * c == 1.
1711  - The filter is interpolating iff b == 0.
1712  - (b=1, c=0) is the (non-interpolating) cubic B-spline basis;
1713  - (b=1/3, c=1/3) is the Mitchell filter;
1714  - (b=0, c=0.5) is the Catmull-Rom spline (which has cubic precision);
1715  - (b=0, c=0.75) is the "sharper cubic" used in Photoshop and OpenCV.
1716  """
1717
1718  def __init__(self, *, b: float, c: float, name: str | None = None) -> None:
1719    name = f'cubic_b{b}_c{c}' if name is None else name
1720    interpolating = b == 0
1721    super().__init__(name=name, radius=2.0, interpolating=interpolating)
1722    self.b, self.c = b, c
1723
1724  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1725    x = np.abs(x)
1726    b, c = self.b, self.c
1727    f3, f2, f0 = 2 - 9 / 6 * b - c, -3 + 2 * b + c, 1 - 1 / 3 * b
1728    g3, g2, g1, g0 = -b / 6 - c, b + 5 * c, -2 * b - 8 * c, 8 / 6 * b + 4 * c
1729    # (np.polynomial.polynomial.polyval(x, [f0, 0, f2, f3]) is almost
1730    # twice as slow; see also https://stackoverflow.com/questions/24065904)
1731    v01 = ((f3 * x + f2) * x) * x + f0
1732    v12 = ((g3 * x + g2) * x + g1) * x + g0
1733    return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CatmullRomFilter(CubicFilter):
1736class CatmullRomFilter(CubicFilter):
1737  """Cubic filter with cubic precision.  Also known as Keys filter.
1738
1739  [E. Catmull, R. Rom.  A class of local interpolating splines.  Computer aided geometric
1740  design, 1974]
1741  [Wikipedia](https://en.wikipedia.org/wiki/Cubic_Hermite_spline#Catmull%E2%80%93Rom_spline)
1742
1743  [R. G. Keys.  Cubic convolution interpolation for digital image processing.
1744  IEEE Trans. on Acoustics, Speech, and Signal Processing, 29(6), 1981.]
1745  https://ieeexplore.ieee.org/document/1163711/.
1746  """
1747
1748  def __init__(self) -> None:
1749    super().__init__(b=0, c=0.5, name='cubic')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class MitchellFilter(CubicFilter):
1752class MitchellFilter(CubicFilter):
1753  """See https://doi.org/10.1145/378456.378514.
1754
1755  [D. P. Mitchell and A. N. Netravali.  Reconstruction filters in computer graphics.  Computer
1756  Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1757  """
1758
1759  def __init__(self) -> None:
1760    super().__init__(b=1 / 3, c=1 / 3, name='mitchell')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class SharpCubicFilter(CubicFilter):
1763class SharpCubicFilter(CubicFilter):
1764  """Cubic filter that is sharper than Catmull-Rom filter.
1765
1766  Used by some tools including OpenCV and Photoshop.
1767
1768  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1769  https://entropymine.com/resamplescope/notes/photoshop/.
1770  """
1771
1772  def __init__(self) -> None:
1773    super().__init__(b=0, c=0.75, name='sharpcubic')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class LanczosFilter(Filter):
1776class LanczosFilter(Filter):
1777  """High-quality filter: sinc function modulated by a sinc window.
1778
1779  Args:
1780    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1781    sampled: If True, use a discretized approximation for improved speed.
1782
1783  See https://en.wikipedia.org/wiki/Lanczos_kernel.
1784  """
1785
1786  def __init__(self, *, radius: int, sampled: bool = True) -> None:
1787    super().__init__(
1788        name=f'lanczos_{radius}', radius=radius, partition_of_unity=False, unit_integral=False
1789    )
1790
1791    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1792    def _eval(x: _ArrayLike) -> _NDArray:
1793      x = np.abs(x)
1794      # Note that window[n] = sinc(2*n/N - 1), with 0 <= n <= N.
1795      # But, x = n - N/2, or equivalently, n = x + N/2, with -N/2 <= x <= N/2.
1796      window = _sinc(x / radius)  # Zero-phase function w_0(x).
1797      return np.where(x < radius, _sinc(x) * window, 0.0)
1798
1799    self._function = _eval
1800
1801  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1802    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class GeneralizedHammingFilter(Filter):
1805class GeneralizedHammingFilter(Filter):
1806  """Sinc function modulated by a Hamming window.
1807
1808  Args:
1809    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1810    a0: Scalar parameter, where 0.0 < a0 < 1.0.  The case of a0=0.5 is the Hann filter.
1811
1812  See https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows,
1813  and hamming() in https://github.com/scipy/scipy/blob/main/scipy/signal/windows/_windows.py.
1814
1815  Note that `'hamming3'` is `(radius=3, a0=25/46)`, which close to but different from `a0=0.54`.
1816
1817  See also np.hamming() and np.hanning().
1818  """
1819
1820  def __init__(self, *, radius: int, a0: float) -> None:
1821    super().__init__(
1822        name=f'hamming_{radius}',
1823        radius=radius,
1824        partition_of_unity=False,  # 1:1.00242  av=1.00188  sd=0.00052909
1825        unit_integral=False,  # 1.00188
1826    )
1827    assert 0.0 < a0 < 1.0
1828    self.a0 = a0
1829
1830  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1831    x = np.abs(x)
1832    # Note that window[n] = a0 - (1 - a0) * cos(2 * pi * n / N), 0 <= n <= N.
1833    # With n = x + N/2, we get the zero-phase function w_0(x):
1834    window = self.a0 + (1.0 - self.a0) * np.cos(np.pi / self.radius * x)
1835    return np.where(x < self.radius, _sinc(x) * window, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class KaiserFilter(Filter):
1838class KaiserFilter(Filter):
1839  """Sinc function modulated by a Kaiser-Bessel window.
1840
1841  See https://en.wikipedia.org/wiki/Kaiser_window, and example use in:
1842  [Karras et al. 20201.  Alias-free generative adversarial networks.
1843  https://arxiv.org/pdf/2106.12423.pdf].
1844
1845  See also np.kaiser().
1846
1847  Args:
1848    radius: Value L/2 in the definition.  It may be fractional for a (digital) resizing filter
1849      (sample spacing s != 1) with an even number of samples (dual grid), e.g., Eq. (6)
1850      in [Karras et al. 2021] --- this effects the precise shape of the window function.
1851    beta: Determines the trade-off between main-lobe width and side-lobe level.
1852    sampled: If True, use a discretized approximation for improved speed.
1853  """
1854
1855  def __init__(self, *, radius: float, beta: float, sampled: bool = True) -> None:
1856    assert beta >= 0.0
1857    super().__init__(
1858        name=f'kaiser_{radius}_{beta}', radius=radius, partition_of_unity=False, unit_integral=False
1859    )
1860
1861    @_cache_sampled_1d_function(xmin=-math.ceil(radius), xmax=math.ceil(radius), enable=sampled)
1862    def _eval(x: _ArrayLike) -> _NDArray:
1863      x = np.abs(x)
1864      window = np.i0(beta * np.sqrt((1.0 - np.square(x / radius)).clip(0.0, 1.0))) / np.i0(beta)
1865      return np.where(x <= radius + 1e-6, _sinc(x) * window, 0.0)
1866
1867    self._function = _eval
1868
1869  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1870    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class BsplineFilter(Filter):
1873class BsplineFilter(Filter):
1874  """B-spline of a non-negative degree.
1875
1876  Args:
1877    degree: The polynomial degree of the B-spline segments.
1878      With `degree=0`, it is like `BoxFilter` except with f(0.5) = f(-0.5) = 0.
1879      With `degree=1`, it is identical to `TriangleFilter`.
1880      With `degree >= 2`, it is no longer interpolating.
1881
1882  See [Carl de Boor.  A practical guide to splines.  Springer, 2001.]
1883  https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.BSpline.html
1884  """
1885
1886  def __init__(self, *, degree: int) -> None:
1887    if degree < 0:
1888      raise ValueError(f'Bspline of degree {degree} is invalid.')
1889    radius = (degree + 1) / 2
1890    interpolating = degree <= 1
1891    super().__init__(name=f'bspline{degree}', radius=radius, interpolating=interpolating)
1892    t = list(range(degree + 2))
1893    self._bspline = scipy.interpolate.BSpline.basis_element(t)
1894
1895  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1896    x = np.abs(x)
1897    return np.where(x < self.radius, self._bspline(x + self.radius), 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CardinalBsplineFilter(Filter):
1900class CardinalBsplineFilter(Filter):
1901  """Interpolating B-spline, achieved with aid of digital pre or post filter.
1902
1903  Args:
1904    degree: The polynomial degree of the B-spline segments.
1905    sampled: If True, use a discretized approximation for improved speed.
1906
1907  See [Hou and Andrews.  Cubic splines for image interpolation and digital filtering, 1978] and
1908  [Unser et al.  Fast B-spline transforms for continuous image representation and interpolation,
1909  1991].
1910  """
1911
1912  def __init__(self, *, degree: int, sampled: bool = True) -> None:
1913    self.degree = degree
1914    if degree < 0:
1915      raise ValueError(f'Bspline of degree {degree} is invalid.')
1916    radius = (degree + 1) / 2
1917    super().__init__(
1918        name=f'cardinal{degree}',
1919        radius=radius,
1920        requires_digital_filter=degree >= 2,
1921        continuous=degree >= 1,
1922    )
1923    t = list(range(degree + 2))
1924    bspline = scipy.interpolate.BSpline.basis_element(t)
1925
1926    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1927    def _eval(x: _ArrayLike) -> _NDArray:
1928      x = np.abs(x)
1929      return np.where(x < radius, bspline(x + radius), 0.0)
1930
1931    self._function = _eval
1932
1933  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1934    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class OmomsFilter(Filter):
1937class OmomsFilter(Filter):
1938  """OMOMS interpolating filter, with aid of digital pre or post filter.
1939
1940  Args:
1941    degree: The polynomial degree of the filter segments.
1942
1943  Optimal MOMS (maximal-order-minimal-support) function; see [Blu and Thevenaz, MOMS: Maximal-order
1944  interpolation of minimal support, 2001].
1945  https://infoscience.epfl.ch/record/63074/files/blu0101.pdf
1946  """
1947
1948  def __init__(self, *, degree: int) -> None:
1949    if degree not in (3, 5):
1950      raise ValueError(f'Degree {degree} not supported.')
1951    super().__init__(name=f'omoms{degree}', radius=(degree + 1) / 2, requires_digital_filter=True)
1952    self.degree = degree
1953
1954  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1955    x = np.abs(x)
1956    match self.degree:
1957      case 3:
1958        v01 = ((0.5 * x - 1.0) * x + 3 / 42) * x + 26 / 42
1959        v12 = ((-7 / 42 * x + 1.0) * x - 85 / 42) * x + 58 / 42
1960        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1961      case 5:
1962        v01 = ((((-1 / 12 * x + 1 / 4) * x - 5 / 99) * x - 9 / 22) * x - 1 / 792) * x + 229 / 440
1963        v12 = (
1964            (((1 / 24 * x - 3 / 8) * x + 505 / 396) * x - 83 / 44) * x + 1351 / 1584
1965        ) * x + 839 / 2640
1966        v23 = (
1967            (((-1 / 120 * x + 1 / 8) * x - 299 / 396) * x + 101 / 44) * x - 27811 / 7920
1968        ) * x + 5707 / 2640
1969        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, np.where(x < 3.0, v23, 0.0)))
1970      case _:
1971        raise ValueError(self.degree)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class GaussianFilter(Filter):
1974class GaussianFilter(Filter):
1975  r"""See https://en.wikipedia.org/wiki/Gaussian_function.
1976
1977  Args:
1978    standard_deviation: Sets the Gaussian $\sigma$.  The default value is 1.25/3.0, which
1979      creates a kernel that is as-close-as-possible to a partition of unity.
1980  """
1981
1982  DEFAULT_STANDARD_DEVIATION = 1.25 / 3.0
1983  """This value creates a kernel that is as-close-as-possible to a partition of unity; see
1984  mesh_processing/test/GridOp_test.cpp: `0.93503:1.06497     av=1           sd=0.0459424`.
1985  Another possibility is 0.5, as suggested on p. 4 of [Ken Turkowski.  Filters for common
1986  resampling tasks, 1990] for kernels with a support of 3 pixels.
1987  https://cadxfem.org/inf/ResamplingFilters.pdf
1988  """
1989
1990  def __init__(self, *, standard_deviation: float = DEFAULT_STANDARD_DEVIATION) -> None:
1991    super().__init__(
1992        name=f'gaussian_{standard_deviation:.3f}',
1993        radius=np.ceil(8.0 * standard_deviation),  # Sufficiently large.
1994        interpolating=False,
1995        partition_of_unity=False,
1996    )
1997    self.standard_deviation = standard_deviation
1998
1999  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2000    x = np.abs(x)
2001    sdv = self.standard_deviation
2002    v0r = np.exp(np.square(x / sdv) / -2.0) / (np.sqrt(math.tau) * sdv)
2003    return np.where(x < self.radius, v0r, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class NarrowBoxFilter(Filter):
2006class NarrowBoxFilter(Filter):
2007  """Compact footprint, used for visualization of grid sample location.
2008
2009  Args:
2010    radius: Specifies the support [-radius, radius] of the narrow box function.  (The default
2011      value 0.199 is an inexact 0.2 to avoid numerical ambiguities.)
2012  """
2013
2014  def __init__(self, *, radius: float = 0.199) -> None:
2015    super().__init__(
2016        name='narrowbox',
2017        radius=radius,
2018        continuous=False,
2019        unit_integral=False,
2020        partition_of_unity=False,
2021    )
2022
2023  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2024    radius = self.radius
2025    magnitude = 1.0
2026    x = np.asarray(x)
2027    return np.where((-radius <= x) & (x < radius), magnitude, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

FILTERS: list[str] = ['impulse', 'box', 'trapezoid', 'triangle', 'cubic', 'sharpcubic', 'lanczos3', 'lanczos5', 'lanczos10', 'cardinal3', 'cardinal5', 'omoms3', 'omoms5', 'hamming3', 'kaiser3', 'gaussian', 'bspline3', 'mitchell', 'narrowbox']

Shortcut names for some predefined filter kernels (specified per dimension). The names expand to:

name Filter a.k.a. / comments
'impulse' ImpulseFilter() nearest
'box' BoxFilter() non-antialiased box, e.g. ImageMagick
'trapezoid' TrapezoidFilter() area antialiasing, e.g. cv.INTER_AREA
'triangle' TriangleFilter() linear (bilinear in 2D), spline order=1
'cubic' CatmullRomFilter() catmullrom, keys, bicubic
'sharpcubic' SharpCubicFilter() cv.INTER_CUBIC, torch 'bicubic'
'lanczos3' LanczosFilter(radius=3) support window [-3, 3]
'lanczos5' LanczosFilter(radius=5) [-5, 5]
'lanczos10' LanczosFilter(radius=10) [-10, 10]
'cardinal3' CardinalBsplineFilter(degree=3) spline interpolation, order=3, GF
'cardinal5' CardinalBsplineFilter(degree=5) spline interpolation, order=5, GF
'omoms3' OmomsFilter(degree=3) non-$C^1$, [-3, 3], GF
'omoms5' OmomsFilter(degree=5) non-$C^1$, [-5, 5], GF
'hamming3' GeneralizedHammingFilter(...) (radius=3, a0=25/46)
'kaiser3' KaiserFilter(radius=3.0, beta=7.12)
'gaussian' GaussianFilter() non-interpolating, default $\sigma=1.25/3$
'bspline3' BsplineFilter(degree=3) non-interpolating
'mitchell' MitchellFilter() mitchellcubic
'narrowbox' NarrowBoxFilter() for visualization of sample positions

The comment label GF denotes a generalized filter, formed as the composition of a finitely supported kernel and a discrete inverse convolution.

Some example filter kernels:


A more extensive set of filters is presented here in the notebook, together with visualizations and analyses of the filter properties. See the source code for extensibility.

@dataclasses.dataclass(frozen=True)
class Gamma(abc.ABC):
2135@dataclasses.dataclass(frozen=True)
2136class Gamma(abc.ABC):
2137  """Abstract base class for transfer functions on sample values.
2138
2139  Image/video content is often stored using a color component transfer function.
2140  See https://en.wikipedia.org/wiki/Gamma_correction.
2141
2142  Converts between integer types and [0.0, 1.0] internal value range.
2143  """
2144
2145  name: str
2146  """Name of component transfer function."""
2147
2148  @abc.abstractmethod
2149  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2150    """Decode source sample values into floating-point, possibly nonlinearly.
2151
2152    Uint source values are mapped to the range [0.0, 1.0].
2153    """
2154
2155  @abc.abstractmethod
2156  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2157    """Encode float signal into destination samples, possibly nonlinearly.
2158
2159    Uint destination values are mapped from the range [0.0, 1.0].
2160
2161    Note that non-integer destination types are not clipped to the range [0.0, 1.0].
2162    If that is desired, it can be performed as a postprocess using `output.clip(0.0, 1.0)`.
2163    """

Abstract base class for transfer functions on sample values.

Image/video content is often stored using a color component transfer function. See https://en.wikipedia.org/wiki/Gamma_correction.

Converts between integer types and [0.0, 1.0] internal value range.

class IdentityGamma(Gamma):
2166class IdentityGamma(Gamma):
2167  """Identity component transfer function."""
2168
2169  def __init__(self) -> None:
2170    super().__init__('identity')
2171
2172  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2173    dtype = np.dtype(dtype)
2174    assert np.issubdtype(dtype, np.inexact)
2175    if np.issubdtype(_arr_dtype(array), np.unsignedinteger):
2176      return _to_float_01(array, dtype)
2177    return _arr_astype(array, dtype)
2178
2179  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2180    dtype = np.dtype(dtype)
2181    assert np.issubdtype(dtype, np.number)
2182    if np.issubdtype(dtype, np.unsignedinteger):
2183      return _from_float(_arr_clip(array, 0.0, 1.0), dtype)
2184    if np.issubdtype(dtype, np.integer):
2185      return _arr_astype(array + 0.5, dtype)
2186    return _arr_astype(array, dtype)

Gamma(name: 'str')

GAMMAS: list[str] = ['identity', 'power2', 'power22', 'srgb']

Shortcut names for some predefined gamma-correction schemes:

name Gamma Decoding function
(linear space from stored value)
Encoding function
(stored value from linear space)
'identity' IdentityGamma() $l = e$ $e = l$
'power2' PowerGamma(2.0) $l = e^{2.0}$ $e = l^{1/2.0}$
'power22' PowerGamma(2.2) $l = e^{2.2}$ $e = l^{1/2.2}$
'srgb' SrgbGamma() $l = \left(\left(e + 0.055\right) / 1.055\right)^{2.4}$ $e = l^{1/2.4} * 1.055 - 0.055$

See the source code for extensibility.

def resize( array: Array, /, shape: Iterable[int], *, gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, boundary: str | Boundary | Iterable[str | Boundary] = 'auto', cval: ArrayLike = 0, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, gamma: str | Gamma | None = None, src_gamma: str | Gamma | None = None, dst_gamma: str | Gamma | None = None, scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0, precision: DTypeLike = None, dtype: DTypeLike = None, dim_order: Iterable[int] | None = None, num_threads: Union[int, Literal['auto']] = 'auto') -> Array:
2598def resize(
2599    array: _Array,
2600    /,
2601    shape: Iterable[int],
2602    *,
2603    gridtype: str | Gridtype | None = None,
2604    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2605    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2606    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2607    cval: _ArrayLike = 0.0,
2608    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2609    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2610    gamma: str | Gamma | None = None,
2611    src_gamma: str | Gamma | None = None,
2612    dst_gamma: str | Gamma | None = None,
2613    scale: float | Iterable[float] = 1.0,
2614    translate: float | Iterable[float] = 0.0,
2615    precision: _DTypeLike = None,
2616    dtype: _DTypeLike = None,
2617    dim_order: Iterable[int] | None = None,
2618    num_threads: int | Literal['auto'] = 'auto',
2619) -> _Array:
2620  """Resample `array` (a grid of sample values) onto a grid with resolution `shape`.
2621
2622  The source `array` is any object recognized by `ARRAYLIBS`.  It is interpreted as a grid
2623  with `len(shape)` domain coordinate dimensions, where each grid sample value has shape
2624  `array.shape[len(shape):]`.
2625
2626  Some examples:
2627
2628  - A grayscale image has `array.shape = height, width` and resizing it with `len(shape) == 2`
2629    produces a new image of scalar values.
2630  - An RGB image has `array.shape = height, width, 3` and resizing it with `len(shape) == 2`
2631    produces a new image of RGB values.
2632  - An 3D grid of 3x3 Jacobians has `array.shape = Z, Y, X, 3, 3` and resizing it with
2633    `len(shape) == 3` produces a new 3D grid of Jacobians.
2634
2635  This function also allows scaling and translation from the source domain to the output domain
2636  through the parameters `scale` and `translate`.  For more general transforms, see `resample`.
2637
2638  Args:
2639    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
2640      The array must have numeric type.  Its first `len(shape)` dimensions are the domain
2641      coordinate dimensions.  Each grid dimension must be at least 1 for a `'dual'` grid or
2642      at least 2 for a `'primal'` grid.
2643    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2644      `array` must have at least as many dimensions as `len(shape)`.
2645    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
2646      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
2647      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
2648    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2649      Parameters `gridtype` and `src_gridtype` cannot both be set.
2650    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2651      Parameters `gridtype` and `dst_gridtype` cannot both be set.
2652    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2653      a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses `'reflect'`
2654      for upsampling and `'clamp'` for downsampling.
2655    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
2656      onto `array.shape[len(shape):]`.  It is subject to `src_gamma`.
2657    filter: The reconstruction kernel for each dimension in `shape`, specified as either a filter
2658      name in `FILTERS` or a `Filter` instance.  It is used during upsampling (i.e., magnification).
2659    prefilter: The prefilter kernel for each dimension in `shape`, specified as either a filter
2660      name in `FILTERS` or a `Filter` instance.  It is used during downsampling
2661      (i.e., minification).  If `None`, it inherits the value of `filter`.  The default
2662      `'lanczos3'` is good for natural images.  For vector graphics images, `'trapezoid'` is better
2663      because it avoids ringing artifacts.
2664    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from
2665      `array` and when creating output grid samples.  It is specified as either a name in `GAMMAS`
2666      or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default is
2667      `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
2668      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
2669      range [0.0, 1.0].
2670    src_gamma: Component transfer function used to "decode" `array` samples.
2671      Parameters `gamma` and `src_gamma` cannot both be set.
2672    dst_gamma: Component transfer function used to "encode" the output samples.
2673      Parameters `gamma` and `dst_gamma` cannot both be set.
2674    scale: Scaling factor applied to each dimension of the source domain when it is mapped onto
2675      the destination domain.
2676    translate: Offset applied to each dimension of the scaled source domain when it is mapped onto
2677      the destination domain.
2678    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
2679      on `array.dtype` and `dtype`.
2680    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
2681      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
2682      to the uint range.
2683    dim_order: Override the automatically selected order in which the grid dimensions are resized.
2684      Must contain a permutation of `range(len(shape))`.
2685    num_threads: Used to determine multithread parallelism if `array` is from `numpy`.  If set to
2686      `'auto'`, it is selected automatically.  Otherwise, it must be a positive integer.
2687
2688  Returns:
2689    An array of the same class as the source `array`, with shape `shape + array.shape[len(shape):]`
2690      and data type `dtype`.
2691
2692  **Example of image upsampling:**
2693
2694  >>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
2695  >>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.
2696
2697  <center>
2698  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_upsampled.png"/>
2699  </center>
2700
2701  **Example of image downsampling:**
2702
2703  >>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
2704  >>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
2705  >>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
2706  >>> downsampled = resize(array, (24, 48))
2707
2708  <center>
2709  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_downsampled2.png"/>
2710  </center>
2711
2712  **Unit test:**
2713
2714  >>> result = resize([1.0, 4.0, 5.0], shape=(4,))
2715  >>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
2716  """
2717  if isinstance(array, (tuple, list)):
2718    array = np.asarray(array)
2719  arraylib = _arr_arraylib(array)
2720  array_dtype = _arr_dtype(array)
2721  if not np.issubdtype(array_dtype, np.number):
2722    raise ValueError(f'Type {array.dtype} is not numeric.')
2723  shape2 = tuple(shape)
2724  array_ndim = len(array.shape)
2725  if not 0 < len(shape2) <= array_ndim:
2726    raise ValueError(f'Shape {array.shape} cannot be resized to {shape2}.')
2727  src_shape = array.shape[: len(shape2)]
2728  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2729      gridtype, src_gridtype, dst_gridtype, len(shape2), len(shape2)
2730  )
2731  boundary2 = np.broadcast_to(np.array(boundary), len(shape2))
2732  cval = np.broadcast_to(cval, array.shape[len(shape2) :])
2733  prefilter = filter if prefilter is None else prefilter
2734  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape2))]
2735  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), len(shape2))]
2736  dtype = array_dtype if dtype is None else np.dtype(dtype)
2737  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, array_dtype, dtype)
2738  scale2 = np.broadcast_to(np.array(scale), len(shape2))
2739  translate2 = np.broadcast_to(np.array(translate), len(shape2))
2740  del (shape, src_gridtype, dst_gridtype, boundary, filter, prefilter)
2741  del (src_gamma, dst_gamma, scale, translate)
2742  precision = _get_precision(precision, [array_dtype, dtype], [])
2743  weight_precision = _real_precision(precision)
2744
2745  is_noop = (
2746      all(src == dst for src, dst in zip(src_shape, shape2))
2747      and all(gt1 == gt2 for gt1, gt2 in zip(src_gridtype2, dst_gridtype2))
2748      and all(f.interpolating for f in prefilter2)
2749      and np.all(scale2 == 1.0)
2750      and np.all(translate2 == 0.0)
2751      and src_gamma2 == dst_gamma2
2752  )
2753  if is_noop:
2754    return array
2755
2756  if dim_order is None:
2757    dim_order = _arr_best_dims_order_for_resize(array, shape2)
2758  else:
2759    dim_order = tuple(dim_order)
2760    if sorted(dim_order) != list(range(len(shape2))):
2761      raise ValueError(f'{dim_order} not a permutation of {list(range(len(shape2)))}.')
2762
2763  array = src_gamma2.decode(array, precision)
2764  cval = _arr_numpy(src_gamma2.decode(cval, precision))
2765
2766  can_use_fast_box_downsampling = (
2767      using_numba
2768      and arraylib == 'numpy'
2769      and len(shape2) == 2
2770      and array_ndim in (2, 3)
2771      and all(src > dst for src, dst in zip(src_shape, shape2))
2772      and all(src % dst == 0 for src, dst in zip(src_shape, shape2))
2773      and all(gridtype.name == 'dual' for gridtype in src_gridtype2)
2774      and all(gridtype.name == 'dual' for gridtype in dst_gridtype2)
2775      and all(f.name in ('box', 'trapezoid') for f in prefilter2)
2776      and np.all(scale2 == 1.0)
2777      and np.all(translate2 == 0.0)
2778  )
2779  if can_use_fast_box_downsampling:
2780    assert isinstance(array, np.ndarray)  # Help mypy.
2781    array = _downsample_in_2d_using_box_filter(array, cast(Any, shape2))
2782    array = dst_gamma2.encode(array, dtype)
2783    return array
2784
2785  # Multidimensional resize can be expressed using einsum() with multiple per-dim resize matrices,
2786  # e.g., as in jax.image.resize().  A benefit is to seek the optimal order of multiplications.
2787  # However, efficiency often requires sparse resize matrices, which are unsupported in einsum().
2788  # Sparse tensors requested for tf.einsum: https://github.com/tensorflow/tensorflow/issues/43497
2789  # https://github.com/tensor-compiler/taco: C++ library that computes tensor algebra expressions
2790  # on sparse and dense tensors; however it does not interoperate with tensorflow, torch, or jax.
2791
2792  for dim in dim_order:
2793    skip_resize_on_this_dim = (
2794        shape2[dim] == array.shape[dim]
2795        and scale2[dim] == 1.0
2796        and translate2[dim] == 0.0
2797        and filter2[dim].interpolating
2798    )
2799    if skip_resize_on_this_dim:
2800      continue
2801
2802    def get_is_minification() -> bool:
2803      src_in_samples = src_gridtype2[dim].size_in_samples(array.shape[dim])
2804      dst_in_samples = dst_gridtype2[dim].size_in_samples(shape2[dim])
2805      return dst_in_samples / src_in_samples * scale2[dim] < 1.0
2806
2807    is_minification = get_is_minification()
2808    boundary_dim = boundary2[dim]
2809    if boundary_dim == 'auto':
2810      boundary_dim = 'clamp' if is_minification else 'reflect'
2811    boundary_dim = _get_boundary(boundary_dim)
2812    resize_matrix, cval_weight = _create_resize_matrix(
2813        array.shape[dim],
2814        shape2[dim],
2815        src_gridtype=src_gridtype2[dim],
2816        dst_gridtype=dst_gridtype2[dim],
2817        boundary=boundary_dim,
2818        filter=filter2[dim],
2819        prefilter=prefilter2[dim],
2820        scale=scale2[dim],
2821        translate=translate2[dim],
2822        dtype=weight_precision,
2823        arraylib=arraylib,
2824    )
2825
2826    array_dim: _Array = _arr_moveaxis(array, dim, 0)
2827    array_flat = _arr_reshape(array_dim, (array_dim.shape[0], -1))
2828    array_flat = _arr_possibly_make_contiguous(array_flat)
2829    if not is_minification and filter2[dim].requires_digital_filter:
2830      array_flat = _apply_digital_filter_1d(
2831          array_flat, src_gridtype2[dim], boundary_dim, cval, filter2[dim]
2832      )
2833
2834    array_flat = _arr_matmul_sparse_dense(resize_matrix, array_flat, num_threads=num_threads)
2835    if cval_weight is not None:
2836      cval_flat = np.broadcast_to(cval, array_dim.shape[1:]).reshape(-1)
2837      if np.issubdtype(array_dtype, np.complexfloating):
2838        cval_weight = _arr_astype(cval_weight, array_dtype)  # (Only necessary for 'tensorflow'.)
2839      array_flat += cval_weight[:, None] * cval_flat
2840
2841    if is_minification and filter2[dim].requires_digital_filter:  # use prefilter2[dim]?
2842      array_flat = _apply_digital_filter_1d(
2843          array_flat, dst_gridtype2[dim], boundary_dim, cval, filter2[dim]
2844      )
2845    array_dim = _arr_reshape(array_flat, (array_flat.shape[0], *array_dim.shape[1:]))
2846    array = _arr_moveaxis(array_dim, 0, dim)
2847
2848  array = dst_gamma2.encode(array, dtype)
2849  return array

Resample array (a grid of sample values) onto a grid with resolution shape.

The source array is any object recognized by ARRAYLIBS. It is interpreted as a grid with len(shape) domain coordinate dimensions, where each grid sample value has shape array.shape[len(shape):].

Some examples:

  • A grayscale image has array.shape = height, width and resizing it with len(shape) == 2 produces a new image of scalar values.
  • An RGB image has array.shape = height, width, 3 and resizing it with len(shape) == 2 produces a new image of RGB values.
  • An 3D grid of 3x3 Jacobians has array.shape = Z, Y, X, 3, 3 and resizing it with len(shape) == 3 produces a new 3D grid of Jacobians.

This function also allows scaling and translation from the source domain to the output domain through the parameters scale and translate. For more general transforms, see resample.

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. Its first len(shape) dimensions are the domain coordinate dimensions. Each grid dimension must be at least 1 for a 'dual' grid or at least 2 for a 'primal' grid.
  • shape: The number of grid samples in each coordinate dimension of the output array. The source array must have at least as many dimensions as len(shape).
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual' if gridtype, src_gridtype, and dst_gridtype are all kept None.
  • src_gridtype: Placement of the samples in the source domain grid for each dimension. Parameters gridtype and src_gridtype cannot both be set.
  • dst_gridtype: Placement of the samples in the output domain grid for each dimension. Parameters gridtype and dst_gridtype cannot both be set.
  • boundary: The reconstruction boundary rule for each dimension in shape, specified as either a name in BOUNDARIES or a Boundary instance. The special value 'auto' uses 'reflect' for upsampling and 'clamp' for downsampling.
  • cval: Constant value used beyond the samples by some boundary rules. It must be broadcastable onto array.shape[len(shape):]. It is subject to src_gamma.
  • filter: The reconstruction kernel for each dimension in shape, specified as either a filter name in FILTERS or a Filter instance. It is used during upsampling (i.e., magnification).
  • prefilter: The prefilter kernel for each dimension in shape, specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter. The default 'lanczos3' is good for natural images. For vector graphics images, 'trapezoid' is better because it avoids ringing artifacts.
  • gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from array and when creating output grid samples. It is specified as either a name in GAMMAS or a Gamma instance. If both array.dtype and dtype are uint, the default is 'power2'. If both are non-uint, the default is 'identity'. Otherwise, gamma or src_gamma/dst_gamma must be set. Gamma correction assumes that float values are in the range [0.0, 1.0].
  • src_gamma: Component transfer function used to "decode" array samples. Parameters gamma and src_gamma cannot both be set.
  • dst_gamma: Component transfer function used to "encode" the output samples. Parameters gamma and dst_gamma cannot both be set.
  • scale: Scaling factor applied to each dimension of the source domain when it is mapped onto the destination domain.
  • translate: Offset applied to each dimension of the scaled source domain when it is mapped onto the destination domain.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • dim_order: Override the automatically selected order in which the grid dimensions are resized. Must contain a permutation of range(len(shape)).
  • num_threads: Used to determine multithread parallelism if array is from numpy. If set to 'auto', it is selected automatically. Otherwise, it must be a positive integer.
Returns:

An array of the same class as the source array, with shape shape + array.shape[len(shape):] and data type dtype.

Example of image upsampling:

>>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
>>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.

Example of image downsampling:

>>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
>>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
>>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
>>> downsampled = resize(array, (24, 48))

Unit test:

>>> result = resize([1.0, 4.0, 5.0], shape=(4,))
>>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
def jaxjit_resize(array: Array, /, *args: Any, **kwargs: Any) -> Array:
2903def jaxjit_resize(array: _Array, /, *args: Any, **kwargs: Any) -> _Array:
2904  """Compute `resize` but with resize function jitted using Jax."""
2905  return _create_jaxjit_resize()(array, *args, **kwargs)  # pylint: disable=not-callable

Compute resize but with resize function jitted using Jax.

def uniform_resize( array: Array, /, shape: Iterable[int], *, object_fit: Literal['contain', 'cover'] = 'contain', gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, boundary: str | Boundary | Iterable[str | Boundary] = 'border', scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0, **kwargs: Any) -> Array:
2908def uniform_resize(
2909    array: _Array,
2910    /,
2911    shape: Iterable[int],
2912    *,
2913    object_fit: Literal['contain', 'cover'] = 'contain',
2914    gridtype: str | Gridtype | None = None,
2915    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2916    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2917    boundary: str | Boundary | Iterable[str | Boundary] = 'natural',  # Instead of 'auto' default.
2918    scale: float | Iterable[float] = 1.0,
2919    translate: float | Iterable[float] = 0.0,
2920    **kwargs: Any,
2921) -> _Array:
2922  """Resample `array` onto a grid with resolution `shape` but with uniform scaling.
2923
2924  Calls function `resize` with `scale` and `translate` set such that the aspect ratio of `array`
2925  is preserved.  The effect is similar to CSS `object-fit: contain`.
2926  The parameter `boundary` (whose default is changed to `'natural'`) determines the values assigned
2927  outside the source domain.
2928
2929  Args:
2930    array: Regular grid of source sample values.
2931    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2932      `array` must have at least as many dimensions as `len(shape)`.
2933    object_fit: Like CSS `object-fit`.  If `'contain'`, `array` is resized uniformly to fit within
2934      `shape`. If `'cover'`, `array` is resized to fully cover `shape`.
2935    gridtype: Placement of samples on all dimensions of both the source and output domain grids.
2936    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2937    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2938    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2939      a name in `BOUNDARIES` or a `Boundary` instance.  The default is `'natural'`, which assigns
2940      `cval` to output points that map outside the source unit domain.
2941    scale: Parameter may not be specified.
2942    translate: Parameter may not be specified.
2943    **kwargs: Additional parameters for `resize` function (including `cval`).
2944
2945  Returns:
2946    An array with shape `shape + array.shape[len(shape):]`.
2947
2948  >>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
2949  array([[0., 1., 1., 0.],
2950         [0., 1., 1., 0.]])
2951
2952  >>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
2953  array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
2954         [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
2955
2956  >>> a = np.arange(6.0).reshape(2, 3)
2957  >>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
2958  array([[0.5, 1.5],
2959         [3.5, 4.5]])
2960  """
2961  if scale != 1.0 or translate != 0.0:
2962    raise ValueError('`uniform_resize()` does not accept `scale` or `translate` parameters.')
2963  if isinstance(array, (tuple, list)):
2964    array = np.asarray(array)
2965  shape = tuple(shape)
2966  array_ndim = len(array.shape)
2967  if not 0 < len(shape) <= array_ndim:
2968    raise ValueError(f'Shape {array.shape} cannot be resized to {shape}.')
2969  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2970      gridtype, src_gridtype, dst_gridtype, len(shape), len(shape)
2971  )
2972  raw_scales = [
2973      dst_gridtype2[dim].size_in_samples(shape[dim])
2974      / src_gridtype2[dim].size_in_samples(array.shape[dim])
2975      for dim in range(len(shape))
2976  ]
2977  scale0 = {'contain': min(raw_scales), 'cover': max(raw_scales)}[object_fit]
2978  scale2 = scale0 / np.array(raw_scales)
2979  translate = (1.0 - scale2) / 2
2980  return resize(array, shape, boundary=boundary, scale=scale2, translate=translate, **kwargs)

Resample array onto a grid with resolution shape but with uniform scaling.

Calls function resize with scale and translate set such that the aspect ratio of array is preserved. The effect is similar to CSS object-fit: contain. The parameter boundary (whose default is changed to 'natural') determines the values assigned outside the source domain.

Arguments:
  • array: Regular grid of source sample values.
  • shape: The number of grid samples in each coordinate dimension of the output array. The source array must have at least as many dimensions as len(shape).
  • object_fit: Like CSS object-fit. If 'contain', array is resized uniformly to fit within shape. If 'cover', array is resized to fully cover shape.
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids.
  • src_gridtype: Placement of the samples in the source domain grid for each dimension.
  • dst_gridtype: Placement of the samples in the output domain grid for each dimension.
  • boundary: The reconstruction boundary rule for each dimension in shape, specified as either a name in BOUNDARIES or a Boundary instance. The default is 'natural', which assigns cval to output points that map outside the source unit domain.
  • scale: Parameter may not be specified.
  • translate: Parameter may not be specified.
  • **kwargs: Additional parameters for resize function (including cval).
Returns:

An array with shape shape + array.shape[len(shape):].

>>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
array([[0., 1., 1., 0.],
       [0., 1., 1., 0.]])
>>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
       [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
>>> a = np.arange(6.0).reshape(2, 3)
>>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
array([[0.5, 1.5],
       [3.5, 4.5]])
def resample( array: Array, /, coords: ArrayLike, *, gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual', boundary: str | Boundary | Iterable[str | Boundary] = 'auto', cval: ArrayLike = 0, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, gamma: str | Gamma | None = None, src_gamma: str | Gamma | None = None, dst_gamma: str | Gamma | None = None, jacobian: Optional[ArrayLike] = None, precision: DTypeLike = None, dtype: DTypeLike = None, max_block_size: int = 40000, debug: bool = False) -> Array:
2986def resample(
2987    array: _Array,
2988    /,
2989    coords: _ArrayLike,
2990    *,
2991    gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual',
2992    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2993    cval: _ArrayLike = 0.0,
2994    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2995    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2996    gamma: str | Gamma | None = None,
2997    src_gamma: str | Gamma | None = None,
2998    dst_gamma: str | Gamma | None = None,
2999    jacobian: _ArrayLike | None = None,
3000    precision: _DTypeLike = None,
3001    dtype: _DTypeLike = None,
3002    max_block_size: int = 40_000,
3003    debug: bool = False,
3004) -> _Array:
3005  """Interpolate `array` (a grid of samples) at specified unit-domain coordinates `coords`.
3006
3007  The last dimension of `coords` contains unit-domain coordinates at which to interpolate the
3008  domain grid samples in `array`.
3009
3010  The number of coordinates (`coords.shape[-1]`) determines how to interpret `array`: its first
3011  `coords.shape[-1]` dimensions define the grid, and the remaining dimensions describe each grid
3012  sample (e.g., scalar, vector, tensor).
3013
3014  Concretely, the grid has shape `array.shape[:coords.shape[-1]]` and each grid sample has shape
3015  `array.shape[coords.shape[-1]:]`.
3016
3017  Examples include:
3018
3019  - Resample a grayscale image with `array.shape = height, width` onto a new grayscale image with
3020    `new.shape = height2, width2` by using `coords.shape = height2, width2, 2`.
3021
3022  - Resample an RGB image with `array.shape = height, width, 3` onto a new RGB image with
3023    `new.shape = height2, width2, 3` by using `coords.shape = height2, width2, 2`.
3024
3025  - Sample an RGB image at `num` 2D points along a line segment by using `coords.shape = num, 2`.
3026
3027  - Sample an RGB image at a single 2D point by using `coords.shape = (2,)`.
3028
3029  - Sample a 3D grid of 3x3 Jacobians with `array.shape = nz, ny, nx, 3, 3` along a 2D plane by
3030    using `coords.shape = height, width, 3`.
3031
3032  - Map a grayscale image through a color map by using `array.shape = 256, 3` and
3033    `coords.shape = height, width`.
3034
3035  Args:
3036    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3037      The array must have numeric type.  The coordinate dimensions appear first, and
3038      each grid sample may have an arbitrary shape.  Each grid dimension must be at least 1 for
3039      a `'dual'` grid or at least 2 for a `'primal'` grid.
3040    coords: Grid of points at which to resample `array`.  The point coordinates are in the last
3041      dimension of `coords`.  The domain associated with the source grid is a unit hypercube,
3042      i.e. with a range [0, 1] on each coordinate dimension.  The output grid has shape
3043      `coords.shape[:-1]` and each of its grid samples has shape `array.shape[coords.shape[-1]:]`.
3044    gridtype: Placement of the samples in the source domain grid for each dimension, specified as
3045      either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`.
3046    boundary: The reconstruction boundary rule for each dimension in `coords.shape[-1]`, specified
3047      as either a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses
3048      `'reflect'` for upsampling and `'clamp'` for downsampling.
3049    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
3050      onto the shape `array.shape[coords.shape[-1]:]`.  It is subject to `src_gamma`.
3051    filter: The reconstruction kernel for each dimension in `coords.shape[-1]`, specified as either
3052      a filter name in `FILTERS` or a `Filter` instance.
3053    prefilter: The prefilter kernel for each dimension in `coords.shape[:-1]`, specified as either
3054      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3055      (i.e., minification).  If `None`, it inherits the value of `filter`.
3056    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples
3057      from `array` and when creating output grid samples.  It is specified as either a name in
3058      `GAMMAS` or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default
3059      is `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
3060      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
3061      range [0.0, 1.0].
3062    src_gamma: Component transfer function used to "decode" `array` samples.
3063      Parameters `gamma` and `src_gamma` cannot both be set.
3064    dst_gamma: Component transfer function used to "encode" the output samples.
3065      Parameters `gamma` and `dst_gamma` cannot both be set.
3066    jacobian: Optional array, which must be broadcastable onto the shape
3067      `coords.shape[:-1] + (coords.shape[-1], coords.shape[-1])`, storing for each point in the
3068      output grid the Jacobian matrix of the map from the unit output domain to the unit source
3069      domain.  If omitted, it is estimated by computing finite differences on `coords`.
3070    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3071      on `array.dtype`, `coords.dtype`, and `dtype`.
3072    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3073      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3074      to the uint range.
3075    max_block_size: If nonzero, maximum number of grid points in `coords` before the resampling
3076      evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
3077    debug: Show internal information.
3078
3079  Returns:
3080    A new sample grid of shape `coords.shape[:-1]`, represented as an array of shape
3081    `coords.shape[:-1] + array.shape[coords.shape[-1]:]`, of the same array library type as
3082    the source array.
3083
3084  **Example of resample operation:**
3085
3086  <center>
3087  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_warp_coords.png"/>
3088  </center>
3089
3090  For reference, the identity resampling for a scalar-valued grid with the default grid-type
3091  `'dual'` is:
3092
3093  >>> array = np.random.default_rng(0).random((5, 7, 3))
3094  >>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
3095  >>> new_array = resample(array, coords)
3096  >>> assert np.allclose(new_array, array)
3097
3098  It is more efficient to use the function `resize` for the special case where the `coords` are
3099  obtained as simple scaling and translation of a new regular grid over the source domain:
3100
3101  >>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
3102  >>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
3103  >>> coords = (coords - translate) / scale
3104  >>> resampled = resample(array, coords)
3105  >>> resized = resize(array, new_shape, scale=scale, translate=translate)
3106  >>> assert np.allclose(resampled, resized)
3107  """
3108  if isinstance(array, (tuple, list)):
3109    array = np.asarray(array)
3110  arraylib = _arr_arraylib(array)
3111  if len(array.shape) == 0:
3112    array = array[None]
3113  coords = np.atleast_1d(coords)
3114  if not np.issubdtype(_arr_dtype(array), np.number):
3115    raise ValueError(f'Type {array.dtype} is not numeric.')
3116  if not np.issubdtype(coords.dtype, np.floating):
3117    raise ValueError(f'Type {coords.dtype} is not floating.')
3118  array_ndim = len(array.shape)
3119  if coords.ndim == 1 and coords.shape[0] > 1 and array_ndim == 1:
3120    coords = coords[:, None]
3121  grid_ndim = coords.shape[-1]
3122  grid_shape = array.shape[:grid_ndim]
3123  sample_shape = array.shape[grid_ndim:]
3124  resampled_ndim = coords.ndim - 1
3125  resampled_shape = coords.shape[:-1]
3126  if grid_ndim > array_ndim:
3127    raise ValueError(
3128        f'There are more coordinate dimensions ({grid_ndim}) in {coords=} than in {array.shape=}.'
3129    )
3130  gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(gridtype), grid_ndim)]
3131  boundary2 = np.broadcast_to(np.array(boundary), grid_ndim).tolist()
3132  cval = np.broadcast_to(cval, sample_shape)
3133  prefilter = filter if prefilter is None else prefilter
3134  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), grid_ndim)]
3135  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), resampled_ndim)]
3136  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3137  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, _arr_dtype(array), dtype)
3138  del gridtype, boundary, filter, prefilter, src_gamma, dst_gamma
3139  if jacobian is not None:
3140    jacobian = np.broadcast_to(jacobian, resampled_shape + (coords.shape[-1],) * 2)
3141  precision = _get_precision(precision, [_arr_dtype(array), dtype], [coords.dtype])
3142  weight_precision = _real_precision(precision)
3143  coords = coords.astype(weight_precision, copy=False)
3144  is_minification = False  # Current limitation; no prefiltering!
3145  assert max_block_size >= 0 or max_block_size == _MAX_BLOCK_SIZE_RECURSING
3146  for dim in range(grid_ndim):
3147    if boundary2[dim] == 'auto':
3148      boundary2[dim] = 'clamp' if is_minification else 'reflect'
3149    boundary2[dim] = _get_boundary(boundary2[dim])
3150
3151  if max_block_size != _MAX_BLOCK_SIZE_RECURSING:
3152    array = src_gamma2.decode(array, precision)
3153    for dim in range(grid_ndim):
3154      assert not is_minification
3155      if filter2[dim].requires_digital_filter:
3156        array = _apply_digital_filter_1d(
3157            array, gridtype2[dim], boundary2[dim], cval, filter2[dim], axis=dim
3158        )
3159    cval = _arr_numpy(src_gamma2.decode(cval, precision))
3160
3161  if math.prod(resampled_shape) > max_block_size > 0:
3162    block_shape = _block_shape_with_min_size(resampled_shape, max_block_size)
3163    if debug:
3164      print(f'(resample: splitting coords into blocks {block_shape}).')
3165    coord_blocks = _split_array_into_blocks(coords, block_shape)
3166
3167    def process_block(coord_block: _NDArray) -> _Array:
3168      return resample(
3169          array,
3170          coord_block,
3171          gridtype=gridtype2,
3172          boundary=boundary2,
3173          cval=cval,
3174          filter=filter2,
3175          prefilter=prefilter2,
3176          src_gamma='identity',
3177          dst_gamma=dst_gamma2,
3178          jacobian=jacobian,
3179          precision=precision,
3180          dtype=dtype,
3181          max_block_size=_MAX_BLOCK_SIZE_RECURSING,
3182      )
3183
3184    result_blocks = _map_function_over_blocks(coord_blocks, process_block)
3185    array = _merge_array_from_blocks(result_blocks)
3186    return array
3187
3188  # A concrete example of upsampling:
3189  #   array = np.ones((5, 7, 3))  # source RGB image has height=5 width=7
3190  #   coords = np.random.default_rng(0).random((8, 9, 2))  # output RGB image has height=8 width=9
3191  #   resample(array, coords, filter=('cubic', 'lanczos3'))
3192  #   grid_shape = 5, 7  grid_ndim = 2
3193  #   resampled_shape = 8, 9  resampled_ndim = 2
3194  #   sample_shape = (3,)
3195  #   src_float_index.shape = 8, 9
3196  #   src_first_index.shape = 8, 9
3197  #   sample_index.shape = (4,) for dim == 0, then (6,) for dim == 1
3198  #   weight = [shape(8, 9, 4), shape(8, 9, 6)]
3199  #   src_index = [shape(8, 9, 4), shape(8, 9, 6)]
3200
3201  # Both:[shape(8, 9, 4), shape(8, 9, 6)]
3202  weight: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3203  src_index: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3204  uses_cval = False
3205  all_num_samples = []  # will be [4, 6]
3206
3207  for dim in range(grid_ndim):
3208    src_size = grid_shape[dim]  # scalar
3209    coords_dim = coords[..., dim]  # (8, 9)
3210    radius = filter2[dim].radius  # scalar
3211    num_samples = int(np.ceil(radius * 2))  # scalar
3212    all_num_samples.append(num_samples)
3213
3214    boundary_dim = boundary2[dim]
3215    coords_dim = boundary_dim.preprocess_coordinates(coords_dim)
3216
3217    # Sample positions mapped back to source unit domain [0, 1].
3218    src_float_index = gridtype2[dim].index_from_point(coords_dim, src_size)  # (8, 9)
3219    src_first_index = (
3220        np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
3221        - (num_samples - 1) // 2
3222    )  # (8, 9)
3223
3224    sample_index = np.arange(num_samples, dtype=np.int32)  # (4,) then (6,)
3225    src_index[dim] = src_first_index[..., None] + sample_index  # (8, 9, 4) then (8, 9, 6)
3226    if filter2[dim].name == 'trapezoid':
3227      # (It might require changing the filter radius at every sample.)
3228      raise ValueError('resample() cannot use adaptive `trapezoid` filter.')
3229    if filter2[dim].name == 'impulse':
3230      weight[dim] = np.ones_like(src_index[dim], weight_precision)
3231    else:
3232      x = src_float_index[..., None] - src_index[dim].astype(weight_precision)
3233      weight[dim] = filter2[dim](x).astype(weight_precision, copy=False)
3234      if filter2[dim].name != 'narrowbox' and (
3235          is_minification or not filter2[dim].partition_of_unity
3236      ):
3237        weight[dim] = weight[dim] / weight[dim].sum(axis=-1)[..., None]
3238
3239    src_index[dim], weight[dim] = boundary_dim.apply(
3240        src_index[dim], weight[dim], coords_dim, src_size, gridtype2[dim]
3241    )
3242    if boundary_dim.uses_cval or filter2[dim].name == 'narrowbox':
3243      uses_cval = True
3244
3245  # Gather the samples.
3246
3247  # Recall that src_index = [shape(8, 9, 4), shape(8, 9, 6)].
3248  src_index_expanded = []
3249  for dim in range(grid_ndim):
3250    src_index_dim = np.moveaxis(
3251        src_index[dim].reshape(src_index[dim].shape + (1,) * (grid_ndim - 1)),
3252        resampled_ndim,
3253        resampled_ndim + dim,
3254    )
3255    src_index_expanded.append(src_index_dim)
3256  indices = tuple(src_index_expanded)  # (shape(8, 9, 4, 1), shape(8, 9, 1, 6))
3257  samples = _arr_getitem(array, indices)  # (8, 9, 4, 6, 3)
3258
3259  # Indirectly derive samples.ndim (which is unavailable during Tensorflow grad computation).
3260  samples_ndim = resampled_ndim + grid_ndim + len(sample_shape)
3261
3262  # Compute an Einstein summation over the samples and each of the per-dimension weights.
3263
3264  def label(dims: Iterable[int]) -> str:
3265    return ''.join(chr(ord('a') + i) for i in dims)
3266
3267  operands = [samples]  # (8, 9, 4, 6, 3)
3268  assert samples_ndim < 26  # Letters 'a' through 'z'.
3269  labels = [label(range(samples_ndim))]  # ['abcde']
3270  for dim in range(grid_ndim):
3271    operands.append(weight[dim])  # (8, 9, 4), then (8, 9, 6)
3272    labels.append(label(list(range(resampled_ndim)) + [resampled_ndim + dim]))  # 'abc' then 'abd'
3273  output_label = label(
3274      list(range(resampled_ndim)) + list(range(resampled_ndim + grid_ndim, samples_ndim))
3275  )  # 'abe'
3276  subscripts = ','.join(labels) + '->' + output_label  # 'abcde,abc,abd->abe'
3277  # Starting in numpy 2.0, np.einsum() outputs np.float64 even with all np.float32 inputs;
3278  # GPT: "aligns np.einsum with other functions where intermediate calculations use higher
3279  # precision (np.float64) regardless of input type when floating-point arithmetic is involved."
3280  # we could explicitly add the parameter `dtype=precision`.
3281  array = _arr_einsum(subscripts, *operands)  # (8, 9, 3)
3282
3283  # Gathering `samples` is the memory bottleneck.  It would be ideal if the gather() and einsum()
3284  # computations could be fused.  In Jax, https://github.com/google/jax/issues/3206 suggests
3285  # that this may become possible.  In any case, for large outputs it helps to partition the
3286  # evaluation over output tiles (using max_block_size).
3287
3288  if uses_cval:
3289    cval_weight = 1.0 - np.multiply.reduce(
3290        [weight[dim].sum(axis=-1) for dim in range(resampled_ndim)]
3291    )  # (8, 9)
3292    cval_weight_reshaped = cval_weight.reshape(cval_weight.shape + (1,) * len(sample_shape))
3293    array += _make_array((cval_weight_reshaped * cval).astype(precision, copy=False), arraylib)
3294
3295  array = dst_gamma2.encode(array, dtype)
3296  return array

Interpolate array (a grid of samples) at specified unit-domain coordinates coords.

The last dimension of coords contains unit-domain coordinates at which to interpolate the domain grid samples in array.

The number of coordinates (coords.shape[-1]) determines how to interpret array: its first coords.shape[-1] dimensions define the grid, and the remaining dimensions describe each grid sample (e.g., scalar, vector, tensor).

Concretely, the grid has shape array.shape[:coords.shape[-1]] and each grid sample has shape array.shape[coords.shape[-1]:].

Examples include:

  • Resample a grayscale image with array.shape = height, width onto a new grayscale image with new.shape = height2, width2 by using coords.shape = height2, width2, 2.

  • Resample an RGB image with array.shape = height, width, 3 onto a new RGB image with new.shape = height2, width2, 3 by using coords.shape = height2, width2, 2.

  • Sample an RGB image at num 2D points along a line segment by using coords.shape = num, 2.

  • Sample an RGB image at a single 2D point by using coords.shape = (2,).

  • Sample a 3D grid of 3x3 Jacobians with array.shape = nz, ny, nx, 3, 3 along a 2D plane by using coords.shape = height, width, 3.

  • Map a grayscale image through a color map by using array.shape = 256, 3 and coords.shape = height, width.

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. The coordinate dimensions appear first, and each grid sample may have an arbitrary shape. Each grid dimension must be at least 1 for a 'dual' grid or at least 2 for a 'primal' grid.
  • coords: Grid of points at which to resample array. The point coordinates are in the last dimension of coords. The domain associated with the source grid is a unit hypercube, i.e. with a range [0, 1] on each coordinate dimension. The output grid has shape coords.shape[:-1] and each of its grid samples has shape array.shape[coords.shape[-1]:].
  • gridtype: Placement of the samples in the source domain grid for each dimension, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual'.
  • boundary: The reconstruction boundary rule for each dimension in coords.shape[-1], specified as either a name in BOUNDARIES or a Boundary instance. The special value 'auto' uses 'reflect' for upsampling and 'clamp' for downsampling.
  • cval: Constant value used beyond the samples by some boundary rules. It must be broadcastable onto the shape array.shape[coords.shape[-1]:]. It is subject to src_gamma.
  • filter: The reconstruction kernel for each dimension in coords.shape[-1], specified as either a filter name in FILTERS or a Filter instance.
  • prefilter: The prefilter kernel for each dimension in coords.shape[:-1], specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter.
  • gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from array and when creating output grid samples. It is specified as either a name in GAMMAS or a Gamma instance. If both array.dtype and dtype are uint, the default is 'power2'. If both are non-uint, the default is 'identity'. Otherwise, gamma or src_gamma/dst_gamma must be set. Gamma correction assumes that float values are in the range [0.0, 1.0].
  • src_gamma: Component transfer function used to "decode" array samples. Parameters gamma and src_gamma cannot both be set.
  • dst_gamma: Component transfer function used to "encode" the output samples. Parameters gamma and dst_gamma cannot both be set.
  • jacobian: Optional array, which must be broadcastable onto the shape coords.shape[:-1] + (coords.shape[-1], coords.shape[-1]), storing for each point in the output grid the Jacobian matrix of the map from the unit output domain to the unit source domain. If omitted, it is estimated by computing finite differences on coords.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype, coords.dtype, and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • max_block_size: If nonzero, maximum number of grid points in coords before the resampling evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
  • debug: Show internal information.
Returns:

A new sample grid of shape coords.shape[:-1], represented as an array of shape coords.shape[:-1] + array.shape[coords.shape[-1]:], of the same array library type as the source array.

Example of resample operation:

For reference, the identity resampling for a scalar-valued grid with the default grid-type 'dual' is:

>>> array = np.random.default_rng(0).random((5, 7, 3))
>>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
>>> new_array = resample(array, coords)
>>> assert np.allclose(new_array, array)

It is more efficient to use the function resize for the special case where the coords are obtained as simple scaling and translation of a new regular grid over the source domain:

>>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
>>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
>>> coords = (coords - translate) / scale
>>> resampled = resample(array, coords)
>>> resized = resize(array, new_shape, scale=scale, translate=translate)
>>> assert np.allclose(resampled, resized)
def resample_affine( array: Array, /, shape: Iterable[int], matrix: ArrayLike, *, gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, precision: DTypeLike = None, dtype: DTypeLike = None, **kwargs: Any) -> Array:
3299def resample_affine(
3300    array: _Array,
3301    /,
3302    shape: Iterable[int],
3303    matrix: _ArrayLike,
3304    *,
3305    gridtype: str | Gridtype | None = None,
3306    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3307    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3308    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3309    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3310    precision: _DTypeLike = None,
3311    dtype: _DTypeLike = None,
3312    **kwargs: Any,
3313) -> _Array:
3314  """Resample a source array using an affinely transformed grid of given shape.
3315
3316  The `matrix` transformation can be linear,
3317    `source_point = matrix @ destination_point`,
3318  or it can be affine where the last matrix column is an offset vector,
3319    `source_point = matrix @ (destination_point, 1.0)`.
3320
3321  Args:
3322    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3323      The array must have numeric type.  The number of grid dimensions is determined from
3324      `matrix.shape[0]`; the remaining dimensions are for each sample value and are all
3325      linearly interpolated.
3326    shape: Dimensions of the desired destination grid.  The number of destination grid dimensions
3327      may be different from that of the source grid.
3328    matrix: 2D array for a linear or affine transform from unit-domain destination points
3329      (in a space with `len(shape)` dimensions) into unit-domain source points (in a space with
3330      `matrix.shape[0]` dimensions).  If the matrix has `len(shape) + 1` columns, the last column
3331      is the affine offset (i.e., translation).
3332    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
3333      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
3334      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
3335    src_gridtype: Placement of samples in the source domain grid for each dimension.
3336      Parameters `gridtype` and `src_gridtype` cannot both be set.
3337    dst_gridtype: Placement of samples in the output domain grid for each dimension.
3338      Parameters `gridtype` and `dst_gridtype` cannot both be set.
3339    filter: The reconstruction kernel for each dimension in `matrix.shape[0]`, specified as either
3340      a filter name in `FILTERS` or a `Filter` instance.
3341    prefilter: The prefilter kernel for each dimension in `len(shape)`, specified as either
3342      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3343      (i.e., minification).  If `None`, it inherits the value of `filter`.
3344    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3345      on `array.dtype` and `dtype`.
3346    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3347      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3348      to the uint range.
3349    **kwargs: Additional parameters for `resample` function.
3350
3351  Returns:
3352    An array of the same class as the source `array`, representing a grid with specified `shape`,
3353    where each grid value is resampled from `array`.  Thus the shape of the returned array is
3354    `shape + array.shape[matrix.shape[0]:]`.
3355  """
3356  if isinstance(array, (tuple, list)):
3357    array = np.asarray(array)
3358  shape = tuple(shape)
3359  matrix = np.asarray(matrix)
3360  dst_ndim = len(shape)
3361  if matrix.ndim != 2:
3362    raise ValueError(f'Array {matrix} is not 2D matrix.')
3363  src_ndim = matrix.shape[0]
3364  # grid_shape = array.shape[:src_ndim]
3365  is_affine = matrix.shape[1] == dst_ndim + 1
3366  if src_ndim > len(array.shape):
3367    raise ValueError(
3368        f'Matrix {matrix} has more rows ({matrix.shape[0]}) than ndim in {array.shape=}.'
3369    )
3370  if matrix.shape[1] != dst_ndim and not is_affine:
3371    raise ValueError(
3372        f'Matrix has {matrix.shape=}, but we expect either {dst_ndim} or {dst_ndim + 1} columns.'
3373    )
3374  src_gridtype2, dst_gridtype2 = _get_gridtypes(
3375      gridtype, src_gridtype, dst_gridtype, src_ndim, dst_ndim
3376  )
3377  prefilter = filter if prefilter is None else prefilter
3378  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), src_ndim)]
3379  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), dst_ndim)]
3380  del src_gridtype, dst_gridtype, filter, prefilter
3381  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3382  precision = _get_precision(precision, [_arr_dtype(array), dtype], [])
3383  weight_precision = _real_precision(precision)
3384
3385  dst_position_list = []  # per dimension
3386  for dim in range(dst_ndim):
3387    dst_size = shape[dim]
3388    dst_index = np.arange(dst_size, dtype=weight_precision)
3389    dst_position_list.append(dst_gridtype2[dim].point_from_index(dst_index, dst_size))
3390  dst_position = np.meshgrid(*dst_position_list, indexing='ij')
3391
3392  linear_matrix = matrix[:, :-1] if is_affine else matrix
3393  src_position = np.tensordot(linear_matrix, dst_position, 1)
3394  coords = np.moveaxis(src_position, 0, -1)
3395  if is_affine:
3396    coords += matrix[:, -1]
3397
3398  # TODO: Based on grid_shape, shape, linear_matrix, and prefilter, determine a
3399  # convolution prefilter and apply it to bandlimit 'array', using boundary for padding.
3400
3401  return resample(
3402      array,
3403      coords,
3404      gridtype=src_gridtype2,
3405      filter=filter2,
3406      prefilter=prefilter2,
3407      precision=precision,
3408      dtype=dtype,
3409      **kwargs,
3410  )

Resample a source array using an affinely transformed grid of given shape.

The matrix transformation can be linear, source_point = matrix @ destination_point, or it can be affine where the last matrix column is an offset vector, source_point = matrix @ (destination_point, 1.0).

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. The number of grid dimensions is determined from matrix.shape[0]; the remaining dimensions are for each sample value and are all linearly interpolated.
  • shape: Dimensions of the desired destination grid. The number of destination grid dimensions may be different from that of the source grid.
  • matrix: 2D array for a linear or affine transform from unit-domain destination points (in a space with len(shape) dimensions) into unit-domain source points (in a space with matrix.shape[0] dimensions). If the matrix has len(shape) + 1 columns, the last column is the affine offset (i.e., translation).
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual' if gridtype, src_gridtype, and dst_gridtype are all kept None.
  • src_gridtype: Placement of samples in the source domain grid for each dimension. Parameters gridtype and src_gridtype cannot both be set.
  • dst_gridtype: Placement of samples in the output domain grid for each dimension. Parameters gridtype and dst_gridtype cannot both be set.
  • filter: The reconstruction kernel for each dimension in matrix.shape[0], specified as either a filter name in FILTERS or a Filter instance.
  • prefilter: The prefilter kernel for each dimension in len(shape), specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • **kwargs: Additional parameters for resample function.
Returns:

An array of the same class as the source array, representing a grid with specified shape, where each grid value is resampled from array. Thus the shape of the returned array is shape + array.shape[matrix.shape[0]:].

def rotation_about_center_in_2d( src_shape: ArrayLike, /, angle: float, *, new_shape: Optional[ArrayLike] = None, scale: float = 1.0) -> np.ndarray:
3441def rotation_about_center_in_2d(
3442    src_shape: _ArrayLike,
3443    /,
3444    angle: float,
3445    *,
3446    new_shape: _ArrayLike | None = None,
3447    scale: float = 1.0,
3448) -> _NDArray:
3449  """Return the 3x3 matrix mapping destination into a source unit domain.
3450
3451  The returned matrix accounts for the possibly non-square domain shapes.
3452
3453  Args:
3454    src_shape: Resolution `(ny, nx)` of the source domain grid.
3455    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3456      onto the destination domain.
3457    new_shape: Resolution `(ny, nx)` of the destination domain grid; it defaults to `src_shape`.
3458    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3459  """
3460
3461  def translation_matrix(vector: _NDArray) -> _NDArray:
3462    matrix = np.eye(len(vector) + 1)
3463    matrix[:-1, -1] = vector
3464    return matrix
3465
3466  def scaling_matrix(scale: _NDArray) -> _NDArray:
3467    return np.diag(tuple(scale) + (1.0,))
3468
3469  def rotation_matrix_2d(angle: float) -> _NDArray:
3470    cos, sin = np.cos(angle), np.sin(angle)
3471    return np.array([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
3472
3473  src_shape = np.asarray(src_shape)
3474  new_shape = src_shape if new_shape is None else np.asarray(new_shape)
3475  _check_eq(src_shape.shape, (2,))
3476  _check_eq(new_shape.shape, (2,))
3477  half = np.array([0.5, 0.5])
3478  matrix = (
3479      translation_matrix(half)
3480      @ scaling_matrix(min(src_shape) / src_shape)
3481      @ rotation_matrix_2d(angle)
3482      @ scaling_matrix(scale * new_shape / min(new_shape))
3483      @ translation_matrix(-half)
3484  )
3485  assert np.allclose(matrix[-1], [0.0, 0.0, 1.0])
3486  return matrix

Return the 3x3 matrix mapping destination into a source unit domain.

The returned matrix accounts for the possibly non-square domain shapes.

Arguments:
  • src_shape: Resolution (ny, nx) of the source domain grid.
  • angle: Angle in radians (positive from x to y axis) applied when mapping the source domain onto the destination domain.
  • new_shape: Resolution (ny, nx) of the destination domain grid; it defaults to src_shape.
  • scale: Scaling factor applied when mapping the source domain onto the destination domain.
def rotate_image_about_center( image: np.ndarray, /, angle: float, *, new_shape: Optional[ArrayLike] = None, scale: float = 1.0, num_rotations: int = 1, **kwargs: Any) -> np.ndarray:
3489def rotate_image_about_center(
3490    image: _NDArray,
3491    /,
3492    angle: float,
3493    *,
3494    new_shape: _ArrayLike | None = None,
3495    scale: float = 1.0,
3496    num_rotations: int = 1,
3497    **kwargs: Any,
3498) -> _NDArray:
3499  """Return a copy of `image` rotated about its center.
3500
3501  Args:
3502    image: Source grid samples; the first two dimensions are spatial (ny, nx).
3503    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3504      onto the destination domain.
3505    new_shape: Resolution `(ny, nx)` of the output grid; it defaults to `image.shape[:2]`.
3506    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3507    num_rotations: Number of rotations (each by `angle`).  Successive resamplings are useful in
3508      analyzing the filtering quality.
3509    **kwargs: Additional parameters for `resample_affine`.
3510  """
3511  new_shape = image.shape[:2] if new_shape is None else np.asarray(new_shape)
3512  matrix = rotation_about_center_in_2d(image.shape[:2], angle, new_shape=new_shape, scale=scale)
3513  for _ in range(num_rotations):
3514    image = resample_affine(image, new_shape, matrix[:-1], **kwargs)
3515  return image

Return a copy of image rotated about its center.

Arguments:
  • image: Source grid samples; the first two dimensions are spatial (ny, nx).
  • angle: Angle in radians (positive from x to y axis) applied when mapping the source domain onto the destination domain.
  • new_shape: Resolution (ny, nx) of the output grid; it defaults to image.shape[:2].
  • scale: Scaling factor applied when mapping the source domain onto the destination domain.
  • num_rotations: Number of rotations (each by angle). Successive resamplings are useful in analyzing the filtering quality.
  • **kwargs: Additional parameters for resample_affine.
def pil_image_resize(array: ArrayLike, /, shape: Iterable[int], *, filter: str) -> np.ndarray:
3518def pil_image_resize(
3519    array: _ArrayLike,
3520    /,
3521    shape: Iterable[int],
3522    *,
3523    filter: str,
3524    boundary: str = 'natural',
3525    cval: float = 0.0,
3526) -> _NDArray:
3527  """Invoke `PIL.Image.resize` using the same parameters as `resize`."""
3528  import PIL.Image
3529
3530  if boundary != 'natural':
3531    raise ValueError(f"{boundary=} must equal 'natural'.")
3532  del cval
3533  array = np.asarray(array)
3534  assert 1 <= array.ndim <= 3
3535  assert np.issubdtype(array.dtype, np.floating)
3536  shape = tuple(shape)
3537  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3538  if array.ndim == 1:
3539    return pil_image_resize(array[None], (1, *shape), filter=filter)[0]
3540  if not hasattr(PIL.Image, 'Resampling'):  # Pillow<9.0
3541    PIL.Image.Resampling = PIL.Image
3542  filters = {
3543      'impulse': PIL.Image.Resampling.NEAREST,
3544      'box': PIL.Image.Resampling.BOX,
3545      'triangle': PIL.Image.Resampling.BILINEAR,
3546      'hamming1': PIL.Image.Resampling.HAMMING,
3547      'cubic': PIL.Image.Resampling.BICUBIC,
3548      'lanczos3': PIL.Image.Resampling.LANCZOS,
3549  }
3550  if filter not in filters:
3551    raise ValueError(f'{filter=} not in {filters=}.')
3552  pil_resample = filters[filter]
3553  if array.ndim == 2:
3554    return np.array(
3555        PIL.Image.fromarray(array).resize(shape[::-1], resample=pil_resample), array.dtype
3556    )
3557  stack = []
3558  for channel in np.moveaxis(array, -1, 0):
3559    pil_image = PIL.Image.fromarray(channel).resize(shape[::-1], resample=pil_resample)
3560    stack.append(np.array(pil_image, array.dtype))
3561  return np.dstack(stack)

Invoke PIL.Image.resize using the same parameters as resize.

def cv_resize(array: ArrayLike, /, shape: Iterable[int], *, filter: str) -> np.ndarray:
3564def cv_resize(
3565    array: _ArrayLike,
3566    /,
3567    shape: Iterable[int],
3568    *,
3569    filter: str,
3570    boundary: str = 'clamp',
3571    cval: float = 0.0,
3572) -> _NDArray:
3573  """Invoke `cv.resize` using the same parameters as `resize`."""
3574  import cv2 as cv
3575
3576  if boundary != 'clamp':
3577    raise ValueError(f"{boundary=} must equal 'clamp'.")
3578  del cval
3579  array = np.asarray(array)
3580  assert 1 <= array.ndim <= 3
3581  shape = tuple(shape)
3582  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3583  if array.ndim == 1:
3584    return cv_resize(array[None], (1, *shape), filter=filter)[0]
3585  filters = {
3586      'impulse': cv.INTER_NEAREST,  # Or consider cv.INTER_NEAREST_EXACT.
3587      'triangle': cv.INTER_LINEAR_EXACT,  # Or just cv.INTER_LINEAR.
3588      'trapezoid': cv.INTER_AREA,
3589      'sharpcubic': cv.INTER_CUBIC,
3590      'lanczos4': cv.INTER_LANCZOS4,
3591  }
3592  if filter not in filters:
3593    raise ValueError(f'{filter=} not in {filters=}.')
3594  interpolation = filters[filter]
3595  result = cv.resize(array, shape[::-1], interpolation=interpolation)
3596  if array.ndim == 3 and result.ndim == 2:
3597    assert array.shape[2] == 1
3598    return result[..., None]  # Add back the last dimension dropped by cv.resize().
3599  return result

Invoke cv.resize using the same parameters as resize.

def scipy_ndimage_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, boundary: str = 'reflect', cval: float = 0.0) -> np.ndarray:
3602def scipy_ndimage_resize(
3603    array: _ArrayLike,
3604    /,
3605    shape: Iterable[int],
3606    *,
3607    filter: str,
3608    boundary: str = 'reflect',
3609    cval: float = 0.0,
3610    scale: float | Iterable[float] = 1.0,
3611    translate: float | Iterable[float] = 0.0,
3612) -> _NDArray:
3613  """Invoke `scipy.ndimage.map_coordinates` using the same parameters as `resize`."""
3614  array = np.asarray(array)
3615  shape = tuple(shape)
3616  assert 1 <= len(shape) <= array.ndim
3617  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3618  if filter not in filters:
3619    raise ValueError(f'{filter=} not in {filters=}.')
3620  order = filters[filter]
3621  boundaries = {'reflect': 'reflect', 'wrap': 'grid-wrap', 'clamp': 'nearest', 'border': 'constant'}
3622  if boundary not in boundaries:
3623    raise ValueError(f'{boundary=} not in {boundaries=}.')
3624  mode = boundaries[boundary]
3625  shape_all = shape + array.shape[len(shape) :]
3626  coords = np.moveaxis(np.indices(shape_all, array.dtype), 0, -1)
3627  coords[..., : len(shape)] = (
3628      (coords[..., : len(shape)] + 0.5) / shape - np.asarray(translate)
3629  ) / np.asarray(scale) * np.array(array.shape)[: len(shape)] - 0.5
3630  coords = np.moveaxis(coords, -1, 0)
3631  return scipy.ndimage.map_coordinates(array, coords, order=order, mode=mode, cval=cval)

Invoke scipy.ndimage.map_coordinates using the same parameters as resize.

def skimage_transform_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, boundary: str = 'reflect', cval: float = 0.0) -> np.ndarray:
3634def skimage_transform_resize(
3635    array: _ArrayLike,
3636    /,
3637    shape: Iterable[int],
3638    *,
3639    filter: str,
3640    boundary: str = 'reflect',
3641    cval: float = 0.0,
3642) -> _NDArray:
3643  """Invoke `skimage.transform.resize` using the same parameters as `resize`."""
3644  import skimage.transform
3645
3646  array = np.asarray(array)
3647  shape = tuple(shape)
3648  assert 1 <= len(shape) <= array.ndim
3649  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3650  if filter not in filters:
3651    raise ValueError(f'{filter=} not in {filters=}.')
3652  order = filters[filter]
3653  boundaries = {'reflect': 'symmetric', 'wrap': 'wrap', 'clamp': 'edge', 'border': 'constant'}
3654  if boundary not in boundaries:
3655    raise ValueError(f'{boundary=} not in {boundaries=}.')
3656  mode = boundaries[boundary]
3657  shape_all = shape + array.shape[len(shape) :]
3658  # Default anti_aliasing=None automatically enables (poor) Gaussian prefilter if downsampling.
3659  # clip=False is the default behavior in `resampler` if the output type is non-integer.
3660  return skimage.transform.resize(
3661      array, shape_all, order=order, mode=mode, cval=cval, clip=False
3662  )  # type: ignore[no-untyped-call]

Invoke skimage.transform.resize using the same parameters as resize.

def tf_image_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, antialias: bool = True) -> TensorflowTensor:
3677def tf_image_resize(
3678    array: _ArrayLike,
3679    /,
3680    shape: Iterable[int],
3681    *,
3682    filter: str,
3683    boundary: str = 'natural',
3684    cval: float = 0.0,
3685    antialias: bool = True,
3686) -> _TensorflowTensor:
3687  """Invoke `tf.image.resize` using the same parameters as `resize`."""
3688  import tensorflow as tf
3689
3690  if filter not in _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER:
3691    raise ValueError(f'{filter=} not in {_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER=}.')
3692  if boundary != 'natural':
3693    raise ValueError(f"{boundary=} must equal 'natural'.")
3694  del cval
3695  array2 = tf.convert_to_tensor(array)
3696  ndim = len(array2.shape)
3697  del array
3698  assert 1 <= ndim <= 3
3699  shape = tuple(shape)
3700  _check_eq(len(shape), 2 if ndim >= 2 else 1)
3701  match ndim:
3702    case 1:
3703      return tf_image_resize(array2[None], (1, *shape), filter=filter, antialias=antialias)[0]
3704    case 2:
3705      return tf_image_resize(array2[..., None], shape, filter=filter, antialias=antialias)[..., 0]
3706    case _:
3707      method = _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER[filter]
3708      return tf.image.resize(array2, shape, method=method, antialias=antialias)

Invoke tf.image.resize using the same parameters as resize.

def torch_nn_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, antialias: bool = False) -> TorchTensor:
3719def torch_nn_resize(
3720    array: _ArrayLike,
3721    /,
3722    shape: Iterable[int],
3723    *,
3724    filter: str,
3725    boundary: str = 'clamp',
3726    cval: float = 0.0,
3727    antialias: bool = False,
3728) -> _TorchTensor:
3729  """Invoke `torch.nn.functional.interpolate` using the same parameters as `resize`."""
3730  import torch
3731
3732  if filter not in _TORCH_INTERPOLATE_MODE_FROM_FILTER:
3733    raise ValueError(f'{filter=} not in {_TORCH_INTERPOLATE_MODE_FROM_FILTER=}.')
3734  if boundary != 'clamp':
3735    raise ValueError(f"{boundary=} must equal 'clamp'.")
3736  del cval
3737  a = torch.as_tensor(array)
3738  del array
3739  assert 1 <= a.ndim <= 3
3740  shape = tuple(shape)
3741  _check_eq(len(shape), 2 if a.ndim >= 2 else 1)
3742  mode = _TORCH_INTERPOLATE_MODE_FROM_FILTER[filter]
3743
3744  def local_resize(a: _TorchTensor) -> _TorchTensor:
3745    # For upsampling, BILINEAR antialias is same PSNR and slower,
3746    #  and BICUBIC antialias is worse PSNR and faster.
3747    # For downsampling, antialias improves PSNR for both BILINEAR and BICUBIC.
3748    # Default align_corners=None corresponds to False which is what we desire.
3749    return torch.nn.functional.interpolate(a, shape, mode=mode, antialias=antialias)
3750
3751  match a.ndim:
3752    case 1:
3753      shape = (1, *shape)
3754      return local_resize(a[None, None, None])[0, 0, 0]
3755    case 2:
3756      return local_resize(a[None, None])[0, 0]
3757    case _:
3758      return local_resize(a.moveaxis(2, 0)[None])[0].moveaxis(0, 2)

Invoke torch.nn.functional.interpolate using the same parameters as resize.

def jax_image_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0) -> JaxArray:
3761def jax_image_resize(
3762    array: _ArrayLike,
3763    /,
3764    shape: Iterable[int],
3765    *,
3766    filter: str,
3767    boundary: str = 'natural',
3768    cval: float = 0.0,
3769    scale: float | Iterable[float] = 1.0,
3770    translate: float | Iterable[float] = 0.0,
3771) -> _JaxArray:
3772  """Invoke `jax.image.scale_and_translate` using the same parameters as `resize`."""
3773  import jax.image
3774  import jax.numpy as jnp
3775
3776  filters = 'triangle cubic lanczos3 lanczos5'.split()
3777  if filter not in filters:
3778    raise ValueError(f'{filter=} not in {filters=}.')
3779  if boundary != 'natural':
3780    raise ValueError(f"{boundary=} must equal 'natural'.")
3781  # When `scale` or `translate` are applied, any region outside the unit domain is assigned value 0.
3782  # To be consistent, the parameter `cval` must be zero.
3783  if scale != 1.0 and cval != 0.0:
3784    raise ValueError(f'Non-unity {scale=} requires that {cval=} be zero.')
3785  if translate != 0.0 and cval != 0.0:
3786    raise ValueError(f'Nonzero {translate=} requires that {cval=} be zero.')
3787  array2 = jnp.asarray(array)
3788  del array
3789  shape = tuple(shape)
3790  assert len(shape) <= array2.ndim
3791  completed_shape = shape + (1,) * (array2.ndim - len(shape))
3792  spatial_dims = list(range(len(shape)))
3793  scale2 = np.broadcast_to(np.array(scale), len(shape))
3794  scale2 = scale2 / np.array(array2.shape[: len(shape)]) * np.array(shape)
3795  translate2 = np.broadcast_to(np.array(translate), len(shape))
3796  translate2 = translate2 * np.array(shape)
3797  return jax.image.scale_and_translate(
3798      array2, completed_shape, spatial_dims, scale2, translate2, filter
3799  )

Invoke jax.image.scale_and_translate using the same parameters as resize.