resampler

resampler: fast differentiable resizing and warping of arbitrary grids.

   1"""resampler: fast differentiable resizing and warping of arbitrary grids.
   2
   3.. include:: ../README.md
   4"""
   5
   6from __future__ import annotations
   7
   8__docformat__ = 'google'
   9__version__ = '0.8.7'
  10__version_info__ = tuple(int(num) for num in __version__.split('.'))
  11
  12from collections.abc import Callable, Iterable, Sequence
  13import abc
  14import dataclasses
  15import functools
  16import importlib
  17import itertools
  18import math
  19import os
  20import sys
  21import types
  22import typing
  23from typing import Any, Generic, Literal, Union
  24
  25import numpy as np
  26import numpy.typing
  27import scipy.interpolate
  28import scipy.linalg
  29import scipy.ndimage
  30import scipy.sparse
  31import scipy.sparse.linalg
  32
  33
  34def _noop_decorator(*args: Any, **kwargs: Any) -> Any:
  35  """Return function decorated with no-operation; invokable with or without args."""
  36  if len(args) != 1 or not callable(args[0]) or kwargs:
  37    return _noop_decorator  # Decorator is invoked with arguments; ignore them.
  38  func: Callable[..., Any] = args[0]
  39  return func
  40
  41
  42try:
  43  import numba
  44except ModuleNotFoundError:
  45  numba = sys.modules['numba'] = types.ModuleType('numba')
  46  numba.njit = _noop_decorator
  47using_numba = hasattr(numba, 'jit')
  48
  49if TYPE_CHECKING:
  50  import jax.numpy
  51  import tensorflow as tf
  52  import torch
  53
  54  _DType = np.dtype[Any]  # (Requires Python 3.9 or TYPE_CHECKING.)
  55  _NDArray = np.ndarray[Any, Any]
  56  _DTypeLike = numpy.DTypeLike
  57  _ArrayLike = numpy.ArrayLike
  58  _TensorflowTensor: TypeAlias = tf.Tensor
  59  _TorchTensor: TypeAlias = torch.Tensor
  60  _JaxArray: TypeAlias = jax.numpy.ndarray
  61
  62else:
  63  # Typically, create named types for use in the `pdoc` documentation.
  64  # But here, these are superseded by the declarations in __init__.pyi!
  65  _DType = Any
  66  _NDArray = Any
  67  _DTypeLike = Any
  68  _ArrayLike = Any
  69  _TensorflowTensor = Any
  70  _TorchTensor = Any
  71  _JaxArray = Any
  72
  73_Array = TypeVar('_Array', _NDArray, _TensorflowTensor, _TorchTensor, _JaxArray)
  74_AnyArray = Union[_NDArray, _TensorflowTensor, _TorchTensor, _JaxArray]
  75
  76
  77def _check_eq(a: Any, b: Any, /) -> None:
  78  """If the two values or arrays are not equal, raise an exception with a useful message."""
  79  are_equal = np.all(a == b) if isinstance(a, np.ndarray) else a == b
  80  if not are_equal:
  81    raise AssertionError(f'{a!r} == {b!r}')
  82
  83
  84def _real_precision(dtype: _DTypeLike, /) -> _DType:
  85  """Return the type of the real part of a complex number."""
  86  return np.array([], dtype).real.dtype
  87
  88
  89def _complex_precision(dtype: _DTypeLike, /) -> _DType:
  90  """Return a complex type to represent a non-complex type."""
  91  return np.result_type(dtype, np.complex64)
  92
  93
  94def _get_precision(
  95    precision: _DTypeLike, dtypes: list[_DType], weight_dtypes: list[_DType], /
  96) -> _DType:
  97  """Return dtype based on desired precision or on data and weight types."""
  98  precision2 = np.dtype(
  99      precision if precision is not None else np.result_type(np.float32, *dtypes, *weight_dtypes)
 100  )
 101  if not np.issubdtype(precision2, np.inexact):
 102    raise ValueError(f'Precision {precision2} is not floating or complex.')
 103  check_complex = [precision2, *dtypes]
 104  is_complex = [np.issubdtype(dtype, np.complexfloating) for dtype in check_complex]
 105  if len(set(is_complex)) != 1:
 106    s_types = ','.join(str(dtype) for dtype in check_complex)
 107    raise ValueError(f'Types {s_types} must be all real or all complex.')
 108  return precision2
 109
 110
 111def _sinc(x: _ArrayLike, /) -> _NDArray:
 112  """Return the value `np.sinc(x)` but improved to:
 113  (1) ignore underflow that occurs at 0.0 for np.float32, and
 114  (2) output exact zero for integer input values.
 115
 116  >>> _sinc(np.array([-3, -2, -1, 0], np.float32))
 117  array([0., 0., 0., 1.], dtype=float32)
 118
 119  >>> _sinc(np.array([-3, -2, -1, 0]))
 120  array([0., 0., 0., 1.])
 121
 122  >>> _sinc(0)
 123  1.0
 124  """
 125  x = np.asarray(x)
 126  x_is_scalar = x.ndim == 0
 127  with np.errstate(under='ignore'):
 128    result = np.sinc(np.atleast_1d(x))
 129    result[x == np.floor(x)] = 0.0
 130    result[x == 0] = 1.0
 131    return result.item() if x_is_scalar else result
 132
 133
 134def _is_symmetric(matrix: scipy.sparse.spmatrix, /, tol: float = 1e-10) -> bool:
 135  """Return True if the sparse matrix is symmetric."""
 136  norm: float = scipy.sparse.linalg.norm(matrix - matrix.T, np.inf)
 137  return norm <= tol
 138
 139
 140def _cache_sampled_1d_function(
 141    xmin: float,
 142    xmax: float,
 143    *,
 144    num_samples: int = 3_600,
 145    enable: bool = True,
 146) -> Callable[[Callable[[_ArrayLike], _NDArray]], Callable[[_ArrayLike], _NDArray]]:
 147  """Function decorator to linearly interpolate cached function values."""
 148  # Speed unchanged up to num_samples=12_000, then slow decrease until 100_000.
 149
 150  def wrap_it(func: Callable[[_ArrayLike], _NDArray]) -> Callable[[_ArrayLike], _NDArray]:
 151    if not enable:
 152      return func
 153
 154    dx = (xmax - xmin) / num_samples
 155    x = np.linspace(xmin, xmax + dx, num_samples + 2, dtype=np.float32)
 156    samples_func = func(x)
 157    assert np.all(samples_func[[0, -1, -2]] == 0.0)
 158
 159    @functools.wraps(func)
 160    def interpolate_using_cached_samples(x: _ArrayLike) -> _NDArray:
 161      x = np.asarray(x)
 162      index_float = np.clip((x - xmin) / dx, 0.0, num_samples)
 163      index = index_float.astype(np.int64)
 164      frac = np.subtract(index_float, index, dtype=np.float32)
 165      return (1 - frac) * samples_func[index] + frac * samples_func[index + 1]
 166
 167    return interpolate_using_cached_samples
 168
 169  return wrap_it
 170
 171
 172class _DownsampleIn2dUsingBoxFilter:
 173  """Fast 2D box-filter downsampling using cached numba-jitted functions."""
 174
 175  def __init__(self) -> None:
 176    # Downsampling function for params (dtype, block_height, block_width, ch).
 177    self._jitted_function: dict[tuple[_DType, int, int, int], Callable[[_NDArray], _NDArray]] = {}
 178
 179  def __call__(self, array: _NDArray, shape: tuple[int, int]) -> _NDArray:
 180    assert using_numba
 181    assert array.ndim in (2, 3), array.ndim
 182    _check_eq(len(shape), 2)
 183    dtype = array.dtype
 184    a = array[..., None] if array.ndim == 2 else array
 185    height, width, ch = a.shape
 186    new_height, new_width = shape
 187    if height % new_height != 0 or width % new_width != 0:
 188      raise ValueError(f'Shape {array.shape} not a multiple of {shape}.')
 189    block_height, block_width = height // new_height, width // new_width
 190
 191    def func(array: _NDArray) -> _NDArray:
 192      new_height = array.shape[0] // block_height
 193      new_width = array.shape[1] // block_width
 194      result = np.empty((new_height, new_width, ch), dtype)
 195      totals = np.empty(ch, dtype)
 196      factor = dtype.type(1.0 / (block_height * block_width))
 197      for y in numba.prange(new_height):  # pylint: disable=not-an-iterable
 198        for x in range(new_width):
 199          # Introducing "y2, x2 = y * block_height, x * block_width" is actually slower.
 200          if ch == 1:  # All the branches involve compile-time constants.
 201            total = dtype.type(0.0)
 202            for yy in range(block_height):
 203              for xx in range(block_width):
 204                total += array[y * block_height + yy, x * block_width + xx, 0]
 205            result[y, x, 0] = total * factor
 206          elif ch == 3:
 207            total0 = total1 = total2 = dtype.type(0.0)
 208            for yy in range(block_height):
 209              for xx in range(block_width):
 210                total0 += array[y * block_height + yy, x * block_width + xx, 0]
 211                total1 += array[y * block_height + yy, x * block_width + xx, 1]
 212                total2 += array[y * block_height + yy, x * block_width + xx, 2]
 213            result[y, x, 0] = total0 * factor
 214            result[y, x, 1] = total1 * factor
 215            result[y, x, 2] = total2 * factor
 216          elif block_height * block_width >= 9:
 217            for c in range(ch):
 218              totals[c] = 0.0
 219            for yy in range(block_height):
 220              for xx in range(block_width):
 221                for c in range(ch):
 222                  totals[c] += array[y * block_height + yy, x * block_width + xx, c]
 223            for c in range(ch):
 224              result[y, x, c] = totals[c] * factor
 225          else:
 226            for c in range(ch):
 227              total = dtype.type(0.0)
 228              for yy in range(block_height):
 229                for xx in range(block_width):
 230                  total += array[y * block_height + yy, x * block_width + xx, c]
 231              result[y, x, c] = total * factor
 232      return result
 233
 234    signature = dtype, block_height, block_width, ch
 235    jitted_function = self._jitted_function.get(signature)
 236    if not jitted_function:
 237      if 0:
 238        print(f'Creating numba jit-wrapper for {signature}.')
 239      jitted_function = numba.njit(func, parallel=True, fastmath=True, cache=True)
 240      self._jitted_function[signature] = jitted_function
 241
 242    try:
 243      result = jitted_function(a)
 244    except RuntimeError:
 245      message = (
 246          'resampler: This runtime error may be due to a corrupt resampler/__pycache__;'
 247          ' try deleting that directory.'
 248      )
 249      print(message, file=sys.stderr, flush=True)
 250      raise
 251
 252    return result[..., 0] if array.ndim == 2 else result
 253
 254
 255_downsample_in_2d_using_box_filter = _DownsampleIn2dUsingBoxFilter()
 256
 257
 258@numba.njit(nogil=True, fastmath=True, cache=True)  # type: ignore[misc]
 259def _numba_serial_csr_dense_mult(
 260    indptr: _NDArray,
 261    indices: _NDArray,
 262    data: _NDArray,
 263    src: _NDArray,
 264    dst: _NDArray,
 265) -> None:
 266  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 267
 268  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 269  """
 270  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 271  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 272  acc = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 273  for i in range(dst.shape[0]):
 274    acc[:] = 0
 275    for jj in range(indptr[i], indptr[i + 1]):
 276      j = indices[jj]
 277      acc += data[jj] * src[j]
 278    dst[i] = acc
 279
 280
 281# I tried using the "minimal" "parallel=" config but this did not result in any jit speedup:
 282#  parallel=dict(comprehension=False, prange=True, numpy=True, reduction=False,
 283#                setitem=False, stencil=False, fusion=False)
 284@numba.njit(parallel=True, fastmath=True, cache=True)  # type: ignore[misc]
 285def _numba_parallel_csr_dense_mult(
 286    indptr: _NDArray,
 287    indices: _NDArray,
 288    data: _NDArray,
 289    src: _NDArray,
 290    dst: _NDArray,
 291) -> None:
 292  """Faster version of scipy.sparse._sparsetools.csr_matvecs().
 293
 294  The single-threaded numba-jitted code is about 2x faster than the scipy C++.
 295  The introduction of parallel omp threads provides another 2-4x speedup.
 296  However, "parallel=True" leads to slow jitting (~3 s), so we cache the jitted code on disk.
 297  """
 298  assert indptr.ndim == indices.ndim == data.ndim == 1 and src.ndim == dst.ndim == 2
 299  assert len(indptr) == dst.shape[0] + 1 and src.shape[1] == dst.shape[1]
 300  acc0 = data[0] * src[0]  # Dummy initialization value, to infer correct shape and dtype.
 301  # Default is static scheduling, which is fine.
 302  for i in numba.prange(dst.shape[0]):  # pylint: disable=not-an-iterable
 303    acc = np.zeros_like(acc0)  # Numba automatically hoists the allocation outside the loop.
 304    for jj in range(indptr[i], indptr[i + 1]):
 305      j = indices[jj]
 306      acc += data[jj] * src[j]
 307    dst[i] = acc
 308
 309
 310@dataclasses.dataclass
 311class _Arraylib(abc.ABC, Generic[_Array]):
 312  """Abstract base class for abstraction of array libraries."""
 313
 314  arraylib: str
 315  """Name of array library (e.g., `'numpy'`, `'tensorflow'`, `'torch'`, `'jax'`)."""
 316
 317  array: _Array
 318
 319  @staticmethod
 320  @abc.abstractmethod
 321  def recognize(array: Any) -> bool:
 322    """Return True if `array` is recognized by this _Arraylib."""
 323
 324  @abc.abstractmethod
 325  def numpy(self) -> _NDArray:
 326    """Return a `numpy` version of `self.array`."""
 327
 328  @abc.abstractmethod
 329  def dtype(self) -> _DType:
 330    """Return the equivalent of `self.array.dtype` as a `numpy` `dtype`."""
 331
 332  @abc.abstractmethod
 333  def astype(self, dtype: _DTypeLike) -> _Array:
 334    """Return the equivalent of `self.array.astype(dtype, copy=False)` with `numpy` `dtype`."""
 335
 336  def reshape(self, shape: tuple[int, ...]) -> _Array:
 337    """Return the equivalent of `self.array.reshape(shape)`."""
 338    return self.array.reshape(shape)
 339
 340  def possibly_make_contiguous(self) -> _Array:
 341    """Return a contiguous copy of `self.array` or just `self.array` if already contiguous."""
 342    return self.array
 343
 344  @abc.abstractmethod
 345  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _Array:
 346    """Return the equivalent of `self.array.clip(low, high, dtype=dtype)` with `numpy` `dtype`."""
 347
 348  @abc.abstractmethod
 349  def square(self) -> _Array:
 350    """Return the equivalent of `np.square(self.array)`."""
 351
 352  @abc.abstractmethod
 353  def sqrt(self) -> _Array:
 354    """Return the equivalent of `np.sqrt(self.array)`."""
 355
 356  def getitem(self, indices: Any) -> _Array:
 357    """Return the equivalent of `self.array[indices]` (a "gather" operation)."""
 358    return self.array[indices]
 359
 360  @abc.abstractmethod
 361  def where(self, if_true: Any, if_false: Any) -> _Array:
 362    """Return the equivalent of `np.where(self.array, if_true, if_false)`."""
 363
 364  @abc.abstractmethod
 365  def transpose(self, axes: Sequence[int]) -> _Array:
 366    """Return the equivalent of `np.transpose(self.array, axes)`."""
 367
 368  @abc.abstractmethod
 369  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 370    """Return the best order in which to process dims for resizing `self.array` to `dst_shape`."""
 371
 372  @abc.abstractmethod
 373  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _Array:
 374    """Return the multiplication of the `sparse` matrix and `self.array`."""
 375
 376  @staticmethod
 377  @abc.abstractmethod
 378  def concatenate(arrays: Sequence[_Array], axis: int) -> _Array:
 379    """Return the equivalent of `np.concatenate(arrays, axis)`."""
 380
 381  @staticmethod
 382  @abc.abstractmethod
 383  def einsum(subscripts: str, *operands: _Array) -> _Array:
 384    """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 385
 386  @staticmethod
 387  @abc.abstractmethod
 388  def make_sparse_matrix(
 389      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 390  ) -> _Array:
 391    """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 392    However, the indices must be ordered and unique."""
 393
 394
 395class _NumpyArraylib(_Arraylib[_NDArray]):
 396  """Numpy implementation of the array abstraction."""
 397
 398  def __init__(self, array: _NDArray) -> None:
 399    super().__init__(arraylib='numpy', array=np.asarray(array))
 400
 401  @staticmethod
 402  def recognize(array: Any) -> bool:
 403    return isinstance(array, np.ndarray)
 404
 405  def numpy(self) -> _NDArray:
 406    return self.array
 407
 408  def dtype(self) -> _DType:
 409    dtype: _DType = self.array.dtype
 410    return dtype
 411
 412  def astype(self, dtype: _DTypeLike) -> _NDArray:
 413    return self.array.astype(dtype, copy=False)
 414
 415  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _NDArray:
 416    return self.array.clip(low, high, dtype=dtype)
 417
 418  def square(self) -> _NDArray:
 419    return np.square(self.array)
 420
 421  def sqrt(self) -> _NDArray:
 422    return np.sqrt(self.array)
 423
 424  def where(self, if_true: Any, if_false: Any) -> _NDArray:
 425    condition = self.array
 426    return np.where(condition, if_true, if_false)
 427
 428  def transpose(self, axes: Sequence[int]) -> _NDArray:
 429    return np.transpose(self.array, tuple(axes))
 430
 431  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 432    # Our heuristics: (1) a dimension with small scaling (especially minification) gets priority,
 433    # and (2) timings show preference to resizing dimensions with larger strides first.
 434    # (Of course, tensorflow.Tensor lacks strides, so (2) does not apply.)
 435    # The optimal ordering might be related to the logic in np.einsum_path().  (Unfortunately,
 436    # np.einsum() does not support the sparse multiplications that we require here.)
 437    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 438    strides: Sequence[int] = self.array.strides
 439    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 440
 441    def priority(dim: int) -> float:
 442      scaling = dst_shape[dim] / src_shape[dim]
 443      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 444
 445    return sorted(range(len(src_shape)), key=priority)
 446
 447  def premult_with_sparse(
 448      self, sparse: scipy.sparse.csr_matrix, num_threads: int | Literal['auto']
 449  ) -> _NDArray:
 450    assert self.array.ndim == sparse.ndim == 2 and sparse.shape[1] == self.array.shape[0]
 451    # Empirically faster than with default numba.config.NUMBA_NUM_THREADS (e.g., 24).
 452    if using_numba:
 453      num_threads2 = min(6, os.cpu_count()) if num_threads == 'auto' else num_threads
 454      src = np.ascontiguousarray(self.array)  # Like .ravel() in _mul_multivector().
 455      dtype = np.result_type(sparse.dtype, src.dtype)
 456      dst = np.empty((sparse.shape[0], src.shape[1]), dtype)
 457      num_scalar_multiplies = len(sparse.data) * src.shape[1]
 458      is_small_size = num_scalar_multiplies < 200_000
 459      if is_small_size or num_threads2 == 1:
 460        _numba_serial_csr_dense_mult(sparse.indptr, sparse.indices, sparse.data, src, dst)
 461      else:
 462        numba.set_num_threads(num_threads2)
 463        _numba_parallel_csr_dense_mult(sparse.indptr, sparse.indices, sparse.data, src, dst)
 464      return dst
 465
 466    # Note that sicpy.sparse does not use multithreading.  The "@" operation
 467    # calls _spbase.__matmul__() -> _spbase._mul_dispatch() -> _cs_matrix._mul_multivector() ->
 468    # scipy.sparse._sparsetools.csr_matvecs() in
 469    # https://github.com/scipy/scipy/blob/main/scipy/sparse/sparsetools/csr.h
 470    # which iteratively calls the (in theory, LEVEL 1 BLAS) function axpy() in
 471    # https://github.com/scipy/scipy/blob/main/scipy/sparse/sparsetools/dense.h
 472    return sparse @ self.array
 473
 474  @staticmethod
 475  def concatenate(arrays: Sequence[_NDArray], axis: int) -> _NDArray:
 476    return np.concatenate(arrays, axis)
 477
 478  @staticmethod
 479  def einsum(subscripts: str, *operands: _NDArray) -> _NDArray:
 480    return np.einsum(subscripts, *operands, optimize=True)
 481
 482  @staticmethod
 483  def make_sparse_matrix(
 484      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 485  ) -> _NDArray:
 486    return scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=shape)
 487
 488
 489class _TensorflowArraylib(_Arraylib[_TensorflowTensor]):
 490  """Tensorflow implementation of the array abstraction."""
 491
 492  def __init__(self, array: _NDArray) -> None:
 493    import tensorflow
 494
 495    self.tf = tensorflow
 496    super().__init__(arraylib='tensorflow', array=self.tf.convert_to_tensor(array))
 497
 498  @staticmethod
 499  def recognize(array: Any) -> bool:
 500    # Eager: tensorflow.python.framework.ops.Tensor
 501    # Non-eager: tensorflow.python.ops.resource_variable_ops.ResourceVariable
 502    return type(array).__module__.startswith('tensorflow.')
 503
 504  def numpy(self) -> _NDArray:
 505    return self.array.numpy()
 506
 507  def dtype(self) -> _DType:
 508    return np.dtype(self.array.dtype.as_numpy_dtype)
 509
 510  def astype(self, dtype: _DTypeLike) -> _TensorflowTensor:
 511    return self.tf.cast(self.array, dtype)
 512
 513  def reshape(self, shape: tuple[int, ...]) -> _TensorflowTensor:
 514    return self.tf.reshape(self.array, shape)
 515
 516  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _TensorflowTensor:
 517    array = self.array
 518    if dtype is not None:
 519      array = self.tf.cast(array, dtype)
 520    return self.tf.clip_by_value(array, low, high)
 521
 522  def square(self) -> _TensorflowTensor:
 523    return self.tf.square(self.array)
 524
 525  def sqrt(self) -> _TensorflowTensor:
 526    return self.tf.sqrt(self.array)
 527
 528  def getitem(self, indices: Any) -> _TensorflowTensor:
 529    if isinstance(indices, tuple):
 530      basic = all(isinstance(x, (type(None), type(Ellipsis), int, slice)) for x in indices)
 531      if not basic:
 532        # We require tf.gather_nd(), which unfortunately requires broadcast expansion of indices.
 533        assert all(isinstance(a, np.ndarray) for a in indices)
 534        assert all(a.ndim == indices[0].ndim for a in indices)
 535        broadcast_indices = np.broadcast_arrays(*indices)  # list of np.ndarray
 536        indices_array = np.moveaxis(np.array(broadcast_indices), 0, -1)
 537        return self.tf.gather_nd(self.array, indices_array)
 538    elif _arr_dtype(indices).type in (np.uint8, np.uint16):
 539      indices = self.tf.cast(indices, np.int32)
 540    return self.tf.gather(self.array, indices)
 541
 542  def where(self, if_true: Any, if_false: Any) -> _TensorflowTensor:
 543    condition = self.array
 544    return self.tf.where(condition, if_true, if_false)
 545
 546  def transpose(self, axes: Sequence[int]) -> _TensorflowTensor:
 547    return self.tf.transpose(self.array, tuple(axes))
 548
 549  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 550    # Note that a tensorflow.Tensor does not have strides.
 551    # Our heuristic is to process dimension 1 first iff dimension 0 is upsampling.  Improve?
 552    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 553    dims = list(range(len(src_shape)))
 554    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 555      dims[:2] = [1, 0]
 556    return dims
 557
 558  def premult_with_sparse(
 559      self, sparse: tf.sparse.SparseTensor, num_threads: int | Literal['auto']
 560  ) -> _TensorflowTensor:
 561    import tensorflow as tf
 562
 563    del num_threads
 564    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 565      sparse = sparse.with_values(_arr_astype(sparse.values, _arr_dtype(self.array)))
 566    return tf.sparse.sparse_dense_matmul(sparse, self.array)
 567
 568  @staticmethod
 569  def concatenate(arrays: Sequence[_TensorflowTensor], axis: int) -> _TensorflowTensor:
 570    import tensorflow as tf
 571
 572    return tf.concat(arrays, axis)
 573
 574  @staticmethod
 575  def einsum(subscripts: str, *operands: _TensorflowTensor) -> _TensorflowTensor:
 576    import tensorflow as tf
 577
 578    return tf.einsum(subscripts, *operands, optimize='greedy')
 579
 580  @staticmethod
 581  def make_sparse_matrix(
 582      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 583  ) -> _TensorflowTensor:
 584    import tensorflow as tf
 585
 586    indices = np.vstack((row_ind, col_ind)).T
 587    return tf.sparse.SparseTensor(indices, data, shape)
 588
 589
 590# pylint: disable=missing-function-docstring
 591
 592
 593class _TorchArraylib(_Arraylib[_TorchTensor]):
 594  """Torch implementation of the array abstraction."""
 595
 596  def __init__(self, array: _NDArray) -> None:
 597    import torch
 598
 599    self.torch = torch
 600    super().__init__(arraylib='torch', array=self.torch.as_tensor(array))
 601
 602  @staticmethod
 603  def recognize(array: Any) -> bool:
 604    return type(array).__module__ == 'torch'
 605
 606  def numpy(self) -> _NDArray:
 607    return self.array.numpy()
 608
 609  def dtype(self) -> _DType:
 610    numpy_type = {
 611        self.torch.float32: np.float32,
 612        self.torch.float64: np.float64,
 613        self.torch.complex64: np.complex64,
 614        self.torch.complex128: np.complex128,
 615        self.torch.uint8: np.uint8,  # No uint16, uint32, uint64.
 616        self.torch.int16: np.int16,
 617        self.torch.int32: np.int32,
 618        self.torch.int64: np.int64,
 619    }[self.array.dtype]
 620    return np.dtype(numpy_type)
 621
 622  def astype(self, dtype: _DTypeLike) -> _TorchTensor:
 623    torch_type = {
 624        np.float32: self.torch.float32,
 625        np.float64: self.torch.float64,
 626        np.complex64: self.torch.complex64,
 627        np.complex128: self.torch.complex128,
 628        np.uint8: self.torch.uint8,  # No uint16, uint32, uint64.
 629        np.int16: self.torch.int16,
 630        np.int32: self.torch.int32,
 631        np.int64: self.torch.int64,
 632    }[np.dtype(dtype).type]
 633    return self.array.type(torch_type)
 634
 635  def possibly_make_contiguous(self) -> _TorchTensor:
 636    return self.array.contiguous()
 637
 638  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _TorchTensor:
 639    array = self.array
 640    array = _arr_astype(array, dtype) if dtype is not None else array
 641    return array.clip(low, high)
 642
 643  def square(self) -> _TorchTensor:
 644    return self.array.square()
 645
 646  def sqrt(self) -> _TorchTensor:
 647    return self.array.sqrt()
 648
 649  def getitem(self, indices: Any) -> _TorchTensor:
 650    if not isinstance(indices, tuple):
 651      indices = indices.type(self.torch.int64)
 652    return self.array[indices]  # pylint: disable=unsubscriptable-object
 653
 654  def where(self, if_true: Any, if_false: Any) -> _TorchTensor:
 655    condition = self.array
 656    return if_true.where(condition, if_false)
 657
 658  def transpose(self, axes: Sequence[int]) -> _TorchTensor:
 659    return self.torch.permute(self.array, tuple(axes))
 660
 661  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 662    # Similar to `_NumpyArraylib`.  We access `array.stride()` instead of `array.strides`.
 663    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 664    strides: Sequence[int] = self.array.stride()
 665    largest_stride_dim = max(range(len(src_shape)), key=lambda dim: strides[dim])
 666
 667    def priority(dim: int) -> float:
 668      scaling = dst_shape[dim] / src_shape[dim]
 669      return scaling * ((0.49 if scaling < 1.0 else 0.65) if dim == largest_stride_dim else 1.0)
 670
 671    return sorted(range(len(src_shape)), key=priority)
 672
 673  def premult_with_sparse(self, sparse: Any, num_threads: int | Literal['auto']) -> _TorchTensor:
 674    del num_threads
 675    if np.issubdtype(_arr_dtype(self.array), np.complexfloating):
 676      sparse = _arr_astype(sparse, _arr_dtype(self.array))
 677    return sparse @ self.array  # Calls torch.sparse.mm().
 678
 679  @staticmethod
 680  def concatenate(arrays: Sequence[_TorchTensor], axis: int) -> _TorchTensor:
 681    import torch
 682
 683    return torch.cat(tuple(arrays), axis)
 684
 685  @staticmethod
 686  def einsum(subscripts: str, *operands: _TorchTensor) -> _TorchTensor:
 687    import torch
 688
 689    operands = tuple(torch.as_tensor(operand) for operand in operands)
 690    if any(np.issubdtype(_arr_dtype(array), np.complexfloating) for array in operands):
 691      operands = tuple(
 692          _arr_astype(array, _complex_precision(_arr_dtype(array))) for array in operands
 693      )
 694    return torch.einsum(subscripts, *operands)
 695
 696  @staticmethod
 697  def make_sparse_matrix(
 698      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 699  ) -> _TorchTensor:
 700    import torch
 701
 702    indices = np.vstack((row_ind, col_ind))
 703    return torch.sparse_coo_tensor(torch.as_tensor(indices), torch.as_tensor(data), shape)
 704    # .coalesce() is unnecessary because indices/data are already merged.
 705
 706
 707# pylint: enable=missing-function-docstring
 708
 709
 710class _JaxArraylib(_Arraylib[_JaxArray]):
 711  """Jax implementation of the array abstraction."""
 712
 713  def __init__(self, array: _NDArray) -> None:
 714    import jax.numpy
 715
 716    self.jnp = jax.numpy
 717    super().__init__(arraylib='jax', array=self.jnp.asarray(array))
 718
 719  @staticmethod
 720  def recognize(array: Any) -> bool:
 721    # e.g., jaxlib.xla_extension.DeviceArray, jax.interpreters.ad.JVPTracer
 722    return type(array).__module__.startswith(('jaxlib.', 'jax.'))
 723
 724  def numpy(self) -> _NDArray:
 725    # 2023-01-09: jax 0.3.17 "DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead."
 726    # Whereas array.to_py() and np.asarray(array) may return a non-writable np.ndarray,
 727    # np.array(array) always returns a writable array but the copy may be more costly.
 728    # return self.array.to_py()
 729    return np.asarray(self.array)
 730
 731  def dtype(self) -> _DType:
 732    return np.dtype(self.array.dtype)
 733
 734  def astype(self, dtype: _DTypeLike) -> _JaxArray:
 735    return self.array.astype(dtype)  # (copy=False is unavailable)
 736
 737  def possibly_make_contiguous(self) -> _JaxArray:
 738    return self.array.copy()
 739
 740  def clip(self, low: Any, high: Any, dtype: _DTypeLike = None) -> _JaxArray:
 741    array = self.array
 742    if dtype is not None:
 743      array = array.astype(dtype)  # (copy=False is unavailable)
 744    return self.jnp.clip(array, low, high)
 745
 746  def square(self) -> _JaxArray:
 747    return self.jnp.square(self.array)
 748
 749  def sqrt(self) -> _JaxArray:
 750    return self.jnp.sqrt(self.array)
 751
 752  def where(self, if_true: Any, if_false: Any) -> _JaxArray:
 753    condition = self.array
 754    return self.jnp.where(condition, if_true, if_false)
 755
 756  def transpose(self, axes: Sequence[int]) -> _JaxArray:
 757    return self.jnp.transpose(self.array, tuple(axes))
 758
 759  def best_dims_order_for_resize(self, dst_shape: tuple[int, ...]) -> list[int]:
 760    # Jax/XLA does not have strides.  Arrays are contiguous, almost always in C order; see
 761    # https://github.com/google/jax/discussions/7544#discussioncomment-1197038.
 762    # We use a heuristic similar to `_TensorflowArraylib`.
 763    src_shape: tuple[int, ...] = self.array.shape[: len(dst_shape)]
 764    dims = list(range(len(src_shape)))
 765    if len(dims) > 1 and dst_shape[0] / src_shape[0] > 1.0:
 766      dims[:2] = [1, 0]
 767    return dims
 768
 769  def premult_with_sparse(
 770      self, sparse: jax.experimental.sparse.BCOO, num_threads: int | Literal['auto']
 771  ) -> _JaxArray:
 772    del num_threads
 773    return sparse @ self.array  # Calls jax.bcoo_multiply_dense().
 774
 775  @staticmethod
 776  def concatenate(arrays: Sequence[_JaxArray], axis: int) -> _JaxArray:
 777    import jax.numpy as jnp
 778
 779    return jnp.concatenate(arrays, axis)
 780
 781  @staticmethod
 782  def einsum(subscripts: str, *operands: _JaxArray) -> _JaxArray:
 783    import jax.numpy as jnp
 784
 785    return jnp.einsum(subscripts, *operands, optimize='greedy')
 786
 787  @staticmethod
 788  def make_sparse_matrix(
 789      data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int]
 790  ) -> _JaxArray:
 791    # https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html
 792    import jax.experimental.sparse
 793
 794    indices = np.vstack((row_ind, col_ind)).T
 795    return jax.experimental.sparse.BCOO(
 796        (data, indices), shape=shape, indices_sorted=True, unique_indices=True
 797    )
 798
 799
 800_CANDIDATE_ARRAYLIBS = {
 801    'numpy': _NumpyArraylib,
 802    'tensorflow': _TensorflowArraylib,
 803    'torch': _TorchArraylib,
 804    'jax': _JaxArraylib,
 805}
 806
 807
 808def is_available(arraylib: str) -> bool:
 809  """Return whether the array library (e.g. 'tensorflow') is available as an installed package."""
 810  return importlib.util.find_spec(arraylib) is not None  # Faster than trying to import it.
 811
 812
 813_DICT_ARRAYLIBS = {
 814    arraylib: cls for arraylib, cls in _CANDIDATE_ARRAYLIBS.items() if is_available(arraylib)
 815}
 816
 817ARRAYLIBS = list(_DICT_ARRAYLIBS)
 818"""Array libraries supported automatically in the resize and resampling operations.
 819
 820- The library is selected automatically based on the type of the `array` function parameter.
 821
 822- The class `_Arraylib` provides library-specific implementations of needed basic functions.
 823
 824- The `_arr_*()` functions dispatch the `_Arraylib` methods based on the array type.
 825"""
 826
 827
 828def _as_arr(array: _Array, /) -> _Arraylib[_Array]:
 829  """Return `array` wrapped as an `_Arraylib` for dispatch of functions."""
 830  for cls in _DICT_ARRAYLIBS.values():
 831    if cls.recognize(array):
 832      return cls(array)  # type: ignore[abstract]
 833  raise ValueError(f'{array} {type(array)} {type(array).__module__} unrecognized by {ARRAYLIBS}.')
 834
 835
 836def _arr_arraylib(array: _Array, /) -> str:
 837  """Return the name of the `Arraylib` representing `array`."""
 838  return _as_arr(array).arraylib
 839
 840
 841def _arr_numpy(array: _Array, /) -> _NDArray:
 842  """Return a `numpy` version of `array`."""
 843  return _as_arr(array).numpy()
 844
 845
 846def _arr_dtype(array: _Array, /) -> _DType:
 847  """Return the equivalent of `array.dtype` as a `numpy` `dtype`."""
 848  return _as_arr(array).dtype()
 849
 850
 851def _arr_astype(array: _Array, dtype: _DTypeLike, /) -> _Array:
 852  """Return the equivalent of `array.astype(dtype)` with `numpy` `dtype`."""
 853  return _as_arr(array).astype(dtype)
 854
 855
 856def _arr_reshape(array: _Array, shape: tuple[int, ...], /) -> _Array:
 857  """Return the equivalent of `array.reshape(shape)."""
 858  return _as_arr(array).reshape(shape)
 859
 860
 861def _arr_possibly_make_contiguous(array: _Array, /) -> _Array:
 862  """Return a contiguous copy of `array` or just `array` if already contiguous."""
 863  return _as_arr(array).possibly_make_contiguous()
 864
 865
 866def _arr_clip(array: _Array, low: _Array, high: _Array, /, dtype: _DTypeLike = None) -> _Array:
 867  """Return the equivalent of `array.clip(low, high, dtype)` with `numpy` `dtype`."""
 868  return _as_arr(array).clip(low, high, dtype)
 869
 870
 871def _arr_square(array: _Array, /) -> _Array:
 872  """Return the equivalent of `np.square(array)`."""
 873  return _as_arr(array).square()
 874
 875
 876def _arr_sqrt(array: _Array, /) -> _Array:
 877  """Return the equivalent of `np.sqrt(array)`."""
 878  return _as_arr(array).sqrt()
 879
 880
 881def _arr_getitem(array: _Array, indices: _Array, /) -> _Array:
 882  """Return the equivalent of `array[indices]`."""
 883  return _as_arr(array).getitem(indices)
 884
 885
 886def _arr_where(condition: _Array, if_true: _Array, if_false: _Array, /) -> _Array:
 887  """Return the equivalent of `np.where(condition, if_true, if_false)`."""
 888  return _as_arr(condition).where(if_true, if_false)
 889
 890
 891def _arr_transpose(array: _Array, axes: Sequence[int], /) -> _Array:
 892  """Return the equivalent of `np.transpose(array, axes)`."""
 893  return _as_arr(array).transpose(axes)
 894
 895
 896def _arr_best_dims_order_for_resize(array: _Array, dst_shape: tuple[int, ...], /) -> list[int]:
 897  """Return the best order in which to process dims for resizing `array` to `dst_shape`."""
 898  return _as_arr(array).best_dims_order_for_resize(dst_shape)
 899
 900
 901def _arr_matmul_sparse_dense(
 902    sparse: Any, dense: _Array, /, *, num_threads: int | Literal['auto'] = 'auto'
 903) -> _Array:
 904  """Return the multiplication of the `sparse` and `dense` matrices."""
 905  assert num_threads == 'auto' or num_threads >= 1
 906  return _as_arr(dense).premult_with_sparse(sparse, num_threads)
 907
 908
 909def _arr_concatenate(arrays: Sequence[_Array], axis: int, /) -> _Array:
 910  """Return the equivalent of `np.concatenate(arrays, axis)`."""
 911  arraylib = _arr_arraylib(arrays[0])
 912  return _DICT_ARRAYLIBS[arraylib].concatenate(arrays, axis)  # type: ignore
 913
 914
 915def _arr_einsum(subscripts: str, /, *operands: _Array) -> _Array:
 916  """Return the equivalent of `np.einsum(subscripts, *operands, optimize=True)`."""
 917  arraylib = _arr_arraylib(operands[0])
 918  return _DICT_ARRAYLIBS[arraylib].einsum(subscripts, *operands)  # type: ignore
 919
 920
 921def _arr_swapaxes(array: _Array, axis1: int, axis2: int, /) -> _Array:
 922  """Return the equivalent of `np.swapaxes(array, axis1, axis2)`."""
 923  ndim = len(array.shape)
 924  assert 0 <= axis1 < ndim and 0 <= axis2 < ndim, (axis1, axis2, ndim)
 925  axes = list(range(ndim))
 926  axes[axis1] = axis2
 927  axes[axis2] = axis1
 928  return _arr_transpose(array, axes)
 929
 930
 931def _arr_moveaxis(array: _Array, source: int, destination: int, /) -> _Array:
 932  """Return the equivalent of `np.moveaxis(array, source, destination)`."""
 933  ndim = len(array.shape)
 934  assert 0 <= source < ndim and 0 <= destination < ndim, (source, destination, ndim)
 935  axes = [n for n in range(ndim) if n != source]
 936  axes.insert(destination, source)
 937  return _arr_transpose(array, axes)
 938
 939
 940def _make_sparse_matrix(
 941    data: _NDArray, row_ind: _NDArray, col_ind: _NDArray, shape: tuple[int, int], arraylib: str, /
 942) -> Any:
 943  """Return the equivalent of `scipy.sparse.csr_matrix(data, (row_ind, col_ind), shape=shape)`.
 944  However, indices must be ordered and unique."""
 945  return _DICT_ARRAYLIBS[arraylib].make_sparse_matrix(data, row_ind, col_ind, shape)
 946
 947
 948def _make_array(array: _ArrayLike, arraylib: str, /) -> Any:
 949  """Return an array from the library `arraylib` initialized with the `numpy` `array`."""
 950  return _DICT_ARRAYLIBS[arraylib](np.asarray(array)).array  # type: ignore[abstract]
 951
 952
 953# Because np.ndarray supports strides, np.moveaxis() and np.permute() are constant-time.
 954# However, ndarray.reshape() often creates a copy of the array if the data is non-contiguous,
 955# e.g. dim=1 in an RGB image.
 956#
 957# In contrast, tf.Tensor does not support strides, so tf.transpose() returns a new permuted
 958# tensor.  However, tf.reshape() is always efficient.
 959
 960
 961def _block_shape_with_min_size(
 962    shape: tuple[int, ...], min_size: int, compact: bool = True
 963) -> tuple[int, ...]:
 964  """Return shape of block (of size at least `min_size`) to subdivide shape."""
 965  if math.prod(shape) < min_size:
 966    raise ValueError(f'Shape {shape} smaller than min_size {min_size}.')
 967  if compact:
 968    root = int(math.ceil(min_size ** (1 / len(shape))))
 969    block_shape = np.minimum(shape, root)
 970    for dim in range(len(shape)):
 971      if block_shape[dim] == 2 and block_shape.prod() >= min_size * 2:
 972        block_shape[dim] = 1
 973    for dim in range(len(shape) - 1, -1, -1):
 974      if block_shape.prod() < min_size:
 975        block_shape[dim] = shape[dim]
 976  else:
 977    block_shape = np.ones_like(shape)
 978    for dim in range(len(shape) - 1, -1, -1):
 979      if block_shape.prod() < min_size:
 980        block_shape[dim] = min(shape[dim], math.ceil(min_size / block_shape.prod()))
 981  return tuple(block_shape)
 982
 983
 984def _array_split(array: _Array, axis: int, num_sections: int) -> list[Any]:
 985  """Split `array` into `num_sections` along `axis`."""
 986  assert 0 <= axis < len(array.shape)
 987  assert 1 <= num_sections <= array.shape[axis]
 988
 989  if 0:
 990    split = np.array_split(array, num_sections, axis=axis)  # Numpy-specific.
 991
 992  else:
 993    # Adapted from https://github.com/numpy/numpy/blob/main/numpy/lib/shape_base.py#L739-L792.
 994    num_total = array.shape[axis]
 995    num_each, num_extra = divmod(num_total, num_sections)
 996    section_sizes = [0] + num_extra * [num_each + 1] + (num_sections - num_extra) * [num_each]
 997    div_points = np.array(section_sizes).cumsum()
 998    split = []
 999    tmp = _arr_swapaxes(array, axis, 0)
1000    for i in range(num_sections):
1001      split.append(_arr_swapaxes(tmp[div_points[i] : div_points[i + 1]], axis, 0))
1002
1003  return split
1004
1005
1006def _split_array_into_blocks(array: _Array, block_shape: Sequence[int], start_axis: int = 0) -> Any:
1007  """Split `array` into nested lists of blocks of size at most `block_shape`."""
1008  # See https://stackoverflow.com/a/50305924.  (If the block_shape is known to
1009  # exactly partition the array, see https://stackoverflow.com/a/16858283.)
1010  if len(block_shape) > len(array.shape):
1011    raise ValueError(f'Block ndim {len(block_shape)} > array ndim {len(array.shape)}.')
1012  if start_axis == len(block_shape):
1013    return array
1014
1015  num_sections = math.ceil(array.shape[start_axis] / block_shape[start_axis])
1016  split = _array_split(array, start_axis, num_sections)
1017  return [_split_array_into_blocks(split_a, block_shape, start_axis + 1) for split_a in split]
1018
1019
1020def _map_function_over_blocks(blocks: Any, func: Callable[[Any], Any]) -> Any:
1021  """Apply `func` to each block in the nested lists of `blocks`."""
1022  if isinstance(blocks, list):
1023    return [_map_function_over_blocks(block, func) for block in blocks]
1024  return func(blocks)
1025
1026
1027def _merge_array_from_blocks(blocks: Any, axis: int = 0) -> Any:
1028  """Merge an array from the nested lists of array blocks in `blocks`."""
1029  # More general than np.block() because the blocks can have additional dims.
1030  if isinstance(blocks, list):
1031    new_blocks = [_merge_array_from_blocks(block, axis + 1) for block in blocks]
1032    return _arr_concatenate(new_blocks, axis)
1033  return blocks
1034
1035
1036@dataclasses.dataclass(frozen=True)
1037class Gridtype(abc.ABC):
1038  """Abstract base class for grid-types such as `'dual'` and `'primal'`.
1039
1040  In resampling operations, the grid-type may be specified separately as `src_gridtype` for the
1041  source domain and `dst_gridtype` for the destination domain.  Moreover, the grid-type may be
1042  specified per domain dimension.
1043
1044  Examples:
1045    `resize(source, shape, gridtype='primal')`  # Sets both src and dst to be `'primal'` grids.
1046
1047    `resize(source, shape, src_gridtype=['dual', 'primal'],
1048            dst_gridtype='dual')`  # Source is `'dual'` in dim0 and `'primal'` in dim1.
1049  """
1050
1051  name: str
1052  """Gridtype name."""
1053
1054  @abc.abstractmethod
1055  def min_size(self) -> int:
1056    """Return the necessary minimum number of grid samples."""
1057
1058  @abc.abstractmethod
1059  def size_in_samples(self, size: int, /) -> int:
1060    """Return the domain size in units of inter-sample spacing."""
1061
1062  @abc.abstractmethod
1063  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1064    """Return [0.0, 1.0] coordinates given [0, size - 1] indices."""
1065
1066  @abc.abstractmethod
1067  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1068    """Return location x given coordinates [0.0, 1.0], where x == 0.0 is the first grid sample
1069    and x == size - 1.0 is the last grid sample."""
1070
1071  @abc.abstractmethod
1072  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1073    """Map integer sample indices to interior ones using boundary reflection."""
1074
1075  @abc.abstractmethod
1076  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1077    """Map integer sample indices to interior ones using wrapping."""
1078
1079  @abc.abstractmethod
1080  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1081    """Map integer sample indices to interior ones using reflect-clamp."""
1082
1083
1084class DualGridtype(Gridtype):
1085  """Samples are at the center of cells in a uniform partition of the domain.
1086
1087  For a unit-domain dimension with N samples, each sample 0 <= i < N has position (i + 0.5) / N,
1088  e.g., [0.125, 0.375, 0.625, 0.875] for N = 4.
1089  """
1090
1091  def __init__(self) -> None:
1092    super().__init__(name='dual')
1093
1094  def min_size(self) -> int:
1095    return 1
1096
1097  def size_in_samples(self, size: int, /) -> int:
1098    return size
1099
1100  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1101    return (index + 0.5) / size
1102
1103  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1104    return point * size - 0.5
1105
1106  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1107    index = np.mod(index, size * 2)
1108    return np.where(index < size, index, 2 * size - 1 - index)
1109
1110  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1111    return np.mod(index, size)
1112
1113  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1114    return np.minimum(np.where(index < 0, -1 - index, index), size - 1)
1115
1116
1117class PrimalGridtype(Gridtype):
1118  """Samples are at the vertices of cells in a uniform partition of the domain.
1119
1120  For a unit-domain dimension with N samples, each sample 0 <= i < N has position i / (N - 1),
1121  e.g., [0, 1/3, 2/3, 1] for N = 4.
1122  """
1123
1124  def __init__(self) -> None:
1125    super().__init__(name='primal')
1126
1127  def min_size(self) -> int:
1128    return 2
1129
1130  def size_in_samples(self, size: int, /) -> int:
1131    return size - 1
1132
1133  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1134    return index / (size - 1)
1135
1136  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1137    return point * (size - 1)
1138
1139  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1140    index = np.mod(index, size * 2 - 2)
1141    return np.where(index < size, index, 2 * size - 2 - index)
1142
1143  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1144    return np.mod(index, size - 1)
1145
1146  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1147    return np.minimum(np.abs(index), size - 1)
1148
1149
1150_DICT_GRIDTYPES = {
1151    'dual': DualGridtype(),
1152    'primal': PrimalGridtype(),
1153}
1154
1155GRIDTYPES = list(_DICT_GRIDTYPES)
1156r"""Shortcut names for the two predefined grid types (specified per dimension):
1157
1158| `gridtype` | `'dual'`<br/>`DualGridtype()`<br/>(default) | `'primal'`<br/>`PrimalGridtype()`<br/>&nbsp; |
1159| --- |:---:|:---:|
1160| Sample positions in 2D<br/>and in 1D at different resolutions | ![Dual](https://github.com/hhoppe/resampler/raw/main/media/dual_grid_small.png) | ![Primal](https://github.com/hhoppe/resampler/raw/main/media/primal_grid_small.png) |
1161| Nesting of samples across resolutions | The samples positions do *not* nest. | The *even* samples remain at coarser scale. |
1162| Number $N_\ell$ of samples (per-dimension) at resolution level $\ell$ | $N_\ell=2^\ell$ | $N_\ell=2^\ell+1$ |
1163| Position of sample index $i$ within domain $[0, 1]$ | $\frac{i + 0.5}{N}$ ("half-integer" coordinates) | $\frac{i}{N-1}$ |
1164| Image resolutions ($N_\ell\times N_\ell$) for dyadic scales | $1\times1, ~~2\times2, ~~4\times4, ~~8\times8, ~\ldots$ | $2\times2, ~~3\times3, ~~5\times5, ~~9\times9, ~\ldots$ |
1165
1166See the source code for extensibility.
1167"""
1168
1169
1170def _get_gridtype(gridtype: str | Gridtype) -> Gridtype:
1171  """Return a `Gridtype`, which can be specified as a name in `GRIDTYPES`."""
1172  return gridtype if isinstance(gridtype, Gridtype) else _DICT_GRIDTYPES[gridtype]
1173
1174
1175def _get_gridtypes(
1176    gridtype: str | Gridtype | None,
1177    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1178    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None,
1179    src_ndim: int,
1180    dst_ndim: int,
1181) -> tuple[list[Gridtype], list[Gridtype]]:
1182  """Return per-dim source and destination grid types given all parameters."""
1183  if gridtype is None and src_gridtype is None and dst_gridtype is None:
1184    gridtype = 'dual'
1185  if gridtype is not None:
1186    if src_gridtype is not None:
1187      raise ValueError('Cannot have both gridtype and src_gridtype.')
1188    if dst_gridtype is not None:
1189      raise ValueError('Cannot have both gridtype and dst_gridtype.')
1190    src_gridtype = dst_gridtype = gridtype
1191  src_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(src_gridtype), src_ndim)]
1192  dst_gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(dst_gridtype), dst_ndim)]
1193  return src_gridtype2, dst_gridtype2
1194
1195
1196@dataclasses.dataclass(frozen=True)
1197class RemapCoordinates(abc.ABC):
1198  """Abstract base class for modifying the specified coordinates prior to evaluating the
1199  reconstruction kernels."""
1200
1201  @abc.abstractmethod
1202  def __call__(self, point: _NDArray, /) -> _NDArray:
1203    ...
1204
1205
1206class NoRemapCoordinates(RemapCoordinates):
1207  """The coordinates are not remapped."""
1208
1209  def __call__(self, point: _NDArray, /) -> _NDArray:
1210    return point
1211
1212
1213class MirrorRemapCoordinates(RemapCoordinates):
1214  """The coordinates are reflected across the domain boundaries so that they lie in the unit
1215  interval.  The resulting function is continuous but not smooth across the boundaries."""
1216
1217  def __call__(self, point: _NDArray, /) -> _NDArray:
1218    point = np.mod(point, 2.0)
1219    return np.where(point >= 1.0, 2.0 - point, point)
1220
1221
1222class TileRemapCoordinates(RemapCoordinates):
1223  """The coordinates are mapped to the unit interval using a "modulo 1.0" operation.  The resulting
1224  function is generally discontinuous across the domain boundaries."""
1225
1226  def __call__(self, point: _NDArray, /) -> _NDArray:
1227    return np.mod(point, 1.0)
1228
1229
1230@dataclasses.dataclass(frozen=True)
1231class ExtendSamples(abc.ABC):
1232  """Abstract base class for replacing references to grid samples exterior to the unit domain by
1233  affine combinations of interior sample(s) and possibly the constant value (`cval`)."""
1234
1235  uses_cval: bool = False
1236  """True if some exterior samples are defined in terms of `cval`, i.e., if the computed weight
1237  is non-affine."""
1238
1239  @abc.abstractmethod
1240  def __call__(
1241      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1242  ) -> tuple[_NDArray, _NDArray]:
1243    """Detect references to exterior samples, i.e., entries of `index` that lie outside the
1244    interval [0, size), and update these indices (and possibly their associated weights) to
1245    reference only interior samples.  Return `new_index, new_weight`."""
1246
1247
1248class ReflectExtendSamples(ExtendSamples):
1249  """Find the interior sample by reflecting across domain boundaries."""
1250
1251  def __call__(
1252      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1253  ) -> tuple[_NDArray, _NDArray]:
1254    index = gridtype.reflect(index, size)
1255    return index, weight
1256
1257
1258class WrapExtendSamples(ExtendSamples):
1259  """Wrap the interior samples periodically.  For a `'primal'` grid, the last
1260  sample is ignored as its value is replaced by the first sample."""
1261
1262  def __call__(
1263      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1264  ) -> tuple[_NDArray, _NDArray]:
1265    index = gridtype.wrap(index, size)
1266    return index, weight
1267
1268
1269class ClampExtendSamples(ExtendSamples):
1270  """Use the nearest interior sample."""
1271
1272  def __call__(
1273      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1274  ) -> tuple[_NDArray, _NDArray]:
1275    index = index.clip(0, size - 1)
1276    return index, weight
1277
1278
1279class ReflectClampExtendSamples(ExtendSamples):
1280  """Extend the grid samples from [0, 1] into [-1, 0] using reflection and then define grid
1281  samples outside [-1, 1] as that of the nearest sample."""
1282
1283  def __call__(
1284      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1285  ) -> tuple[_NDArray, _NDArray]:
1286    index = gridtype.reflect_clamp(index, size)
1287    return index, weight
1288
1289
1290class BorderExtendSamples(ExtendSamples):
1291  """Let all exterior samples have the constant value (`cval`)."""
1292
1293  def __init__(self) -> None:
1294    super().__init__(uses_cval=True)
1295
1296  def __call__(
1297      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1298  ) -> tuple[_NDArray, _NDArray]:
1299    low = index < 0
1300    weight[low] = 0.0
1301    index[low] = 0
1302    high = index >= size
1303    weight[high] = 0.0
1304    index[high] = size - 1
1305    return index, weight
1306
1307
1308class ValidExtendSamples(ExtendSamples):
1309  """Assign all domain samples weight 1 and all outside samples weight 0.
1310  Compute a weighted reconstruction and divide by the reconstructed weight."""
1311
1312  def __init__(self) -> None:
1313    super().__init__(uses_cval=True)
1314
1315  def __call__(
1316      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1317  ) -> tuple[_NDArray, _NDArray]:
1318    low = index < 0
1319    weight[low] = 0.0
1320    index[low] = 0
1321    high = index >= size
1322    weight[high] = 0.0
1323    index[high] = size - 1
1324    sum_weight = weight.sum(axis=-1)
1325    nonzero_sum = sum_weight != 0.0
1326    np.divide(weight, sum_weight[..., None], out=weight, where=nonzero_sum[..., None])
1327    return index, weight
1328
1329
1330class LinearExtendSamples(ExtendSamples):
1331  """Linearly extrapolate beyond boundary samples."""
1332
1333  def __call__(
1334      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1335  ) -> tuple[_NDArray, _NDArray]:
1336    if size < 2:
1337      index = gridtype.reflect(index, size)
1338      return index, weight
1339    # For each boundary, define new columns in index and weight arrays to represent the last and
1340    # next-to-last samples.  When we later construct the sparse resize matrix, we will sum the
1341    # duplicate index entries.
1342    low = index < 0
1343    high = index >= size
1344    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 4), weight.dtype)
1345    x = index
1346    w[..., -4] = ((1 - x) * weight).sum(where=low, axis=-1)
1347    w[..., -3] = ((x) * weight).sum(where=low, axis=-1)
1348    x = (size - 1) - index
1349    w[..., -2] = ((x) * weight).sum(where=high, axis=-1)
1350    w[..., -1] = ((1 - x) * weight).sum(where=high, axis=-1)
1351    weight[low] = 0.0
1352    index[low] = 0
1353    weight[high] = 0.0
1354    index[high] = size - 1
1355    w[..., :-4] = weight
1356    weight = w
1357    new_index = np.empty(w.shape, index.dtype)
1358    new_index[..., :-4] = index
1359    # Let matrix (including zero values) be banded.
1360    new_index[..., -4:] = np.where(w[..., -4:] != 0.0, [0, 1, size - 2, size - 1], index[..., :1])
1361    index = new_index
1362    return index, weight
1363
1364
1365class QuadraticExtendSamples(ExtendSamples):
1366  """Quadratically extrapolate beyond boundary samples."""
1367
1368  def __call__(
1369      self, index: _NDArray, weight: _NDArray, size: int, gridtype: Gridtype, /
1370  ) -> tuple[_NDArray, _NDArray]:
1371    # [Keys 1981] suggests this as x[-1] = 3*x[0] - 3*x[1] + x[2], calling it "cubic precision",
1372    # but it seems just quadratic.
1373    if size < 3:
1374      index = gridtype.reflect(index, size)
1375      return index, weight
1376    low = index < 0
1377    high = index >= size
1378    w = np.empty((*weight.shape[:-1], weight.shape[-1] + 6), weight.dtype)
1379    x = index
1380    w[..., -6] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=low, axis=-1)
1381    w[..., -5] = (((-x + 2) * x) * weight).sum(where=low, axis=-1)
1382    w[..., -4] = (((0.5 * x - 0.5) * x) * weight).sum(where=low, axis=-1)
1383    x = (size - 1) - index
1384    w[..., -3] = (((0.5 * x - 0.5) * x) * weight).sum(where=high, axis=-1)
1385    w[..., -2] = (((-x + 2) * x) * weight).sum(where=high, axis=-1)
1386    w[..., -1] = (((0.5 * x - 1.5) * x + 1) * weight).sum(where=high, axis=-1)
1387    weight[low] = 0.0
1388    index[low] = 0
1389    weight[high] = 0.0
1390    index[high] = size - 1
1391    w[..., :-6] = weight
1392    weight = w
1393    new_index = np.empty(w.shape, index.dtype)
1394    new_index[..., :-6] = index
1395    # Let matrix (including zero values) be banded.
1396    new_index[..., -6:] = np.where(
1397        w[..., -6:] != 0.0, [0, 1, 2, size - 3, size - 2, size - 1], index[..., :1]
1398    )
1399    index = new_index
1400    return index, weight
1401
1402
1403@dataclasses.dataclass(frozen=True)
1404class OverrideExteriorValue:
1405  """Abstract base class to set the value outside some domain extent to a
1406  constant value (`cval`)."""
1407
1408  boundary_antialiasing: bool = True
1409  """Antialias the pixel values adjacent to the boundary of the extent."""
1410
1411  uses_cval: bool = False
1412  """Modify some weights to introduce references to the `cval` constant value."""
1413
1414  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1415    """For all `point` outside some extent, modify the weight to be zero."""
1416
1417  def override_using_signed_distance(
1418      self, weight: _NDArray, point: _NDArray, signed_distance: _NDArray, /
1419  ) -> None:
1420    """Reduce sample weights for "outside" values based on the signed distance function,
1421    to effectively assign the constant value `cval`."""
1422    all_points_inside_domain = np.all(signed_distance <= 0.0)
1423    if all_points_inside_domain:
1424      return
1425    if self.boundary_antialiasing and min(point.shape) >= 2:
1426      # For discontinuous coordinate mappings, we may need to somehow ignore
1427      # the large finite differences computed across the map discontinuities.
1428      gradient = np.gradient(point)
1429      gradient_norm = np.linalg.norm(np.atleast_2d(gradient), axis=0)
1430      signed_distance_in_samples = signed_distance / (gradient_norm + 1e-20)
1431      # Opacity is in linear space, which is correct if Gamma is set.
1432      opacity = (0.5 - signed_distance_in_samples).clip(0.0, 1.0)
1433      weight *= opacity[..., None]
1434    else:
1435      is_outside = signed_distance > 0.0
1436      weight[is_outside, :] = 0.0
1437
1438
1439class NoOverrideExteriorValue(OverrideExteriorValue):
1440  """The function value is not overridden."""
1441
1442  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1443    pass
1444
1445
1446class UnitDomainOverrideExteriorValue(OverrideExteriorValue):
1447  """Values outside the unit interval [0, 1] are replaced by the constant `cval`."""
1448
1449  def __init__(self, **kwargs: Any) -> None:
1450    super().__init__(uses_cval=True, **kwargs)
1451
1452  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1453    signed_distance = abs(point - 0.5) - 0.5  # Boundaries at 0.0 and 1.0.
1454    self.override_using_signed_distance(weight, point, signed_distance)
1455
1456
1457class PlusMinusOneOverrideExteriorValue(OverrideExteriorValue):
1458  """Values outside the interval [-1, 1] are replaced by the constant `cval`."""
1459
1460  def __init__(self, **kwargs: Any) -> None:
1461    super().__init__(uses_cval=True, **kwargs)
1462
1463  def __call__(self, weight: _NDArray, point: _NDArray, /) -> None:
1464    signed_distance = abs(point) - 1.0  # Boundaries at -1.0 and 1.0.
1465    self.override_using_signed_distance(weight, point, signed_distance)
1466
1467
1468@dataclasses.dataclass(frozen=True)
1469class Boundary:
1470  """Domain boundary rules.  These define the reconstruction over the source domain near and beyond
1471  the domain boundaries.  The rules may be specified separately for each domain dimension."""
1472
1473  name: str = ''
1474  """Boundary rule name."""
1475
1476  coord_remap: RemapCoordinates = NoRemapCoordinates()
1477  """Modify specified coordinates prior to evaluating the reconstruction kernels."""
1478
1479  extend_samples: ExtendSamples = ReflectExtendSamples()
1480  """Define the value of each grid sample outside the unit domain as an affine combination of
1481  interior sample(s) and possibly the constant value (`cval`)."""
1482
1483  override_value: OverrideExteriorValue = NoOverrideExteriorValue()
1484  """Set the value outside some extent to a constant value (`cval`)."""
1485
1486  @property
1487  def uses_cval(self) -> bool:
1488    """True if weights may be non-affine, involving the constant value (`cval`)."""
1489    return self.extend_samples.uses_cval or self.override_value.uses_cval
1490
1491  def preprocess_coordinates(self, point: _NDArray, /) -> _NDArray:
1492    """Modify coordinates prior to evaluating the filter kernels."""
1493    # Antialiasing across the tile boundaries may be feasible but seems hard.
1494    point = self.coord_remap(point)
1495    return point
1496
1497  def apply(
1498      self, index: _NDArray, weight: _NDArray, point: _NDArray, size: int, gridtype: Gridtype, /
1499  ) -> tuple[_NDArray, _NDArray]:
1500    """Replace exterior samples by combinations of interior samples."""
1501    index, weight = self.extend_samples(index, weight, size, gridtype)
1502    self.override_reconstruction(weight, point)
1503    return index, weight
1504
1505  def override_reconstruction(self, weight: _NDArray, point: _NDArray, /) -> None:
1506    """For points outside an extent, modify weight to zero to assign `cval`."""
1507    self.override_value(weight, point)
1508
1509
1510_DICT_BOUNDARIES = {
1511    'reflect': Boundary('reflect', extend_samples=ReflectExtendSamples()),
1512    'wrap': Boundary('wrap', extend_samples=WrapExtendSamples()),
1513    'tile': Boundary(
1514        'title', coord_remap=TileRemapCoordinates(), extend_samples=ReflectExtendSamples()
1515    ),
1516    'clamp': Boundary('clamp', extend_samples=ClampExtendSamples()),
1517    'border': Boundary('border', extend_samples=BorderExtendSamples()),
1518    'natural': Boundary(
1519        'natural',
1520        extend_samples=ValidExtendSamples(),
1521        override_value=UnitDomainOverrideExteriorValue(),
1522    ),
1523    'linear_constant': Boundary(
1524        'linear_constant',
1525        extend_samples=LinearExtendSamples(),
1526        override_value=UnitDomainOverrideExteriorValue(),
1527    ),
1528    'quadratic_constant': Boundary(
1529        'quadratic_constant',
1530        extend_samples=QuadraticExtendSamples(),
1531        override_value=UnitDomainOverrideExteriorValue(),
1532    ),
1533    'reflect_clamp': Boundary('reflect_clamp', extend_samples=ReflectClampExtendSamples()),
1534    'constant': Boundary(
1535        'constant',
1536        extend_samples=ReflectExtendSamples(),
1537        override_value=UnitDomainOverrideExteriorValue(),
1538    ),
1539    'linear': Boundary('linear', extend_samples=LinearExtendSamples()),
1540    'quadratic': Boundary('quadratic', extend_samples=QuadraticExtendSamples()),
1541}
1542
1543BOUNDARIES = list(_DICT_BOUNDARIES)
1544"""Shortcut names for some predefined boundary rules (as defined by `_DICT_BOUNDARIES`):
1545
1546| name                   | a.k.a. / comments |
1547|------------------------|-------------------|
1548| `'reflect'`            | *reflected*, *symm*, *symmetric*, *mirror*, *grid-mirror* |
1549| `'wrap'`               | *periodic*, *repeat*, *grid-wrap* |
1550| `'tile'`               | like `'reflect'` within unit domain, then tile discontinuously |
1551| `'clamp'`              | *clamped*, *nearest*, *edge*, *clamp-to-edge*, repeat last sample |
1552| `'border'`             | *grid-constant*, use `cval` for samples outside unit domain |
1553| `'natural'`            | *renormalize* using only interior samples, use `cval` outside domain |
1554| `'reflect_clamp'`      | *mirror-clamp-to-edge* |
1555| `'constant'`           | like `'reflect'` but replace by `cval` outside unit domain |
1556| `'linear'`             | extrapolate from 2 last samples |
1557| `'quadratic'`          | extrapolate from 3 last samples |
1558| `'linear_constant'`    | like `'linear'` but replace by `cval` outside unit domain |
1559| `'quadratic_constant'` | like `'quadratic'` but replace by `cval` outside unit domain |
1560
1561These boundary rules may be specified per dimension.  See the source code for extensibility
1562using the classes `RemapCoordinates`, `ExtendSamples`, and `OverrideExteriorValue`.
1563
1564**Boundary rules illustrated in 1D:**
1565
1566<center>
1567<img src="https://github.com/hhoppe/resampler/raw/main/media/boundary_rules_in_1D.png" width="100%"/>
1568</center>
1569
1570**Boundary rules illustrated in 2D:**
1571
1572<center>
1573<img src="https://github.com/hhoppe/resampler/raw/main/media/boundary_rules_in_2D.png" width="100%"/>
1574</center>
1575"""
1576
1577_OFTUSED_BOUNDARIES = (
1578    'reflect wrap tile clamp border natural linear_constant quadratic_constant'.split()
1579)
1580"""A useful subset of `BOUNDARIES` for visualization in figures."""
1581
1582
1583def _get_boundary(boundary: str | Boundary, /) -> Boundary:
1584  """Return a `Boundary`, which can be specified as a name in `BOUNDARIES`."""
1585  return boundary if isinstance(boundary, Boundary) else _DICT_BOUNDARIES[boundary]
1586
1587
1588@dataclasses.dataclass(frozen=True)
1589class Filter(abc.ABC):
1590  """Abstract base class for filter kernel functions.
1591
1592  Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support
1593  interval [-radius, radius].  (Some sites instead define kernels over the interval [0, N]
1594  where N = 2 * radius.)
1595
1596  Portions of this code are adapted from the C++ library in
1597  https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp
1598
1599  See also https://hhoppe.com/proj/filtering/.
1600  """
1601
1602  name: str
1603  """Filter kernel name."""
1604
1605  radius: float
1606  """Max absolute value of x for which self(x) is nonzero."""
1607
1608  interpolating: bool = True
1609  """True if self(0) == 1.0 and self(i) == 0.0 for all nonzero integers i."""
1610
1611  continuous: bool = True
1612  """True if the kernel function has $C^0$ continuity."""
1613
1614  partition_of_unity: bool = True
1615  """True if the convolution of the kernel with a Dirac comb reproduces the
1616  unity function."""
1617
1618  unit_integral: bool = True
1619  """True if the integral of the kernel function is 1."""
1620
1621  requires_digital_filter: bool = False
1622  """True if the filter needs a pre/post digital filter for interpolation."""
1623
1624  @abc.abstractmethod
1625  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1626    """Return evaluation of filter kernel at locations x."""
1627
1628
1629class ImpulseFilter(Filter):
1630  """See https://en.wikipedia.org/wiki/Dirac_delta_function."""
1631
1632  def __init__(self) -> None:
1633    super().__init__(name='impulse', radius=1e-20, continuous=False, partition_of_unity=False)
1634
1635  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1636    raise AssertionError('The Impulse is infinitely narrow, so cannot be directly evaluated.')
1637
1638
1639class BoxFilter(Filter):
1640  """See https://en.wikipedia.org/wiki/Box_function.
1641
1642  The kernel function has value 1.0 over the half-open interval [-.5, .5).
1643  """
1644
1645  def __init__(self) -> None:
1646    super().__init__(name='box', radius=0.5, continuous=False)
1647
1648  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1649    use_asymmetric = True
1650    if use_asymmetric:
1651      x = np.asarray(x)
1652      return np.where((-0.5 <= x) & (x < 0.5), 1.0, 0.0)
1653    x = np.abs(x)
1654    return np.where(x < 0.5, 1.0, np.where(x == 0.5, 0.5, 0.0))
1655
1656
1657class TrapezoidFilter(Filter):
1658  """Filter for antialiased "area-based" filtering.
1659
1660  Args:
1661    radius: Specifies the support [-radius, radius] of the filter, where 0.5 < radius <= 1.0.
1662      The special case `radius = None` is a placeholder that indicates that the filter will be
1663      replaced by a trapezoid of the appropriate radius (based on scaling) for correct
1664      antialiasing in both minification and magnification.
1665
1666  This filter is similar to the BoxFilter but with linearly sloped sides.  It has value 1.0
1667  in the interval abs(x) <= 1.0 - radius and decreases linearly to value 0.0 in the interval
1668  1.0 - radius <= abs(x) <= radius, always with value 0.5 at x = 0.5.
1669  """
1670
1671  def __init__(self, *, radius: float | None = None) -> None:
1672    if radius is None:
1673      super().__init__(name='trapezoid', radius=0.0)
1674      return
1675    if not 0.5 < radius <= 1.0:
1676      raise ValueError(f'Radius {radius} is outside the range (0.5, 1.0].')
1677    super().__init__(name=f'trapezoid_{radius}', radius=radius)
1678
1679  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1680    x = np.abs(x)
1681    assert 0.5 < self.radius <= 1.0
1682    return ((0.5 + 0.25 / (self.radius - 0.5)) - (0.5 / (self.radius - 0.5)) * x).clip(0.0, 1.0)
1683
1684
1685class TriangleFilter(Filter):
1686  """See https://en.wikipedia.org/wiki/Triangle_function.
1687
1688  Also known as the hat or tent function.  It is used for piecewise-linear
1689  (or bilinear, or trilinear, ...) interpolation.
1690  """
1691
1692  def __init__(self) -> None:
1693    super().__init__(name='triangle', radius=1.0)
1694
1695  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1696    return (1.0 - np.abs(x)).clip(0.0, 1.0)
1697
1698
1699class CubicFilter(Filter):
1700  """Family of cubic filters parameterized by two scalar parameters.
1701
1702  Args:
1703    b: first scalar parameter.
1704    c: second scalar parameter.
1705
1706  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1707  https://doi.org/10.1145/378456.378514.
1708
1709  [D. P. Mitchell and A. N. Netravali. Reconstruction filters in computer graphics.
1710  Computer Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1711
1712  - The filter has quadratic precision iff b + 2 * c == 1.
1713  - The filter is interpolating iff b == 0.
1714  - (b=1, c=0) is the (non-interpolating) cubic B-spline basis;
1715  - (b=1/3, c=1/3) is the Mitchell filter;
1716  - (b=0, c=0.5) is the Catmull-Rom spline (which has cubic precision);
1717  - (b=0, c=0.75) is the "sharper cubic" used in Photoshop and OpenCV.
1718  """
1719
1720  def __init__(self, *, b: float, c: float, name: str | None = None) -> None:
1721    name = f'cubic_b{b}_c{c}' if name is None else name
1722    interpolating = b == 0
1723    super().__init__(name=name, radius=2.0, interpolating=interpolating)
1724    self.b, self.c = b, c
1725
1726  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1727    x = np.abs(x)
1728    b, c = self.b, self.c
1729    f3, f2, f0 = 2 - 9 / 6 * b - c, -3 + 2 * b + c, 1 - 1 / 3 * b
1730    g3, g2, g1, g0 = -b / 6 - c, b + 5 * c, -2 * b - 8 * c, 8 / 6 * b + 4 * c
1731    # (np.polynomial.polynomial.polyval(x, [f0, 0, f2, f3]) is almost
1732    # twice as slow; see also https://stackoverflow.com/questions/24065904)
1733    v01 = ((f3 * x + f2) * x) * x + f0
1734    v12 = ((g3 * x + g2) * x + g1) * x + g0
1735    return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1736
1737
1738class CatmullRomFilter(CubicFilter):
1739  """Cubic filter with cubic precision.  Also known as Keys filter.
1740
1741  [E. Catmull, R. Rom.  A class of local interpolating splines.  Computer aided geometric
1742  design, 1974]
1743  [Wikipedia](https://en.wikipedia.org/wiki/Cubic_Hermite_spline#Catmull%E2%80%93Rom_spline)
1744
1745  [R. G. Keys.  Cubic convolution interpolation for digital image processing.
1746  IEEE Trans. on Acoustics, Speech, and Signal Processing, 29(6), 1981.]
1747  https://ieeexplore.ieee.org/document/1163711/.
1748  """
1749
1750  def __init__(self) -> None:
1751    super().__init__(b=0, c=0.5, name='cubic')
1752
1753
1754class MitchellFilter(CubicFilter):
1755  """See https://doi.org/10.1145/378456.378514.
1756
1757  [D. P. Mitchell and A. N. Netravali.  Reconstruction filters in computer graphics.  Computer
1758  Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1759  """
1760
1761  def __init__(self) -> None:
1762    super().__init__(b=1 / 3, c=1 / 3, name='mitchell')
1763
1764
1765class SharpCubicFilter(CubicFilter):
1766  """Cubic filter that is sharper than Catmull-Rom filter.
1767
1768  Used by some tools including OpenCV and Photoshop.
1769
1770  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1771  https://entropymine.com/resamplescope/notes/photoshop/.
1772  """
1773
1774  def __init__(self) -> None:
1775    super().__init__(b=0, c=0.75, name='sharpcubic')
1776
1777
1778class LanczosFilter(Filter):
1779  """High-quality filter: sinc function modulated by a sinc window.
1780
1781  Args:
1782    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1783    sampled: If True, use a discretized approximation for improved speed.
1784
1785  See https://en.wikipedia.org/wiki/Lanczos_kernel.
1786  """
1787
1788  def __init__(self, *, radius: int, sampled: bool = True) -> None:
1789    super().__init__(
1790        name=f'lanczos_{radius}', radius=radius, partition_of_unity=False, unit_integral=False
1791    )
1792
1793    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1794    def _eval(x: _ArrayLike) -> _NDArray:
1795      x = np.abs(x)
1796      # Note that window[n] = sinc(2*n/N - 1), with 0 <= n <= N.
1797      # But, x = n - N/2, or equivalently, n = x + N/2, with -N/2 <= x <= N/2.
1798      window = _sinc(x / radius)  # Zero-phase function w_0(x).
1799      return np.where(x < radius, _sinc(x) * window, 0.0)
1800
1801    self._function = _eval
1802
1803  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1804    return self._function(x)
1805
1806
1807class GeneralizedHammingFilter(Filter):
1808  """Sinc function modulated by a Hamming window.
1809
1810  Args:
1811    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1812    a0: Scalar parameter, where 0.0 < a0 < 1.0.  The case of a0=0.5 is the Hann filter.
1813
1814  See https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows,
1815  and hamming() in https://github.com/scipy/scipy/blob/main/scipy/signal/windows/_windows.py.
1816
1817  Note that `'hamming3'` is `(radius=3, a0=25/46)`, which close to but different from `a0=0.54`.
1818
1819  See also np.hamming() and np.hanning().
1820  """
1821
1822  def __init__(self, *, radius: int, a0: float) -> None:
1823    super().__init__(
1824        name=f'hamming_{radius}',
1825        radius=radius,
1826        partition_of_unity=False,  # 1:1.00242  av=1.00188  sd=0.00052909
1827        unit_integral=False,  # 1.00188
1828    )
1829    assert 0.0 < a0 < 1.0
1830    self.a0 = a0
1831
1832  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1833    x = np.abs(x)
1834    # Note that window[n] = a0 - (1 - a0) * cos(2 * pi * n / N), 0 <= n <= N.
1835    # With n = x + N/2, we get the zero-phase function w_0(x):
1836    window = self.a0 + (1.0 - self.a0) * np.cos(np.pi / self.radius * x)
1837    return np.where(x < self.radius, _sinc(x) * window, 0.0)
1838
1839
1840class KaiserFilter(Filter):
1841  """Sinc function modulated by a Kaiser-Bessel window.
1842
1843  See https://en.wikipedia.org/wiki/Kaiser_window, and example use in:
1844  [Karras et al. 20201.  Alias-free generative adversarial networks.
1845  https://arxiv.org/pdf/2106.12423.pdf].
1846
1847  See also np.kaiser().
1848
1849  Args:
1850    radius: Value L/2 in the definition.  It may be fractional for a (digital) resizing filter
1851      (sample spacing s != 1) with an even number of samples (dual grid), e.g., Eq. (6)
1852      in [Karras et al. 2021] --- this effects the precise shape of the window function.
1853    beta: Determines the trade-off between main-lobe width and side-lobe level.
1854    sampled: If True, use a discretized approximation for improved speed.
1855  """
1856
1857  def __init__(self, *, radius: float, beta: float, sampled: bool = True) -> None:
1858    assert beta >= 0.0
1859    super().__init__(
1860        name=f'kaiser_{radius}_{beta}', radius=radius, partition_of_unity=False, unit_integral=False
1861    )
1862
1863    @_cache_sampled_1d_function(xmin=-math.ceil(radius), xmax=math.ceil(radius), enable=sampled)
1864    def _eval(x: _ArrayLike) -> _NDArray:
1865      x = np.abs(x)
1866      window = np.i0(beta * np.sqrt((1.0 - np.square(x / radius)).clip(0.0, 1.0))) / np.i0(beta)
1867      return np.where(x <= radius + 1e-6, _sinc(x) * window, 0.0)
1868
1869    self._function = _eval
1870
1871  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1872    return self._function(x)
1873
1874
1875class BsplineFilter(Filter):
1876  """B-spline of a non-negative degree.
1877
1878  Args:
1879    degree: The polynomial degree of the B-spline segments.
1880      With `degree=0`, it is like `BoxFilter` except with f(0.5) = f(-0.5) = 0.
1881      With `degree=1`, it is identical to `TriangleFilter`.
1882      With `degree >= 2`, it is no longer interpolating.
1883
1884  See [Carl de Boor.  A practical guide to splines.  Springer, 2001.]
1885  https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.BSpline.html
1886  """
1887
1888  def __init__(self, *, degree: int) -> None:
1889    if degree < 0:
1890      raise ValueError(f'Bspline of degree {degree} is invalid.')
1891    radius = (degree + 1) / 2
1892    interpolating = degree <= 1
1893    super().__init__(name=f'bspline{degree}', radius=radius, interpolating=interpolating)
1894    t = list(range(degree + 2))
1895    self._bspline = scipy.interpolate.BSpline.basis_element(t)
1896
1897  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1898    x = np.abs(x)
1899    return np.where(x < self.radius, self._bspline(x + self.radius), 0.0)
1900
1901
1902class CardinalBsplineFilter(Filter):
1903  """Interpolating B-spline, achieved with aid of digital pre or post filter.
1904
1905  Args:
1906    degree: The polynomial degree of the B-spline segments.
1907    sampled: If True, use a discretized approximation for improved speed.
1908
1909  See [Hou and Andrews.  Cubic splines for image interpolation and digital filtering, 1978] and
1910  [Unser et al.  Fast B-spline transforms for continuous image representation and interpolation,
1911  1991].
1912  """
1913
1914  def __init__(self, *, degree: int, sampled: bool = True) -> None:
1915    self.degree = degree
1916    if degree < 0:
1917      raise ValueError(f'Bspline of degree {degree} is invalid.')
1918    radius = (degree + 1) / 2
1919    super().__init__(
1920        name=f'cardinal{degree}',
1921        radius=radius,
1922        requires_digital_filter=degree >= 2,
1923        continuous=degree >= 1,
1924    )
1925    t = list(range(degree + 2))
1926    bspline = scipy.interpolate.BSpline.basis_element(t)
1927
1928    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1929    def _eval(x: _ArrayLike) -> _NDArray:
1930      x = np.abs(x)
1931      return np.where(x < radius, bspline(x + radius), 0.0)
1932
1933    self._function = _eval
1934
1935  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1936    return self._function(x)
1937
1938
1939class OmomsFilter(Filter):
1940  """OMOMS interpolating filter, with aid of digital pre or post filter.
1941
1942  Args:
1943    degree: The polynomial degree of the filter segments.
1944
1945  Optimal MOMS (maximal-order-minimal-support) function; see [Blu and Thevenaz, MOMS: Maximal-order
1946  interpolation of minimal support, 2001].
1947  https://infoscience.epfl.ch/record/63074/files/blu0101.pdf
1948  """
1949
1950  def __init__(self, *, degree: int) -> None:
1951    if degree not in (3, 5):
1952      raise ValueError(f'Degree {degree} not supported.')
1953    super().__init__(name=f'omoms{degree}', radius=(degree + 1) / 2, requires_digital_filter=True)
1954    self.degree = degree
1955
1956  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1957    x = np.abs(x)
1958    match self.degree:
1959      case 3:
1960        v01 = ((0.5 * x - 1.0) * x + 3 / 42) * x + 26 / 42
1961        v12 = ((-7 / 42 * x + 1.0) * x - 85 / 42) * x + 58 / 42
1962        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1963      case 5:
1964        v01 = ((((-1 / 12 * x + 1 / 4) * x - 5 / 99) * x - 9 / 22) * x - 1 / 792) * x + 229 / 440
1965        v12 = (
1966            (((1 / 24 * x - 3 / 8) * x + 505 / 396) * x - 83 / 44) * x + 1351 / 1584
1967        ) * x + 839 / 2640
1968        v23 = (
1969            (((-1 / 120 * x + 1 / 8) * x - 299 / 396) * x + 101 / 44) * x - 27811 / 7920
1970        ) * x + 5707 / 2640
1971        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, np.where(x < 3.0, v23, 0.0)))
1972      case _:
1973        raise ValueError(self.degree)
1974
1975
1976class GaussianFilter(Filter):
1977  r"""See https://en.wikipedia.org/wiki/Gaussian_function.
1978
1979  Args:
1980    standard_deviation: Sets the Gaussian $\sigma$.  The default value is 1.25/3.0, which
1981      creates a kernel that is as-close-as-possible to a partition of unity.
1982  """
1983
1984  DEFAULT_STANDARD_DEVIATION = 1.25 / 3.0
1985  """This value creates a kernel that is as-close-as-possible to a partition of unity; see
1986  mesh_processing/test/GridOp_test.cpp: `0.93503:1.06497     av=1           sd=0.0459424`.
1987  Another possibility is 0.5, as suggested on p. 4 of [Ken Turkowski.  Filters for common
1988  resampling tasks, 1990] for kernels with a support of 3 pixels.
1989  https://cadxfem.org/inf/ResamplingFilters.pdf
1990  """
1991
1992  def __init__(self, *, standard_deviation: float = DEFAULT_STANDARD_DEVIATION) -> None:
1993    super().__init__(
1994        name=f'gaussian_{standard_deviation:.3f}',
1995        radius=np.ceil(8.0 * standard_deviation),  # Sufficiently large.
1996        interpolating=False,
1997        partition_of_unity=False,
1998    )
1999    self.standard_deviation = standard_deviation
2000
2001  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2002    x = np.abs(x)
2003    sdv = self.standard_deviation
2004    v0r = np.exp(np.square(x / sdv) / -2.0) / (np.sqrt(math.tau) * sdv)
2005    return np.where(x < self.radius, v0r, 0.0)
2006
2007
2008class NarrowBoxFilter(Filter):
2009  """Compact footprint, used for visualization of grid sample location.
2010
2011  Args:
2012    radius: Specifies the support [-radius, radius] of the narrow box function.  (The default
2013      value 0.199 is an inexact 0.2 to avoid numerical ambiguities.)
2014  """
2015
2016  def __init__(self, *, radius: float = 0.199) -> None:
2017    super().__init__(
2018        name='narrowbox',
2019        radius=radius,
2020        continuous=False,
2021        unit_integral=False,
2022        partition_of_unity=False,
2023    )
2024
2025  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2026    radius = self.radius
2027    magnitude = 1.0
2028    x = np.asarray(x)
2029    return np.where((-radius <= x) & (x < radius), magnitude, 0.0)
2030
2031
2032_DEFAULT_FILTER = 'lanczos3'
2033
2034_DICT_FILTERS = {
2035    'impulse': ImpulseFilter(),
2036    'box': BoxFilter(),
2037    'trapezoid': TrapezoidFilter(),
2038    'triangle': TriangleFilter(),
2039    'cubic': CatmullRomFilter(),
2040    'sharpcubic': SharpCubicFilter(),
2041    'lanczos3': LanczosFilter(radius=3),
2042    'lanczos5': LanczosFilter(radius=5),
2043    'lanczos10': LanczosFilter(radius=10),
2044    'cardinal3': CardinalBsplineFilter(degree=3),
2045    'cardinal5': CardinalBsplineFilter(degree=5),
2046    'omoms3': OmomsFilter(degree=3),
2047    'omoms5': OmomsFilter(degree=5),
2048    'hamming3': GeneralizedHammingFilter(radius=3, a0=25 / 46),
2049    'kaiser3': KaiserFilter(radius=3.0, beta=7.12),
2050    'gaussian': GaussianFilter(),
2051    'bspline3': BsplineFilter(degree=3),
2052    'mitchell': MitchellFilter(),
2053    'narrowbox': NarrowBoxFilter(),
2054    # Not in FILTERS:
2055    'hamming1': GeneralizedHammingFilter(radius=1, a0=0.54),
2056    'hann3': GeneralizedHammingFilter(radius=3, a0=0.5),
2057    'lanczos4': LanczosFilter(radius=4),
2058}
2059
2060FILTERS = list(itertools.takewhile(lambda x: x != 'hamming1', _DICT_FILTERS))
2061r"""Shortcut names for some predefined filter kernels (specified per dimension).
2062The names expand to:
2063
2064| name           | `Filter`                      | a.k.a. / comments |
2065|----------------|-------------------------------|-------------------|
2066| `'impulse'`    | `ImpulseFilter()`             | *nearest* |
2067| `'box'`        | `BoxFilter()`                 | non-antialiased box, e.g. ImageMagick |
2068| `'trapezoid'`  | `TrapezoidFilter()`           | *area* antialiasing, e.g. `cv.INTER_AREA` |
2069| `'triangle'`   | `TriangleFilter()`            | *linear*  (*bilinear* in 2D), spline `order=1` |
2070| `'cubic'`      | `CatmullRomFilter()`          | *catmullrom*, *keys*, *bicubic* |
2071| `'sharpcubic'` | `SharpCubicFilter()`          | `cv.INTER_CUBIC`, `torch 'bicubic'` |
2072| `'lanczos3'`   | `LanczosFilter`(radius=3)     | support window [-3, 3] |
2073| `'lanczos5'`   | `LanczosFilter`(radius=5)     | [-5, 5] |
2074| `'lanczos10'`  | `LanczosFilter`(radius=10)    | [-10, 10] |
2075| `'cardinal3'`  | `CardinalBsplineFilter`(degree=3) | *spline interpolation*, `order=3`, *GF* |
2076| `'cardinal5'`  | `CardinalBsplineFilter`(degree=5) | *spline interpolation*, `order=5`, *GF* |
2077| `'omoms3'`     | `OmomsFilter`(degree=3)       | non-$C^1$, [-3, 3], *GF* |
2078| `'omoms5'`     | `OmomsFilter`(degree=5)       | non-$C^1$, [-5, 5], *GF* |
2079| `'hamming3'`   | `GeneralizedHammingFilter`(...) | (radius=3, a0=25/46) |
2080| `'kaiser3'`    | `KaiserFilter`(radius=3.0, beta=7.12) | |
2081| `'gaussian'`   | `GaussianFilter()`            | non-interpolating, default $\sigma=1.25/3$ |
2082| `'bspline3'`   | `BsplineFilter`(degree=3)     | non-interpolating |
2083| `'mitchell'`   | `MitchellFilter()`            | *mitchellcubic* |
2084| `'narrowbox'`  | `NarrowBoxFilter()`           | for visualization of sample positions |
2085
2086The comment label *GF* denotes a [generalized filter](https://hhoppe.com/proj/filtering/), formed
2087as the composition of a finitely supported kernel and a discrete inverse convolution.
2088
2089**Some example filter kernels:**
2090
2091<center>
2092<img src="https://github.com/hhoppe/resampler/raw/main/media/filter_summary.png" width="100%"/>
2093</center>
2094
2095<br/>A more extensive set of filters is presented [here](#plots_of_filters) in the
2096[notebook](https://colab.research.google.com/github/hhoppe/resampler/blob/main/resampler_notebook.ipynb),
2097together with visualizations and analyses of the filter properties.
2098See the source code for extensibility.
2099"""
2100
2101
2102def _get_filter(filter: str | Filter, /) -> Filter:
2103  """Return a `Filter`, which can be specified as a name string key in `FILTERS`."""
2104  return filter if isinstance(filter, Filter) else _DICT_FILTERS[filter]
2105
2106
2107def _to_float_01(array: _Array, /, dtype: _DTypeLike) -> _Array:
2108  """Scale uint to the range [0.0, 1.0], and clip float to [0.0, 1.0]."""
2109  array_dtype = _arr_dtype(array)
2110  dtype = np.dtype(dtype)
2111  assert np.issubdtype(dtype, np.floating)
2112  match array_dtype.type:
2113    case np.uint8 | np.uint16 | np.uint32:
2114      if _arr_arraylib(array) == 'numpy':
2115        assert isinstance(array, np.ndarray)  # Help mypy.
2116        return np.multiply(array, 1 / np.iinfo(array_dtype).max, dtype=dtype)
2117      return _arr_astype(array, dtype) / np.iinfo(array_dtype).max
2118    case _:
2119      assert np.issubdtype(array_dtype, np.floating)
2120      return _arr_clip(array, 0.0, 1.0, dtype)
2121
2122
2123def _from_float(array: _Array, /, dtype: _DTypeLike) -> _Array:
2124  """Convert a float in range [0.0, 1.0] to uint or float type."""
2125  assert np.issubdtype(_arr_dtype(array), np.floating)
2126  dtype = np.dtype(dtype)
2127  match dtype.type:
2128    case np.uint8 | np.uint16:
2129      return cast(_Array, _arr_astype(array * np.float32(np.iinfo(dtype).max) + 0.5, dtype))
2130    case np.uint32:
2131      return cast(_Array, _arr_astype(array * np.float64(np.iinfo(dtype).max) + 0.5, dtype))
2132    case _:
2133      assert np.issubdtype(dtype, np.floating)
2134      return _arr_astype(array, dtype)
2135
2136
2137@dataclasses.dataclass(frozen=True)
2138class Gamma(abc.ABC):
2139  """Abstract base class for transfer functions on sample values.
2140
2141  Image/video content is often stored using a color component transfer function.
2142  See https://en.wikipedia.org/wiki/Gamma_correction.
2143
2144  Converts between integer types and [0.0, 1.0] internal value range.
2145  """
2146
2147  name: str
2148  """Name of component transfer function."""
2149
2150  @abc.abstractmethod
2151  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2152    """Decode source sample values into floating-point, possibly nonlinearly.
2153
2154    Uint source values are mapped to the range [0.0, 1.0].
2155    """
2156
2157  @abc.abstractmethod
2158  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2159    """Encode float signal into destination samples, possibly nonlinearly.
2160
2161    Uint destination values are mapped from the range [0.0, 1.0].
2162
2163    Note that non-integer destination types are not clipped to the range [0.0, 1.0].
2164    If that is desired, it can be performed as a postprocess using `output.clip(0.0, 1.0)`.
2165    """
2166
2167
2168class IdentityGamma(Gamma):
2169  """Identity component transfer function."""
2170
2171  def __init__(self) -> None:
2172    super().__init__('identity')
2173
2174  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2175    dtype = np.dtype(dtype)
2176    assert np.issubdtype(dtype, np.inexact)
2177    if np.issubdtype(_arr_dtype(array), np.unsignedinteger):
2178      return _to_float_01(array, dtype)
2179    return _arr_astype(array, dtype)
2180
2181  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2182    dtype = np.dtype(dtype)
2183    assert np.issubdtype(dtype, np.number)
2184    if np.issubdtype(dtype, np.unsignedinteger):
2185      return _from_float(_arr_clip(array, 0.0, 1.0), dtype)
2186    if np.issubdtype(dtype, np.integer):
2187      return _arr_astype(array + 0.5, dtype)
2188    return _arr_astype(array, dtype)
2189
2190
2191class PowerGamma(Gamma):
2192  """Gamma correction using a power function."""
2193
2194  def __init__(self, power: float) -> None:
2195    super().__init__(name=f'power_{power}')
2196    self.power = power
2197
2198  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2199    dtype = np.dtype(dtype)
2200    assert np.issubdtype(dtype, np.floating)
2201    if _arr_dtype(array) == np.uint8 and self.power != 2:
2202      arraylib = _arr_arraylib(array)
2203      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2204      return _arr_getitem(decode_table, array)
2205
2206    array = _to_float_01(array, dtype)
2207    return _arr_square(array) if self.power == 2 else array**self.power
2208
2209  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2210    array = _arr_clip(array, 0.0, 1.0)
2211    array = _arr_sqrt(array) if self.power == 2 else array ** (1.0 / self.power)
2212    return _from_float(array, dtype)
2213
2214
2215class SrgbGamma(Gamma):
2216  """Gamma correction using sRGB; see https://en.wikipedia.org/wiki/SRGB."""
2217
2218  def __init__(self) -> None:
2219    super().__init__(name='srgb')
2220
2221  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2222    dtype = np.dtype(dtype)
2223    assert np.issubdtype(dtype, np.floating)
2224    if _arr_dtype(array) == np.uint8:
2225      arraylib = _arr_arraylib(array)
2226      decode_table = _make_array(self.decode(np.arange(256, dtype=dtype) / 255), arraylib)
2227      return _arr_getitem(decode_table, array)
2228
2229    x = _to_float_01(array, dtype)
2230    return _arr_where(x > 0.04045, ((x + 0.055) / 1.055) ** 2.4, x / 12.92)
2231
2232  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2233    x = _arr_clip(array, 0.0, 1.0)
2234    # Unfortunately, exponentiation is slow, and np.digitize() is even slower.
2235    # pytype: disable=wrong-arg-types
2236    x = _arr_where(x > 0.0031308, x ** (1.0 / 2.4) * 1.055 - (0.055 - 1e-17), x * 12.92)
2237    # pytype: enable=wrong-arg-types
2238    return _from_float(x, dtype)
2239
2240
2241_DICT_GAMMAS = {
2242    'identity': IdentityGamma(),
2243    'power2': PowerGamma(2.0),
2244    'power22': PowerGamma(2.2),
2245    'srgb': SrgbGamma(),
2246}
2247
2248GAMMAS = list(_DICT_GAMMAS)
2249r"""Shortcut names for some predefined gamma-correction schemes:
2250
2251| name | `Gamma` | Decoding function<br/> (linear space from stored value) | Encoding function<br/> (stored value from linear space) |
2252|---|---|:---:|:---:|
2253| `'identity'` | `IdentityGamma()` | $l = e$ | $e = l$ |
2254| `'power2'` | `PowerGamma`(2.0) | $l = e^{2.0}$ | $e = l^{1/2.0}$ |
2255| `'power22'` | `PowerGamma`(2.2) | $l = e^{2.2}$ | $e = l^{1/2.2}$ |
2256| `'srgb'` | `SrgbGamma()` | $l = \left(\left(e + 0.055\right) / 1.055\right)^{2.4}$ | $e = l^{1/2.4} * 1.055 - 0.055$ |
2257
2258See the source code for extensibility.
2259"""
2260
2261
2262def _get_gamma(gamma: str | Gamma, /) -> Gamma:
2263  """Return a `Gamma`, which can be specified as a name in `GAMMAS`."""
2264  return gamma if isinstance(gamma, Gamma) else _DICT_GAMMAS[gamma]
2265
2266
2267def _get_src_dst_gamma(
2268    gamma: str | Gamma | None,
2269    src_gamma: str | Gamma | None,
2270    dst_gamma: str | Gamma | None,
2271    src_dtype: _DType,
2272    dst_dtype: _DType,
2273) -> tuple[Gamma, Gamma]:
2274  if gamma is None and src_gamma is None and dst_gamma is None:
2275    src_uint = np.issubdtype(src_dtype, np.unsignedinteger)
2276    dst_uint = np.issubdtype(dst_dtype, np.unsignedinteger)
2277    if src_uint and dst_uint:
2278      # The default might ideally be 'srgb' but that conversion is costlier.
2279      gamma = 'power2'
2280    elif not src_uint and not dst_uint:
2281      gamma = 'identity'
2282    else:
2283      raise ValueError(f'Gamma must be specified because {src_dtype=} and {dst_dtype=}.')
2284  if gamma is not None:
2285    if src_gamma is not None:
2286      raise ValueError('Cannot specify both gamma and src_gamma.')
2287    if dst_gamma is not None:
2288      raise ValueError('Cannot specify both gamma and dst_gamma.')
2289    src_gamma = dst_gamma = gamma
2290  assert src_gamma and dst_gamma
2291  src_gamma = _get_gamma(src_gamma)
2292  dst_gamma = _get_gamma(dst_gamma)
2293  return src_gamma, dst_gamma
2294
2295
2296def _create_resize_matrix(
2297    src_size: int,
2298    dst_size: int,
2299    src_gridtype: Gridtype,
2300    dst_gridtype: Gridtype,
2301    boundary: Boundary,
2302    filter: Filter,
2303    prefilter: Filter | None = None,
2304    scale: float = 1.0,
2305    translate: float = 0.0,
2306    dtype: _DTypeLike = np.float64,
2307    arraylib: str = 'numpy',
2308) -> tuple[_Array, _Array | None]:
2309  """Compute affine weights for 1D resampling from `src_size` to `dst_size`.
2310
2311  Compute a sparse matrix in which each row expresses a destination sample value as a combination
2312  of source sample values depending on the boundary rule.  If the combination is non-affine,
2313  the remainder (returned as `cval_weight`) is the contribution of the special constant value
2314  (cval) defined outside the domain.
2315
2316  Args:
2317    src_size: The number of samples within the source 1D domain.
2318    dst_size: The number of samples within the destination 1D domain.
2319    src_gridtype: Placement of the samples in the source domain grid.
2320    dst_gridtype: Placement of the output samples in the destination domain grid.
2321    boundary: The reconstruction boundary rule.
2322    filter: The reconstruction kernel (used for upsampling/magnification).
2323    prefilter: The prefilter kernel (used for downsampling/minification).  If it is `None`,
2324      `filter` is used.
2325    scale: Scaling factor applied when mapping the source domain onto the destination domain.
2326    translate: Offset applied when mapping the scaled source domain onto the destination domain.
2327    dtype: Precision of computed resize matrix entries.
2328    arraylib: Representation of output.  Must be an element of `ARRAYLIBS`.
2329
2330  Returns:
2331    sparse_matrix: Matrix whose rows express output sample values as affine combinations of the
2332      source sample values.
2333    cval_weight: Optional vector expressing the additional contribution of the constant value
2334      (`cval`) to the combination in each row of `sparse_matrix`.  It equals one minus the sum of
2335      the weights in each matrix row.
2336  """
2337  if src_size < src_gridtype.min_size():
2338    raise ValueError(f'Source size {src_size} is too small for resize.')
2339  if dst_size < dst_gridtype.min_size():
2340    raise ValueError(f'Destination size {dst_size} is too small for resize.')
2341  prefilter = filter if prefilter is None else prefilter
2342  dtype = np.dtype(dtype)
2343  assert np.issubdtype(dtype, np.floating)
2344
2345  scaling = dst_gridtype.size_in_samples(dst_size) / src_gridtype.size_in_samples(src_size) * scale
2346  is_minification = scaling < 1.0
2347  filter = prefilter if is_minification else filter
2348  if filter.name == 'trapezoid':
2349    radius = 0.5 + 0.5 * min(scaling, 1.0 / scaling)
2350    filter = TrapezoidFilter(radius=radius)
2351  radius = filter.radius
2352  num_samples = int(np.ceil(radius * 2 / scaling) if is_minification else np.ceil(radius * 2))
2353
2354  dst_index = np.arange(dst_size, dtype=dtype)
2355  # Destination sample locations in unit domain [0, 1].
2356  dst_position = dst_gridtype.point_from_index(dst_index, dst_size)
2357
2358  src_position = (dst_position - translate) / scale
2359  src_position = boundary.preprocess_coordinates(src_position)
2360
2361  # Sample positions mapped back to source unit domain [0, 1].
2362  src_float_index = src_gridtype.index_from_point(src_position, src_size)
2363  src_first_index = (
2364      np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
2365      - (num_samples - 1) // 2
2366  )
2367
2368  sample_index = np.arange(num_samples, dtype=np.int32)
2369  src_index = src_first_index[:, None] + sample_index  # (dst_size, num_samples)
2370
2371  def get_weight_matrix() -> _NDArray:
2372    if filter.name == 'impulse':
2373      return np.ones(src_index.shape, dtype)
2374    if is_minification:
2375      x = (src_float_index[:, None] - src_index.astype(dtype)) * scaling
2376      return filter(x) * scaling
2377    # Either same size or magnification.
2378    x = src_float_index[:, None] - src_index.astype(dtype)
2379    return filter(x)
2380
2381  weight = get_weight_matrix().astype(dtype, copy=False)
2382
2383  if filter.name != 'narrowbox' and (is_minification or not filter.partition_of_unity):
2384    weight = weight / weight.sum(axis=-1)[..., None]
2385
2386  src_index, weight = boundary.apply(src_index, weight, src_position, src_size, src_gridtype)
2387  shape = dst_size, src_size
2388
2389  def prepare_sparse_resize_matrix() -> tuple[_NDArray, _NDArray, _NDArray]:
2390    linearized = (src_index + np.indices(src_index.shape)[0] * src_size).ravel()
2391    values = weight.ravel()
2392    # Remove the zero weights.
2393    nonzero = values != 0.0
2394    linearized, values = linearized[nonzero], values[nonzero]
2395    # Sort and merge the duplicate indices.
2396    unique, unique_inverse = np.unique(linearized, return_inverse=True)
2397    data2 = np.ones(len(linearized), np.float32)
2398    row_ind2 = unique_inverse
2399    col_ind2 = np.arange(len(linearized))
2400    shape2 = len(unique), len(linearized)
2401    csr = scipy.sparse.csr_matrix((data2, (row_ind2, col_ind2)), shape=shape2)
2402    data = csr * values  # Merged values.
2403    row_ind, col_ind = unique // src_size, unique % src_size  # Merged indices.
2404    return data, row_ind, col_ind
2405
2406  data, row_ind, col_ind = prepare_sparse_resize_matrix()
2407  resize_matrix = _make_sparse_matrix(data, row_ind, col_ind, shape, arraylib)
2408
2409  uses_cval = boundary.uses_cval or filter.name == 'narrowbox'
2410  cval_weight = _make_array(1.0 - weight.sum(axis=-1), arraylib) if uses_cval else None
2411
2412  return resize_matrix, cval_weight
2413
2414
2415def _apply_digital_filter_1d(
2416    array: _Array,
2417    gridtype: Gridtype,
2418    boundary: Boundary,
2419    cval: _ArrayLike,
2420    filter: Filter,
2421    /,
2422    *,
2423    axis: int = 0,
2424) -> _Array:
2425  """Apply inverse convolution to the specified dimension of the array.
2426
2427  Find the array coefficients such that convolution with the (continuous) filter (given
2428  gridtype and boundary) interpolates the original array values.
2429  """
2430  assert filter.requires_digital_filter
2431  arraylib = _arr_arraylib(array)
2432
2433  if arraylib == 'tensorflow':
2434    import tensorflow as tf
2435
2436    def forward(x: _NDArray) -> _NDArray:
2437      return _apply_digital_filter_1d_numpy(x, gridtype, boundary, cval, filter, axis, False)
2438
2439    def backward(grad_output: _NDArray) -> _NDArray:
2440      return _apply_digital_filter_1d_numpy(
2441          grad_output, gridtype, boundary, cval, filter, axis, True
2442      )
2443
2444    @tf.custom_gradient  # type: ignore[misc]
2445    def tensorflow_inverse_convolution(x: _TensorflowTensor) -> _TensorflowTensor:
2446      # Although `forward` accesses parameters gridtype, boundary, etc., it is not stateful
2447      # because the function is redefined on each invocation of _apply_digital_filter_1d.
2448      y = tf.numpy_function(forward, [x], x.dtype, stateful=False)
2449
2450      def grad(grad_output: _TensorflowTensor) -> _TensorflowTensor:
2451        return tf.numpy_function(backward, [grad_output], x.dtype, stateful=False)
2452
2453      return y, grad
2454
2455    return tensorflow_inverse_convolution(array)
2456
2457  if arraylib == 'torch':
2458    import torch.autograd
2459
2460    class InverseConvolution(torch.autograd.Function):  # type: ignore[misc] # pylint: disable=abstract-method
2461      """Differentiable wrapper for _apply_digital_filter_1d."""
2462
2463      @staticmethod
2464      def forward(ctx: Any, *args: _TorchTensor, **kwargs: Any) -> _TorchTensor:
2465        del ctx
2466        assert not kwargs
2467        (x,) = args
2468        a = _apply_digital_filter_1d_numpy(
2469            x.detach().numpy(), gridtype, boundary, cval, filter, axis, False
2470        )
2471        return torch.as_tensor(a)
2472
2473      @staticmethod
2474      def backward(ctx: Any, *grad_outputs: _TorchTensor) -> _TorchTensor:
2475        del ctx
2476        (grad_output,) = grad_outputs
2477        a = _apply_digital_filter_1d_numpy(
2478            grad_output.detach().numpy(), gridtype, boundary, cval, filter, axis, True
2479        )
2480        return torch.as_tensor(a)
2481
2482    return InverseConvolution.apply(array)
2483
2484  if arraylib == 'jax':
2485    import jax
2486    import jax.numpy as jnp
2487    # It seems rather difficult to implement this digital filter (inverse convolution) in Jax.
2488    # https://jax.readthedocs.io/en/latest/jax.scipy.html sadly omits scipy.signal.filtfilt().
2489    # To include a (non-traceable) numpy function in Jax requires jax.experimental.host_callback
2490    # and/or defining a new jax.core.Primitive (which allows differentiability).  See
2491    # https://github.com/google/jax/issues/1142#issuecomment-544286585
2492    # https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb  :-(
2493    # https://github.com/google/jax/issues/5934
2494
2495    @jax.custom_gradient  # type: ignore[misc]
2496    def jax_inverse_convolution(x: _JaxArray) -> _JaxArray:
2497      # This function is not jax-traceable due to the presence of to_py(), so jit and grad fail.
2498      x_py = np.asarray(x)  # to_py() deprecated.
2499      a = _apply_digital_filter_1d_numpy(x_py, gridtype, boundary, cval, filter, axis, False)
2500      y = jnp.asarray(a)
2501
2502      def grad(grad_output: _JaxArray) -> _JaxArray:
2503        grad_output_py = np.asarray(grad_output)  # to_py() deprecated.
2504        a = _apply_digital_filter_1d_numpy(
2505            grad_output_py, gridtype, boundary, cval, filter, axis, True
2506        )
2507        return jnp.asarray(a)
2508
2509      return y, grad
2510
2511    return jax_inverse_convolution(array)
2512
2513  assert arraylib == 'numpy'
2514  assert isinstance(array, np.ndarray)  # Help mypy.
2515  return _apply_digital_filter_1d_numpy(array, gridtype, boundary, cval, filter, axis, False)
2516
2517
2518def _apply_digital_filter_1d_numpy(
2519    array: _NDArray,
2520    gridtype: Gridtype,
2521    boundary: Boundary,
2522    cval: _ArrayLike,
2523    filter: Filter,
2524    axis: int,
2525    compute_backward: bool,
2526    /,
2527) -> _NDArray:
2528  """Version of _apply_digital_filter_1d` specialized to numpy array."""
2529  assert np.issubdtype(array.dtype, np.inexact)
2530  cval = np.asarray(cval).astype(array.dtype, copy=False)
2531
2532  # Use fast spline_filter1d() if we have a compatible gridtype, boundary, and filter:
2533  mode = {
2534      ('reflect', 'dual'): 'reflect',
2535      ('reflect', 'primal'): 'mirror',
2536      ('wrap', 'dual'): 'grid-wrap',
2537      ('wrap', 'primal'): 'wrap',
2538  }.get((boundary.name, gridtype.name))
2539  filter_is_compatible = isinstance(filter, CardinalBsplineFilter)
2540  use_split_filter1d = filter_is_compatible and mode
2541  if use_split_filter1d:
2542    assert isinstance(filter, CardinalBsplineFilter)  # Help mypy.
2543    assert filter.degree >= 2
2544    # compute_backward=True is same: matrix is symmetric and cval is unused.
2545    return scipy.ndimage.spline_filter1d(
2546        array, axis=axis, order=filter.degree, mode=mode, output=array.dtype
2547    )
2548
2549  array_dim = np.moveaxis(array, axis, 0)
2550  l = original_l = math.ceil(filter.radius) - 1
2551  x = np.arange(-l, l + 1, dtype=array.real.dtype)
2552  values = filter(x)
2553  size = array_dim.shape[0]
2554  src_index = np.arange(size)[:, None] + np.arange(len(values)) - l
2555  weight = np.full((size, len(values)), values)
2556  src_position = np.broadcast_to(0.5, len(values))
2557  src_index, weight = boundary.apply(src_index, weight, src_position, size, gridtype)
2558  if gridtype.name == 'primal' and boundary.name == 'wrap':
2559    # Overwrite redundant last row to preserve unreferenced last sample and thereby make the
2560    # matrix non-singular.
2561    src_index[-1] = [size - 1] + [0] * (src_index.shape[1] - 1)
2562    weight[-1] = [1.0] + [0.0] * (weight.shape[1] - 1)
2563  bandwidth = abs(src_index - np.arange(size)[:, None]).max()
2564  is_banded = bandwidth <= l + 1  # Add one for quadratic boundary and l == 1.
2565  # Currently, matrix is always banded unless boundary.name == 'wrap'.
2566
2567  data = weight.reshape(-1).astype(array.dtype, copy=False)
2568  row_ind = np.arange(size).repeat(src_index.shape[1])
2569  col_ind = src_index.reshape(-1)
2570  matrix = scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=(size, size))
2571  if compute_backward:
2572    matrix = matrix.T
2573
2574  if boundary.uses_cval and not compute_backward:
2575    cval_weight = 1.0 - np.asarray(matrix.sum(axis=-1))[:, 0]
2576    if array_dim.ndim == 2:  # Handle the case that we have array_flat.
2577      cval = np.tile(cval.reshape(-1), array_dim[0].size // cval.size)
2578    array_dim = array_dim - cval_weight.reshape(-1, *(1,) * array_dim[0].ndim) * cval
2579
2580  if is_banded:
2581    matrix = matrix.todia()
2582    assert np.all(np.diff(matrix.offsets) == 1)  # Consecutive, often [-l, l].
2583    l, u = -matrix.offsets[0], matrix.offsets[-1]
2584    assert l <= original_l + 1 and u <= original_l + 1, (l, u, original_l)
2585    options = dict(check_finite=False, overwrite_ab=True, overwrite_b=False)
2586    if _is_symmetric(matrix):
2587      array_dim = scipy.linalg.solveh_banded(matrix.data[-1 : l - 1 : -1], array_dim, **options)
2588    else:
2589      array_dim = scipy.linalg.solve_banded((l, u), matrix.data[::-1], array_dim, **options)
2590
2591  else:
2592    lu = scipy.sparse.linalg.splu(matrix.tocsc(), permc_spec='NATURAL')
2593    assert all(s <= size * len(values) for s in (lu.L.nnz, lu.U.nnz))  # Sparse.
2594    array_dim = lu.solve(array_dim.reshape(array_dim.shape[0], -1)).reshape(array_dim.shape)
2595
2596  return np.moveaxis(array_dim, 0, axis)
2597
2598
2599def resize(
2600    array: _Array,
2601    /,
2602    shape: Iterable[int],
2603    *,
2604    gridtype: str | Gridtype | None = None,
2605    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2606    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2607    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2608    cval: _ArrayLike = 0,
2609    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2610    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2611    gamma: str | Gamma | None = None,
2612    src_gamma: str | Gamma | None = None,
2613    dst_gamma: str | Gamma | None = None,
2614    scale: float | Iterable[float] = 1.0,
2615    translate: float | Iterable[float] = 0.0,
2616    precision: _DTypeLike = None,
2617    dtype: _DTypeLike = None,
2618    dim_order: Iterable[int] | None = None,
2619    num_threads: int | Literal['auto'] = 'auto',
2620) -> _Array:
2621  """Resample `array` (a grid of sample values) onto a grid with resolution `shape`.
2622
2623  The source `array` is any object recognized by `ARRAYLIBS`.  It is interpreted as a grid
2624  with `len(shape)` domain coordinate dimensions, where each grid sample value has shape
2625  `array.shape[len(shape):]`.
2626
2627  Some examples:
2628
2629  - A grayscale image has `array.shape = height, width` and resizing it with `len(shape) == 2`
2630    produces a new image of scalar values.
2631  - An RGB image has `array.shape = height, width, 3` and resizing it with `len(shape) == 2`
2632    produces a new image of RGB values.
2633  - An 3D grid of 3x3 Jacobians has `array.shape = Z, Y, X, 3, 3` and resizing it with
2634    `len(shape) == 3` produces a new 3D grid of Jacobians.
2635
2636  This function also allows scaling and translation from the source domain to the output domain
2637  through the parameters `scale` and `translate`.  For more general transforms, see `resample`.
2638
2639  Args:
2640    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
2641      The array must have numeric type.  Its first `len(shape)` dimensions are the domain
2642      coordinate dimensions.  Each grid dimension must be at least 1 for a `'dual'` grid or
2643      at least 2 for a `'primal'` grid.
2644    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2645      `array` must have at least as many dimensions as `len(shape)`.
2646    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
2647      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
2648      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
2649    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2650      Parameters `gridtype` and `src_gridtype` cannot both be set.
2651    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2652      Parameters `gridtype` and `dst_gridtype` cannot both be set.
2653    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2654      a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses `'reflect'`
2655      for upsampling and `'clamp'` for downsampling.
2656    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
2657      onto `array.shape[len(shape):]`.
2658    filter: The reconstruction kernel for each dimension in `shape`, specified as either a filter
2659      name in `FILTERS` or a `Filter` instance.  It is used during upsampling (i.e., magnification).
2660    prefilter: The prefilter kernel for each dimension in `shape`, specified as either a filter
2661      name in `FILTERS` or a `Filter` instance.  It is used during downsampling
2662      (i.e., minification).  If `None`, it inherits the value of `filter`.  The default
2663      `'lanczos3'` is good for natural images.  For vector graphics images, `'trapezoid'` is better
2664      because it avoids ringing artifacts.
2665    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from
2666      `array` and when creating output grid samples.  It is specified as either a name in `GAMMAS`
2667      or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default is
2668      `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
2669      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
2670      range [0.0, 1.0].
2671    src_gamma: Component transfer function used to "decode" `array` samples.
2672      Parameters `gamma` and `src_gamma` cannot both be set.
2673    dst_gamma: Component transfer function used to "encode" the output samples.
2674      Parameters `gamma` and `dst_gamma` cannot both be set.
2675    scale: Scaling factor applied to each dimension of the source domain when it is mapped onto
2676      the destination domain.
2677    translate: Offset applied to each dimension of the scaled source domain when it is mapped onto
2678      the destination domain.
2679    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
2680      on `array.dtype` and `dtype`.
2681    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
2682      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
2683      to the uint range.
2684    dim_order: Override the automatically selected order in which the grid dimensions are resized.
2685      Must contain a permutation of `range(len(shape))`.
2686    num_threads: Used to determine multithread parallelism if `array` is from `numpy`.  If set to
2687      `'auto'`, it is selected automatically.  Otherwise, it must be a positive integer.
2688
2689  Returns:
2690    An array of the same class as the source `array`, with shape `shape + array.shape[len(shape):]`
2691      and data type `dtype`.
2692
2693  **Example of image upsampling:**
2694
2695  >>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
2696  >>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.
2697
2698  <center>
2699  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_upsampled.png"/>
2700  </center>
2701
2702  **Example of image downsampling:**
2703
2704  >>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
2705  >>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
2706  >>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
2707  >>> downsampled = resize(array, (24, 48))
2708
2709  <center>
2710  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_downsampled2.png"/>
2711  </center>
2712
2713  **Unit test:**
2714
2715  >>> result = resize([1.0, 4.0, 5.0], shape=(4,))
2716  >>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
2717  """
2718  if isinstance(array, (tuple, list)):
2719    array = np.asarray(array)
2720  arraylib = _arr_arraylib(array)
2721  array_dtype = _arr_dtype(array)
2722  if not np.issubdtype(array_dtype, np.number):
2723    raise ValueError(f'Type {array.dtype} is not numeric.')
2724  shape2 = tuple(shape)
2725  array_ndim = len(array.shape)
2726  if not 0 < len(shape2) <= array_ndim:
2727    raise ValueError(f'Shape {array.shape} cannot be resized to {shape2}.')
2728  src_shape = array.shape[: len(shape2)]
2729  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2730      gridtype, src_gridtype, dst_gridtype, len(shape2), len(shape2)
2731  )
2732  boundary2 = np.broadcast_to(np.array(boundary), len(shape2))
2733  cval = np.broadcast_to(cval, array.shape[len(shape2) :])
2734  prefilter = filter if prefilter is None else prefilter
2735  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape2))]
2736  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), len(shape2))]
2737  dtype = array_dtype if dtype is None else np.dtype(dtype)
2738  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, array_dtype, dtype)
2739  scale2 = np.broadcast_to(np.array(scale), len(shape2))
2740  translate2 = np.broadcast_to(np.array(translate), len(shape2))
2741  del (shape, src_gridtype, dst_gridtype, boundary, filter, prefilter)
2742  del (src_gamma, dst_gamma, scale, translate)
2743  precision = _get_precision(precision, [array_dtype, dtype], [])
2744  weight_precision = _real_precision(precision)
2745
2746  is_noop = (
2747      all(src == dst for src, dst in zip(src_shape, shape2))
2748      and all(gt1 == gt2 for gt1, gt2 in zip(src_gridtype2, dst_gridtype2))
2749      and all(f.interpolating for f in prefilter2)
2750      and np.all(scale2 == 1.0)
2751      and np.all(translate2 == 0.0)
2752      and src_gamma2 == dst_gamma2
2753  )
2754  if is_noop:
2755    return array
2756
2757  if dim_order is None:
2758    dim_order = _arr_best_dims_order_for_resize(array, shape2)
2759  else:
2760    dim_order = tuple(dim_order)
2761    if sorted(dim_order) != list(range(len(shape2))):
2762      raise ValueError(f'{dim_order} not a permutation of {list(range(len(shape2)))}.')
2763
2764  array = src_gamma2.decode(array, precision)
2765
2766  can_use_fast_box_downsampling = (
2767      using_numba
2768      and arraylib == 'numpy'
2769      and len(shape2) == 2
2770      and array_ndim in (2, 3)
2771      and all(src > dst for src, dst in zip(src_shape, shape2))
2772      and all(src % dst == 0 for src, dst in zip(src_shape, shape2))
2773      and all(gridtype.name == 'dual' for gridtype in src_gridtype2)
2774      and all(gridtype.name == 'dual' for gridtype in dst_gridtype2)
2775      and all(f.name in ('box', 'trapezoid') for f in prefilter2)
2776      and np.all(scale2 == 1.0)
2777      and np.all(translate2 == 0.0)
2778  )
2779  if can_use_fast_box_downsampling:
2780    assert isinstance(array, np.ndarray)  # Help mypy.
2781    array = _downsample_in_2d_using_box_filter(array, cast(Any, shape2))
2782    array = dst_gamma2.encode(array, dtype)
2783    return array
2784
2785  # Multidimensional resize can be expressed using einsum() with multiple per-dim resize matrices,
2786  # e.g., as in jax.image.resize().  A benefit is to seek the optimal order of multiplications.
2787  # However, efficiency often requires sparse resize matrices, which are unsupported in einsum().
2788  # Sparse tensors requested for tf.einsum: https://github.com/tensorflow/tensorflow/issues/43497
2789  # https://github.com/tensor-compiler/taco: C++ library that computes tensor algebra expressions
2790  # on sparse and dense tensors; however it does not interoperate with tensorflow, torch, or jax.
2791
2792  for dim in dim_order:
2793    skip_resize_on_this_dim = (
2794        shape2[dim] == array.shape[dim]
2795        and scale2[dim] == 1.0
2796        and translate2[dim] == 0.0
2797        and filter2[dim].interpolating
2798    )
2799    if skip_resize_on_this_dim:
2800      continue
2801
2802    def get_is_minification() -> bool:
2803      src_in_samples = src_gridtype2[dim].size_in_samples(array.shape[dim])
2804      dst_in_samples = dst_gridtype2[dim].size_in_samples(shape2[dim])
2805      return dst_in_samples / src_in_samples * scale2[dim] < 1.0
2806
2807    is_minification = get_is_minification()
2808    boundary_dim = boundary2[dim]
2809    if boundary_dim == 'auto':
2810      boundary_dim = 'clamp' if is_minification else 'reflect'
2811    boundary_dim = _get_boundary(boundary_dim)
2812    resize_matrix, cval_weight = _create_resize_matrix(
2813        array.shape[dim],
2814        shape2[dim],
2815        src_gridtype=src_gridtype2[dim],
2816        dst_gridtype=dst_gridtype2[dim],
2817        boundary=boundary_dim,
2818        filter=filter2[dim],
2819        prefilter=prefilter2[dim],
2820        scale=scale2[dim],
2821        translate=translate2[dim],
2822        dtype=weight_precision,
2823        arraylib=arraylib,
2824    )
2825
2826    array_dim: _Array = _arr_moveaxis(array, dim, 0)
2827    array_flat = _arr_reshape(array_dim, (array_dim.shape[0], -1))
2828    array_flat = _arr_possibly_make_contiguous(array_flat)
2829    if not is_minification and filter2[dim].requires_digital_filter:
2830      array_flat = _apply_digital_filter_1d(
2831          array_flat, src_gridtype2[dim], boundary_dim, cval, filter2[dim]
2832      )
2833
2834    array_flat = _arr_matmul_sparse_dense(resize_matrix, array_flat, num_threads=num_threads)
2835    if cval_weight is not None:
2836      cval_flat = np.broadcast_to(cval, array_dim.shape[1:]).reshape(-1)
2837      if np.issubdtype(array_dtype, np.complexfloating):
2838        cval_weight = _arr_astype(cval_weight, array_dtype)  # (Only necessary for 'tensorflow'.)
2839      array_flat += cval_weight[:, None] * cval_flat
2840
2841    if is_minification and filter2[dim].requires_digital_filter:  # use prefilter2[dim]?
2842      array_flat = _apply_digital_filter_1d(
2843          array_flat, dst_gridtype2[dim], boundary_dim, cval, filter2[dim]
2844      )
2845    array_dim = _arr_reshape(array_flat, (array_flat.shape[0], *array_dim.shape[1:]))
2846    array = _arr_moveaxis(array_dim, 0, dim)
2847
2848  array = dst_gamma2.encode(array, dtype)
2849  return array
2850
2851
2852_original_resize = resize
2853
2854
2855def resize_in_arraylib(array: _NDArray, /, *args: Any, arraylib: str, **kwargs: Any) -> _NDArray:
2856  """Evaluate the `resize()` operation using the specified array library from `ARRAYLIBS`."""
2857  _check_eq(_arr_arraylib(array), 'numpy')
2858  return _arr_numpy(_original_resize(_make_array(array, arraylib), *args, **kwargs))
2859
2860
2861def resize_in_numpy(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2862  """Evaluate the `resize()` operation using the `numpy` library."""
2863  return resize_in_arraylib(array, *args, arraylib='numpy', **kwargs)
2864
2865
2866def resize_in_tensorflow(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2867  """Evaluate the `resize()` operation using the `tensorflow` library."""
2868  return resize_in_arraylib(array, *args, arraylib='tensorflow', **kwargs)
2869
2870
2871def resize_in_torch(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2872  """Evaluate the `resize()` operation using the `torch` library."""
2873  return resize_in_arraylib(array, *args, arraylib='torch', **kwargs)
2874
2875
2876def resize_in_jax(array: _NDArray, /, *args: Any, **kwargs: Any) -> _NDArray:
2877  """Evaluate the `resize()` operation using the `jax` library."""
2878  return resize_in_arraylib(array, *args, arraylib='jax', **kwargs)
2879
2880
2881def _resize_possibly_in_arraylib(
2882    array: _Array, /, *args: Any, arraylib: str, **kwargs: Any
2883) -> _AnyArray:
2884  """If `array` is from numpy, evaluate `resize()` using the array library from `ARRAYLIBS`."""
2885  if _arr_arraylib(array) == 'numpy':
2886    return _arr_numpy(
2887        _original_resize(_make_array(cast(_ArrayLike, array), arraylib), *args, **kwargs)
2888    )
2889  return _original_resize(array, *args, **kwargs)
2890
2891
2892@functools.cache
2893def _create_jaxjit_resize() -> Callable[..., _Array]:
2894  """Lazily invoke `jax.jit` on `resize`."""
2895  import jax
2896
2897  jitted: Any = jax.jit(
2898      _original_resize, static_argnums=(1,), static_argnames=list(_original_resize.__kwdefaults__)
2899  )
2900  return jitted
2901
2902
2903def jaxjit_resize(array: _Array, /, *args: Any, **kwargs: Any) -> _Array:
2904  """Compute `resize` but with resize function jitted using Jax."""
2905  return _create_jaxjit_resize()(array, *args, **kwargs)  # pylint: disable=not-callable
2906
2907
2908def uniform_resize(
2909    array: _Array,
2910    /,
2911    shape: Iterable[int],
2912    *,
2913    object_fit: Literal['contain', 'cover'] = 'contain',
2914    gridtype: str | Gridtype | None = None,
2915    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2916    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2917    boundary: str | Boundary | Iterable[str | Boundary] = 'natural',  # Instead of 'auto' default.
2918    scale: float | Iterable[float] = 1.0,
2919    translate: float | Iterable[float] = 0.0,
2920    **kwargs: Any,
2921) -> _Array:
2922  """Resample `array` onto a grid with resolution `shape` but with uniform scaling.
2923
2924  Calls function `resize` with `scale` and `translate` set such that the aspect ratio of `array`
2925  is preserved.  The effect is similar to CSS `object-fit: contain`.
2926  The parameter `boundary` (whose default is changed to `'natural'`) determines the values assigned
2927  outside the source domain.
2928
2929  Args:
2930    array: Regular grid of source sample values.
2931    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2932      `array` must have at least as many dimensions as `len(shape)`.
2933    object_fit: Like CSS `object-fit`.  If `'contain'`, `array` is resized uniformly to fit within
2934      `shape`. If `'cover'`, `array` is resized to fully cover `shape`.
2935    gridtype: Placement of samples on all dimensions of both the source and output domain grids.
2936    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2937    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2938    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2939      a name in `BOUNDARIES` or a `Boundary` instance.  The default is `'natural'`, which assigns
2940      `cval` to output points that map outside the source unit domain.
2941    scale: Parameter may not be specified.
2942    translate: Parameter may not be specified.
2943    **kwargs: Additional parameters for `resize` function (including `cval`).
2944
2945  Returns:
2946    An array with shape `shape + array.shape[len(shape):]`.
2947
2948  >>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
2949  array([[0., 1., 1., 0.],
2950         [0., 1., 1., 0.]])
2951
2952  >>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
2953  array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
2954         [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
2955
2956  >>> a = np.arange(6.0).reshape(2, 3)
2957  >>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
2958  array([[0.5, 1.5],
2959         [3.5, 4.5]])
2960  """
2961  if scale != 1.0 or translate != 0.0:
2962    raise ValueError('`uniform_resize()` does not accept `scale` or `translate` parameters.')
2963  if isinstance(array, (tuple, list)):
2964    array = np.asarray(array)
2965  shape = tuple(shape)
2966  array_ndim = len(array.shape)
2967  if not 0 < len(shape) <= array_ndim:
2968    raise ValueError(f'Shape {array.shape} cannot be resized to {shape}.')
2969  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2970      gridtype, src_gridtype, dst_gridtype, len(shape), len(shape)
2971  )
2972  raw_scales = [
2973      dst_gridtype2[dim].size_in_samples(shape[dim])
2974      / src_gridtype2[dim].size_in_samples(array.shape[dim])
2975      for dim in range(len(shape))
2976  ]
2977  scale0 = {'contain': min(raw_scales), 'cover': max(raw_scales)}[object_fit]
2978  scale2 = scale0 / np.array(raw_scales)
2979  translate = (1.0 - scale2) / 2
2980  return resize(array, shape, boundary=boundary, scale=scale2, translate=translate, **kwargs)
2981
2982
2983_MAX_BLOCK_SIZE_RECURSING = -999  # Special value to indicate re-invocation on partitioned blocks.
2984
2985
2986def resample(
2987    array: _Array,
2988    /,
2989    coords: _ArrayLike,
2990    *,
2991    gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual',
2992    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2993    cval: _ArrayLike = 0,
2994    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2995    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2996    gamma: str | Gamma | None = None,
2997    src_gamma: str | Gamma | None = None,
2998    dst_gamma: str | Gamma | None = None,
2999    jacobian: _ArrayLike | None = None,
3000    precision: _DTypeLike = None,
3001    dtype: _DTypeLike = None,
3002    max_block_size: int = 40_000,
3003    debug: bool = False,
3004) -> _Array:
3005  """Interpolate `array` (a grid of samples) at specified unit-domain coordinates `coords`.
3006
3007  The last dimension of `coords` contains unit-domain coordinates at which to interpolate the
3008  domain grid samples in `array`.
3009
3010  The number of coordinates (`coords.shape[-1]`) determines how to interpret `array`: its first
3011  `coords.shape[-1]` dimensions define the grid, and the remaining dimensions describe each grid
3012  sample (e.g., scalar, vector, tensor).
3013
3014  Concretely, the grid has shape `array.shape[:coords.shape[-1]]` and each grid sample has shape
3015  `array.shape[coords.shape[-1]:]`.
3016
3017  Examples include:
3018
3019  - Resample a grayscale image with `array.shape = height, width` onto a new grayscale image with
3020    `new.shape = height2, width2` by using `coords.shape = height2, width2, 2`.
3021
3022  - Resample an RGB image with `array.shape = height, width, 3` onto a new RGB image with
3023    `new.shape = height2, width2, 3` by using `coords.shape = height2, width2, 2`.
3024
3025  - Sample an RGB image at `num` 2D points along a line segment by using `coords.shape = num, 2`.
3026
3027  - Sample an RGB image at a single 2D point by using `coords.shape = (2,)`.
3028
3029  - Sample a 3D grid of 3x3 Jacobians with `array.shape = nz, ny, nx, 3, 3` along a 2D plane by
3030    using `coords.shape = height, width, 3`.
3031
3032  - Map a grayscale image through a color map by using `array.shape = 256, 3` and
3033    `coords.shape = height, width`.
3034
3035  Args:
3036    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3037      The array must have numeric type.  The coordinate dimensions appear first, and
3038      each grid sample may have an arbitrary shape.  Each grid dimension must be at least 1 for
3039      a `'dual'` grid or at least 2 for a `'primal'` grid.
3040    coords: Grid of points at which to resample `array`.  The point coordinates are in the last
3041      dimension of `coords`.  The domain associated with the source grid is a unit hypercube,
3042      i.e. with a range [0, 1] on each coordinate dimension.  The output grid has shape
3043      `coords.shape[:-1]` and each of its grid samples has shape `array.shape[coords.shape[-1]:]`.
3044    gridtype: Placement of the samples in the source domain grid for each dimension, specified as
3045      either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`.
3046    boundary: The reconstruction boundary rule for each dimension in `coords.shape[-1]`, specified
3047      as either a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses
3048      `'reflect'` for upsampling and `'clamp'` for downsampling.
3049    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
3050      onto the shape `array.shape[coords.shape[-1]:]`.
3051    filter: The reconstruction kernel for each dimension in `coords.shape[-1]`, specified as either
3052      a filter name in `FILTERS` or a `Filter` instance.
3053    prefilter: The prefilter kernel for each dimension in `coords.shape[:-1]`, specified as either
3054      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3055      (i.e., minification).  If `None`, it inherits the value of `filter`.
3056    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples
3057      from `array` and when creating output grid samples.  It is specified as either a name in
3058      `GAMMAS` or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default
3059      is `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
3060      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
3061      range [0.0, 1.0].
3062    src_gamma: Component transfer function used to "decode" `array` samples.
3063      Parameters `gamma` and `src_gamma` cannot both be set.
3064    dst_gamma: Component transfer function used to "encode" the output samples.
3065      Parameters `gamma` and `dst_gamma` cannot both be set.
3066    jacobian: Optional array, which must be broadcastable onto the shape
3067      `coords.shape[:-1] + (coords.shape[-1], coords.shape[-1])`, storing for each point in the
3068      output grid the Jacobian matrix of the map from the unit output domain to the unit source
3069      domain.  If omitted, it is estimated by computing finite differences on `coords`.
3070    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3071      on `array.dtype`, `coords.dtype`, and `dtype`.
3072    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3073      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3074      to the uint range.
3075    max_block_size: If nonzero, maximum number of grid points in `coords` before the resampling
3076      evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
3077    debug: Show internal information.
3078
3079  Returns:
3080    A new sample grid of shape `coords.shape[:-1]`, represented as an array of shape
3081    `coords.shape[:-1] + array.shape[coords.shape[-1]:]`, of the same array library type as
3082    the source array.
3083
3084  **Example of resample operation:**
3085
3086  <center>
3087  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_warp_coords.png"/>
3088  </center>
3089
3090  For reference, the identity resampling for a scalar-valued grid with the default grid-type
3091  `'dual'` is:
3092
3093  >>> array = np.random.default_rng(0).random((5, 7, 3))
3094  >>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
3095  >>> new_array = resample(array, coords)
3096  >>> assert np.allclose(new_array, array)
3097
3098  It is more efficient to use the function `resize` for the special case where the `coords` are
3099  obtained as simple scaling and translation of a new regular grid over the source domain:
3100
3101  >>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
3102  >>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
3103  >>> coords = (coords - translate) / scale
3104  >>> resampled = resample(array, coords)
3105  >>> resized = resize(array, new_shape, scale=scale, translate=translate)
3106  >>> assert np.allclose(resampled, resized)
3107  """
3108  if isinstance(array, (tuple, list)):
3109    array = np.asarray(array)
3110  arraylib = _arr_arraylib(array)
3111  if len(array.shape) == 0:
3112    array = array[None]
3113  coords = np.atleast_1d(coords)
3114  if not np.issubdtype(_arr_dtype(array), np.number):
3115    raise ValueError(f'Type {array.dtype} is not numeric.')
3116  if not np.issubdtype(coords.dtype, np.floating):
3117    raise ValueError(f'Type {coords.dtype} is not floating.')
3118  array_ndim = len(array.shape)
3119  if coords.ndim == 1 and coords.shape[0] > 1 and array_ndim == 1:
3120    coords = coords[:, None]
3121  grid_ndim = coords.shape[-1]
3122  grid_shape = array.shape[:grid_ndim]
3123  sample_shape = array.shape[grid_ndim:]
3124  resampled_ndim = coords.ndim - 1
3125  resampled_shape = coords.shape[:-1]
3126  if grid_ndim > array_ndim:
3127    raise ValueError(
3128        f'There are more coordinate dimensions ({grid_ndim}) in {coords=} than in {array.shape=}.'
3129    )
3130  gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(gridtype), grid_ndim)]
3131  boundary2 = np.broadcast_to(np.array(boundary), grid_ndim).tolist()
3132  cval = np.broadcast_to(cval, sample_shape)
3133  prefilter = filter if prefilter is None else prefilter
3134  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), grid_ndim)]
3135  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), resampled_ndim)]
3136  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3137  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, _arr_dtype(array), dtype)
3138  del gridtype, boundary, filter, prefilter, src_gamma, dst_gamma
3139  if jacobian is not None:
3140    jacobian = np.broadcast_to(jacobian, resampled_shape + (coords.shape[-1],) * 2)
3141  precision = _get_precision(precision, [_arr_dtype(array), dtype], [coords.dtype])
3142  weight_precision = _real_precision(precision)
3143  coords = coords.astype(weight_precision, copy=False)
3144  is_minification = False  # Current limitation; no prefiltering!
3145  assert max_block_size >= 0 or max_block_size == _MAX_BLOCK_SIZE_RECURSING
3146  for dim in range(grid_ndim):
3147    if boundary2[dim] == 'auto':
3148      boundary2[dim] = 'clamp' if is_minification else 'reflect'
3149    boundary2[dim] = _get_boundary(boundary2[dim])
3150
3151  if max_block_size != _MAX_BLOCK_SIZE_RECURSING:
3152    array = src_gamma2.decode(array, precision)
3153    for dim in range(grid_ndim):
3154      assert not is_minification
3155      if filter2[dim].requires_digital_filter:
3156        array = _apply_digital_filter_1d(
3157            array, gridtype2[dim], boundary2[dim], cval, filter2[dim], axis=dim
3158        )
3159
3160  if math.prod(resampled_shape) > max_block_size > 0:
3161    block_shape = _block_shape_with_min_size(resampled_shape, max_block_size)
3162    if debug:
3163      print(f'(resample: splitting coords into blocks {block_shape}).')
3164    coord_blocks = _split_array_into_blocks(coords, block_shape)
3165
3166    def process_block(coord_block: _NDArray) -> _Array:
3167      return resample(
3168          array,
3169          coord_block,
3170          gridtype=gridtype2,
3171          boundary=boundary2,
3172          cval=cval,
3173          filter=filter2,
3174          prefilter=prefilter2,
3175          src_gamma='identity',
3176          dst_gamma=dst_gamma2,
3177          jacobian=jacobian,
3178          precision=precision,
3179          dtype=dtype,
3180          max_block_size=_MAX_BLOCK_SIZE_RECURSING,
3181      )
3182
3183    result_blocks = _map_function_over_blocks(coord_blocks, process_block)
3184    array = _merge_array_from_blocks(result_blocks)
3185    return array
3186
3187  # A concrete example of upsampling:
3188  #   array = np.ones((5, 7, 3))  # source RGB image has height=5 width=7
3189  #   coords = np.random.default_rng(0).random((8, 9, 2))  # output RGB image has height=8 width=9
3190  #   resample(array, coords, filter=('cubic', 'lanczos3'))
3191  #   grid_shape = 5, 7  grid_ndim = 2
3192  #   resampled_shape = 8, 9  resampled_ndim = 2
3193  #   sample_shape = (3,)
3194  #   src_float_index.shape = 8, 9
3195  #   src_first_index.shape = 8, 9
3196  #   sample_index.shape = (4,) for dim == 0, then (6,) for dim == 1
3197  #   weight = [shape(8, 9, 4), shape(8, 9, 6)]
3198  #   src_index = [shape(8, 9, 4), shape(8, 9, 6)]
3199
3200  # Both:[shape(8, 9, 4), shape(8, 9, 6)]
3201  weight: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3202  src_index: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3203  uses_cval = False
3204  all_num_samples = []  # will be [4, 6]
3205
3206  for dim in range(grid_ndim):
3207    src_size = grid_shape[dim]  # scalar
3208    coords_dim = coords[..., dim]  # (8, 9)
3209    radius = filter2[dim].radius  # scalar
3210    num_samples = int(np.ceil(radius * 2))  # scalar
3211    all_num_samples.append(num_samples)
3212
3213    boundary_dim = boundary2[dim]
3214    coords_dim = boundary_dim.preprocess_coordinates(coords_dim)
3215
3216    # Sample positions mapped back to source unit domain [0, 1].
3217    src_float_index = gridtype2[dim].index_from_point(coords_dim, src_size)  # (8, 9)
3218    src_first_index = (
3219        np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
3220        - (num_samples - 1) // 2
3221    )  # (8, 9)
3222
3223    sample_index = np.arange(num_samples, dtype=np.int32)  # (4,) then (6,)
3224    src_index[dim] = src_first_index[..., None] + sample_index  # (8, 9, 4) then (8, 9, 6)
3225    if filter2[dim].name == 'trapezoid':
3226      # (It might require changing the filter radius at every sample.)
3227      raise ValueError('resample() cannot use adaptive `trapezoid` filter.')
3228    if filter2[dim].name == 'impulse':
3229      weight[dim] = np.ones_like(src_index[dim], weight_precision)
3230    else:
3231      x = src_float_index[..., None] - src_index[dim].astype(weight_precision)
3232      weight[dim] = filter2[dim](x).astype(weight_precision, copy=False)
3233      if filter2[dim].name != 'narrowbox' and (
3234          is_minification or not filter2[dim].partition_of_unity
3235      ):
3236        weight[dim] = weight[dim] / weight[dim].sum(axis=-1)[..., None]
3237
3238    src_index[dim], weight[dim] = boundary_dim.apply(
3239        src_index[dim], weight[dim], coords_dim, src_size, gridtype2[dim]
3240    )
3241    if boundary_dim.uses_cval or filter2[dim].name == 'narrowbox':
3242      uses_cval = True
3243
3244  # Gather the samples.
3245
3246  # Recall that src_index = [shape(8, 9, 4), shape(8, 9, 6)].
3247  src_index_expanded = []
3248  for dim in range(grid_ndim):
3249    src_index_dim = np.moveaxis(
3250        src_index[dim].reshape(src_index[dim].shape + (1,) * (grid_ndim - 1)),
3251        resampled_ndim,
3252        resampled_ndim + dim,
3253    )
3254    src_index_expanded.append(src_index_dim)
3255  indices = tuple(src_index_expanded)  # (shape(8, 9, 4, 1), shape(8, 9, 1, 6))
3256  samples = _arr_getitem(array, indices)  # (8, 9, 4, 6, 3)
3257
3258  # Indirectly derive samples.ndim (which is unavailable during Tensorflow grad computation).
3259  samples_ndim = resampled_ndim + grid_ndim + len(sample_shape)
3260
3261  # Compute an Einstein summation over the samples and each of the per-dimension weights.
3262
3263  def label(dims: Iterable[int]) -> str:
3264    return ''.join(chr(ord('a') + i) for i in dims)
3265
3266  operands = [samples]  # (8, 9, 4, 6, 3)
3267  assert samples_ndim < 26  # Letters 'a' through 'z'.
3268  labels = [label(range(samples_ndim))]  # ['abcde']
3269  for dim in range(grid_ndim):
3270    operands.append(weight[dim])  # (8, 9, 4), then (8, 9, 6)
3271    labels.append(label(list(range(resampled_ndim)) + [resampled_ndim + dim]))  # 'abc' then 'abd'
3272  output_label = label(
3273      list(range(resampled_ndim)) + list(range(resampled_ndim + grid_ndim, samples_ndim))
3274  )  # 'abe'
3275  subscripts = ','.join(labels) + '->' + output_label  # 'abcde,abc,abd->abe'
3276  array = _arr_einsum(subscripts, *operands)  # (8, 9, 3)
3277
3278  # Gathering `samples` is the memory bottleneck.  It would be ideal if the gather() and einsum()
3279  # computations could be fused.  In Jax, https://github.com/google/jax/issues/3206 suggests
3280  # that this may become possible.  In any case, for large outputs it helps to partition the
3281  # evaluation over output tiles (using max_block_size).
3282
3283  if uses_cval:
3284    cval_weight = 1.0 - np.multiply.reduce(
3285        [weight[dim].sum(axis=-1) for dim in range(resampled_ndim)]
3286    )  # (8, 9)
3287    cval_weight_reshaped = cval_weight.reshape(cval_weight.shape + (1,) * len(sample_shape))
3288    array += _make_array((cval_weight_reshaped * cval).astype(precision, copy=False), arraylib)
3289
3290  array = dst_gamma2.encode(array, dtype)
3291  return array
3292
3293
3294def resample_affine(
3295    array: _Array,
3296    /,
3297    shape: Iterable[int],
3298    matrix: _ArrayLike,
3299    *,
3300    gridtype: str | Gridtype | None = None,
3301    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3302    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3303    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3304    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3305    precision: _DTypeLike = None,
3306    dtype: _DTypeLike = None,
3307    **kwargs: Any,
3308) -> _Array:
3309  """Resample a source array using an affinely transformed grid of given shape.
3310
3311  The `matrix` transformation can be linear,
3312    `source_point = matrix @ destination_point`,
3313  or it can be affine where the last matrix column is an offset vector,
3314    `source_point = matrix @ (destination_point, 1.0)`.
3315
3316  Args:
3317    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3318      The array must have numeric type.  The number of grid dimensions is determined from
3319      `matrix.shape[0]`; the remaining dimensions are for each sample value and are all
3320      linearly interpolated.
3321    shape: Dimensions of the desired destination grid.  The number of destination grid dimensions
3322      may be different from that of the source grid.
3323    matrix: 2D array for a linear or affine transform from unit-domain destination points
3324      (in a space with `len(shape)` dimensions) into unit-domain source points (in a space with
3325      `matrix.shape[0]` dimensions).  If the matrix has `len(shape) + 1` columns, the last column
3326      is the affine offset (i.e., translation).
3327    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
3328      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
3329      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
3330    src_gridtype: Placement of samples in the source domain grid for each dimension.
3331      Parameters `gridtype` and `src_gridtype` cannot both be set.
3332    dst_gridtype: Placement of samples in the output domain grid for each dimension.
3333      Parameters `gridtype` and `dst_gridtype` cannot both be set.
3334    filter: The reconstruction kernel for each dimension in `matrix.shape[0]`, specified as either
3335      a filter name in `FILTERS` or a `Filter` instance.
3336    prefilter: The prefilter kernel for each dimension in `len(shape)`, specified as either
3337      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3338      (i.e., minification).  If `None`, it inherits the value of `filter`.
3339    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3340      on `array.dtype` and `dtype`.
3341    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3342      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3343      to the uint range.
3344    **kwargs: Additional parameters for `resample` function.
3345
3346  Returns:
3347    An array of the same class as the source `array`, representing a grid with specified `shape`,
3348    where each grid value is resampled from `array`.  Thus the shape of the returned array is
3349    `shape + array.shape[matrix.shape[0]:]`.
3350  """
3351  if isinstance(array, (tuple, list)):
3352    array = np.asarray(array)
3353  shape = tuple(shape)
3354  matrix = np.asarray(matrix)
3355  dst_ndim = len(shape)
3356  if matrix.ndim != 2:
3357    raise ValueError(f'Array {matrix} is not 2D matrix.')
3358  src_ndim = matrix.shape[0]
3359  # grid_shape = array.shape[:src_ndim]
3360  is_affine = matrix.shape[1] == dst_ndim + 1
3361  if src_ndim > len(array.shape):
3362    raise ValueError(
3363        f'Matrix {matrix} has more rows ({matrix.shape[0]}) than ndim in {array.shape=}.'
3364    )
3365  if matrix.shape[1] != dst_ndim and not is_affine:
3366    raise ValueError(
3367        f'Matrix has {matrix.shape=}, but we expect either {dst_ndim} or {dst_ndim + 1} columns.'
3368    )
3369  src_gridtype2, dst_gridtype2 = _get_gridtypes(
3370      gridtype, src_gridtype, dst_gridtype, src_ndim, dst_ndim
3371  )
3372  prefilter = filter if prefilter is None else prefilter
3373  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), src_ndim)]
3374  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), dst_ndim)]
3375  del src_gridtype, dst_gridtype, filter, prefilter
3376  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3377  precision = _get_precision(precision, [_arr_dtype(array), dtype], [])
3378  weight_precision = _real_precision(precision)
3379
3380  dst_position_list = []  # per dimension
3381  for dim in range(dst_ndim):
3382    dst_size = shape[dim]
3383    dst_index = np.arange(dst_size, dtype=weight_precision)
3384    dst_position_list.append(dst_gridtype2[dim].point_from_index(dst_index, dst_size))
3385  dst_position = np.meshgrid(*dst_position_list, indexing='ij')
3386
3387  linear_matrix = matrix[:, :-1] if is_affine else matrix
3388  src_position = np.tensordot(linear_matrix, dst_position, 1)
3389  coords = np.moveaxis(src_position, 0, -1)
3390  if is_affine:
3391    coords += matrix[:, -1]
3392
3393  # TODO: Based on grid_shape, shape, linear_matrix, and prefilter, determine a
3394  # convolution prefilter and apply it to bandlimit 'array', using boundary for padding.
3395
3396  return resample(
3397      array,
3398      coords,
3399      gridtype=src_gridtype2,
3400      filter=filter2,
3401      prefilter=prefilter2,
3402      precision=precision,
3403      dtype=dtype,
3404      **kwargs,
3405  )
3406
3407
3408def _resize_using_resample(
3409    array: _Array,
3410    /,
3411    shape: Iterable[int],
3412    *,
3413    scale: _ArrayLike = 1.0,
3414    translate: _ArrayLike = 0.0,
3415    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3416    fallback: bool = False,
3417    **kwargs: Any,
3418) -> _Array:
3419  """Use the more general `resample` operation for `resize`, as a debug tool."""
3420  if isinstance(array, (tuple, list)):
3421    array = np.asarray(array)
3422  shape = tuple(shape)
3423  scale = np.broadcast_to(scale, len(shape))
3424  translate = np.broadcast_to(translate, len(shape))
3425  # TODO: let resample() do prefiltering for proper downsampling.
3426  has_minification = np.any(np.array(shape) < array.shape[: len(shape)]) or np.any(scale < 1.0)
3427  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape))]
3428  has_auto_trapezoid = any(f.name == 'trapezoid' for f in filter2)
3429  if fallback and (has_minification or has_auto_trapezoid):
3430    return _original_resize(array, shape, scale=scale, translate=translate, filter=filter, **kwargs)
3431  offset = -translate / scale
3432  matrix = np.concatenate([np.diag(1.0 / scale), offset[:, None]], axis=1)
3433  return resample_affine(array, shape, matrix, filter=filter, **kwargs)
3434
3435
3436def rotation_about_center_in_2d(
3437    src_shape: _ArrayLike,
3438    /,
3439    angle: float,
3440    *,
3441    new_shape: _ArrayLike | None = None,
3442    scale: float = 1.0,
3443) -> _NDArray:
3444  """Return the 3x3 matrix mapping destination into a source unit domain.
3445
3446  The returned matrix accounts for the possibly non-square domain shapes.
3447
3448  Args:
3449    src_shape: Resolution `(ny, nx)` of the source domain grid.
3450    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3451      onto the destination domain.
3452    new_shape: Resolution `(ny, nx)` of the destination domain grid; it defaults to `src_shape`.
3453    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3454  """
3455
3456  def translation_matrix(vector: _NDArray) -> _NDArray:
3457    matrix = np.eye(len(vector) + 1)
3458    matrix[:-1, -1] = vector
3459    return matrix
3460
3461  def scaling_matrix(scale: _NDArray) -> _NDArray:
3462    return np.diag(tuple(scale) + (1.0,))
3463
3464  def rotation_matrix_2d(angle: float) -> _NDArray:
3465    cos, sin = np.cos(angle), np.sin(angle)
3466    return np.array([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
3467
3468  src_shape = np.asarray(src_shape)
3469  new_shape = src_shape if new_shape is None else np.asarray(new_shape)
3470  _check_eq(src_shape.shape, (2,))
3471  _check_eq(new_shape.shape, (2,))
3472  half = np.array([0.5, 0.5])
3473  matrix = (
3474      translation_matrix(half)
3475      @ scaling_matrix(min(src_shape) / src_shape)
3476      @ rotation_matrix_2d(angle)
3477      @ scaling_matrix(scale * new_shape / min(new_shape))
3478      @ translation_matrix(-half)
3479  )
3480  assert np.allclose(matrix[-1], [0.0, 0.0, 1.0])
3481  return matrix
3482
3483
3484def rotate_image_about_center(
3485    image: _NDArray,
3486    /,
3487    angle: float,
3488    *,
3489    new_shape: _ArrayLike | None = None,
3490    scale: float = 1.0,
3491    num_rotations: int = 1,
3492    **kwargs: Any,
3493) -> _NDArray:
3494  """Return a copy of `image` rotated about its center.
3495
3496  Args:
3497    image: Source grid samples; the first two dimensions are spatial (ny, nx).
3498    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3499      onto the destination domain.
3500    new_shape: Resolution `(ny, nx)` of the output grid; it defaults to `image.shape[:2]`.
3501    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3502    num_rotations: Number of rotations (each by `angle`).  Successive resamplings are useful in
3503      analyzing the filtering quality.
3504    **kwargs: Additional parameters for `resample_affine`.
3505  """
3506  new_shape = image.shape[:2] if new_shape is None else np.asarray(new_shape)
3507  matrix = rotation_about_center_in_2d(image.shape[:2], angle, new_shape=new_shape, scale=scale)
3508  for _ in range(num_rotations):
3509    image = resample_affine(image, new_shape, matrix[:-1], **kwargs)
3510  return image
3511
3512
3513def pil_image_resize(
3514    array: _ArrayLike,
3515    /,
3516    shape: Iterable[int],
3517    *,
3518    filter: str,
3519    boundary: str = 'natural',
3520    cval: float = 0.0,
3521) -> _NDArray:
3522  """Invoke `PIL.Image.resize` using the same parameters as `resize`."""
3523  import PIL.Image
3524
3525  if boundary != 'natural':
3526    raise ValueError(f"{boundary=} must equal 'natural'.")
3527  del cval
3528  array = np.asarray(array)
3529  assert 1 <= array.ndim <= 3
3530  assert np.issubdtype(array.dtype, np.floating)
3531  shape = tuple(shape)
3532  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3533  if array.ndim == 1:
3534    return pil_image_resize(array[None], (1, *shape), filter=filter)[0]
3535  if not hasattr(PIL.Image, 'Resampling'):  # Pillow<9.0
3536    PIL.Image.Resampling = PIL.Image
3537  filters = {
3538      'impulse': PIL.Image.Resampling.NEAREST,
3539      'box': PIL.Image.Resampling.BOX,
3540      'triangle': PIL.Image.Resampling.BILINEAR,
3541      'hamming1': PIL.Image.Resampling.HAMMING,
3542      'cubic': PIL.Image.Resampling.BICUBIC,
3543      'lanczos3': PIL.Image.Resampling.LANCZOS,
3544  }
3545  if filter not in filters:
3546    raise ValueError(f'{filter=} not in {filters=}.')
3547  pil_resample = filters[filter]
3548  if array.ndim == 2:
3549    return np.array(
3550        PIL.Image.fromarray(array).resize(shape[::-1], resample=pil_resample), array.dtype
3551    )
3552  stack = []
3553  for channel in np.moveaxis(array, -1, 0):
3554    pil_image = PIL.Image.fromarray(channel).resize(shape[::-1], resample=pil_resample)
3555    stack.append(np.array(pil_image, array.dtype))
3556  return np.dstack(stack)
3557
3558
3559def cv_resize(
3560    array: _ArrayLike,
3561    /,
3562    shape: Iterable[int],
3563    *,
3564    filter: str,
3565    boundary: str = 'clamp',
3566    cval: float = 0.0,
3567) -> _NDArray:
3568  """Invoke `cv.resize` using the same parameters as `resize`."""
3569  import cv2 as cv
3570
3571  if boundary != 'clamp':
3572    raise ValueError(f"{boundary=} must equal 'clamp'.")
3573  del cval
3574  array = np.asarray(array)
3575  assert 1 <= array.ndim <= 3
3576  shape = tuple(shape)
3577  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3578  if array.ndim == 1:
3579    return cv_resize(array[None], (1, *shape), filter=filter)[0]
3580  filters = {
3581      'impulse': cv.INTER_NEAREST,  # Or consider cv.INTER_NEAREST_EXACT.
3582      'triangle': cv.INTER_LINEAR_EXACT,  # Or just cv.INTER_LINEAR.
3583      'trapezoid': cv.INTER_AREA,
3584      'sharpcubic': cv.INTER_CUBIC,
3585      'lanczos4': cv.INTER_LANCZOS4,
3586  }
3587  if filter not in filters:
3588    raise ValueError(f'{filter=} not in {filters=}.')
3589  interpolation = filters[filter]
3590  result = cv.resize(array, shape[::-1], interpolation=interpolation)
3591  if array.ndim == 3 and result.ndim == 2:
3592    assert array.shape[2] == 1
3593    return result[..., None]  # Add back the last dimension dropped by cv.resize().
3594  return result
3595
3596
3597def scipy_ndimage_resize(
3598    array: _ArrayLike,
3599    /,
3600    shape: Iterable[int],
3601    *,
3602    filter: str,
3603    boundary: str = 'reflect',
3604    cval: float = 0.0,
3605    scale: float | Iterable[float] = 1.0,
3606    translate: float | Iterable[float] = 0.0,
3607) -> _NDArray:
3608  """Invoke `scipy.ndimage.map_coordinates` using the same parameters as `resize`."""
3609  array = np.asarray(array)
3610  shape = tuple(shape)
3611  assert 1 <= len(shape) <= array.ndim
3612  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3613  if filter not in filters:
3614    raise ValueError(f'{filter=} not in {filters=}.')
3615  order = filters[filter]
3616  boundaries = {'reflect': 'reflect', 'wrap': 'grid-wrap', 'clamp': 'nearest', 'border': 'constant'}
3617  if boundary not in boundaries:
3618    raise ValueError(f'{boundary=} not in {boundaries=}.')
3619  mode = boundaries[boundary]
3620  shape_all = shape + array.shape[len(shape) :]
3621  coords = np.moveaxis(np.indices(shape_all, array.dtype), 0, -1)
3622  coords[..., : len(shape)] = (
3623      (coords[..., : len(shape)] + 0.5) / shape - np.asarray(translate)
3624  ) / np.asarray(scale) * np.array(array.shape)[: len(shape)] - 0.5
3625  coords = np.moveaxis(coords, -1, 0)
3626  return scipy.ndimage.map_coordinates(array, coords, order=order, mode=mode, cval=cval)
3627
3628
3629def skimage_transform_resize(
3630    array: _ArrayLike,
3631    /,
3632    shape: Iterable[int],
3633    *,
3634    filter: str,
3635    boundary: str = 'reflect',
3636    cval: float = 0.0,
3637) -> _NDArray:
3638  """Invoke `skimage.transform.resize` using the same parameters as `resize`."""
3639  import skimage.transform
3640
3641  array = np.asarray(array)
3642  shape = tuple(shape)
3643  assert 1 <= len(shape) <= array.ndim
3644  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3645  if filter not in filters:
3646    raise ValueError(f'{filter=} not in {filters=}.')
3647  order = filters[filter]
3648  boundaries = {'reflect': 'symmetric', 'wrap': 'wrap', 'clamp': 'edge', 'border': 'constant'}
3649  if boundary not in boundaries:
3650    raise ValueError(f'{boundary=} not in {boundaries=}.')
3651  mode = boundaries[boundary]
3652  shape_all = shape + array.shape[len(shape) :]
3653  # Default anti_aliasing=None automatically enables (poor) Gaussian prefilter if downsampling.
3654  # clip=False is the default behavior in `resampler` if the output type is non-integer.
3655  return skimage.transform.resize(
3656      array, shape_all, order=order, mode=mode, cval=cval, clip=False
3657  )  # type: ignore[no-untyped-call]
3658
3659
3660_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER = {
3661    'impulse': 'nearest',
3662    'trapezoid': 'area',
3663    'triangle': 'bilinear',
3664    'mitchell': 'mitchellcubic',
3665    'cubic': 'bicubic',
3666    'lanczos3': 'lanczos3',
3667    'lanczos5': 'lanczos5',
3668    # GaussianFilter(0.5): 'gaussian',  # radius_4 > desired_radius_3.
3669}
3670
3671
3672def tf_image_resize(
3673    array: _ArrayLike,
3674    /,
3675    shape: Iterable[int],
3676    *,
3677    filter: str,
3678    boundary: str = 'natural',
3679    cval: float = 0.0,
3680    antialias: bool = True,
3681) -> _TensorflowTensor:
3682  """Invoke `tf.image.resize` using the same parameters as `resize`."""
3683  import tensorflow as tf
3684
3685  if filter not in _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER:
3686    raise ValueError(f'{filter=} not in {_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER=}.')
3687  if boundary != 'natural':
3688    raise ValueError(f"{boundary=} must equal 'natural'.")
3689  del cval
3690  array2 = tf.convert_to_tensor(array)
3691  ndim = len(array2.shape)
3692  del array
3693  assert 1 <= ndim <= 3
3694  shape = tuple(shape)
3695  _check_eq(len(shape), 2 if ndim >= 2 else 1)
3696  match ndim:
3697    case 1:
3698      return tf_image_resize(array2[None], (1, *shape), filter=filter, antialias=antialias)[0]
3699    case 2:
3700      return tf_image_resize(array2[..., None], shape, filter=filter, antialias=antialias)[..., 0]
3701    case _:
3702      method = _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER[filter]
3703      return tf.image.resize(array2, shape, method=method, antialias=antialias)
3704
3705
3706_TORCH_INTERPOLATE_MODE_FROM_FILTER = {
3707    'impulse': 'nearest-exact',  # ('nearest' matches buggy OpenCV's INTER_NEAREST)
3708    'trapezoid': 'area',
3709    'triangle': 'bilinear',
3710    'sharpcubic': 'bicubic',
3711}
3712
3713
3714def torch_nn_resize(
3715    array: _ArrayLike,
3716    /,
3717    shape: Iterable[int],
3718    *,
3719    filter: str,
3720    boundary: str = 'clamp',
3721    cval: float = 0.0,
3722    antialias: bool = False,
3723) -> _TorchTensor:
3724  """Invoke `torch.nn.functional.interpolate` using the same parameters as `resize`."""
3725  import torch
3726
3727  if filter not in _TORCH_INTERPOLATE_MODE_FROM_FILTER:
3728    raise ValueError(f'{filter=} not in {_TORCH_INTERPOLATE_MODE_FROM_FILTER=}.')
3729  if boundary != 'clamp':
3730    raise ValueError(f"{boundary=} must equal 'clamp'.")
3731  del cval
3732  a = torch.as_tensor(array)
3733  del array
3734  assert 1 <= a.ndim <= 3
3735  shape = tuple(shape)
3736  _check_eq(len(shape), 2 if a.ndim >= 2 else 1)
3737  mode = _TORCH_INTERPOLATE_MODE_FROM_FILTER[filter]
3738
3739  def local_resize(a: _TorchTensor) -> _TorchTensor:
3740    # For upsampling, BILINEAR antialias is same PSNR and slower,
3741    #  and BICUBIC antialias is worse PSNR and faster.
3742    # For downsampling, antialias improves PSNR for both BILINEAR and BICUBIC.
3743    # Default align_corners=None corresponds to False which is what we desire.
3744    return torch.nn.functional.interpolate(a, shape, mode=mode, antialias=antialias)
3745
3746  match a.ndim:
3747    case 1:
3748      shape = (1, *shape)
3749      return local_resize(a[None, None, None])[0, 0, 0]
3750    case 2:
3751      return local_resize(a[None, None])[0, 0]
3752    case _:
3753      return local_resize(a.moveaxis(2, 0)[None])[0].moveaxis(0, 2)
3754
3755
3756def jax_image_resize(
3757    array: _ArrayLike,
3758    /,
3759    shape: Iterable[int],
3760    *,
3761    filter: str,
3762    boundary: str = 'natural',
3763    cval: float = 0.0,
3764    scale: float | Iterable[float] = 1.0,
3765    translate: float | Iterable[float] = 0.0,
3766) -> _JaxArray:
3767  """Invoke `jax.image.scale_and_translate` using the same parameters as `resize`."""
3768  import jax.image
3769  import jax.numpy as jnp
3770
3771  filters = 'triangle cubic lanczos3 lanczos5'.split()
3772  if filter not in filters:
3773    raise ValueError(f'{filter=} not in {filters=}.')
3774  if boundary != 'natural':
3775    raise ValueError(f"{boundary=} must equal 'natural'.")
3776  # When `scale` or `translate` are applied, any region outside the unit domain is assigned value 0.
3777  # To be consistent, the parameter `cval` must be zero.
3778  if scale != 1.0 and cval != 0.0:
3779    raise ValueError(f'Non-unity {scale=} requires that {cval=} be zero.')
3780  if translate != 0.0 and cval != 0.0:
3781    raise ValueError(f'Nonzero {translate=} requires that {cval=} be zero.')
3782  array2 = jnp.asarray(array)
3783  del array
3784  shape = tuple(shape)
3785  assert len(shape) <= array2.ndim
3786  completed_shape = shape + (1,) * (array2.ndim - len(shape))
3787  spatial_dims = list(range(len(shape)))
3788  scale2 = np.broadcast_to(np.array(scale), len(shape))
3789  scale2 = scale2 / np.array(array2.shape[: len(shape)]) * np.array(shape)
3790  translate2 = np.broadcast_to(np.array(translate), len(shape))
3791  translate2 = translate2 * np.array(shape)
3792  return jax.image.scale_and_translate(
3793      array2, completed_shape, spatial_dims, scale2, translate2, filter
3794  )
3795
3796
3797_CANDIDATE_RESIZERS = {
3798    'resampler.resize': resize,
3799    'PIL.Image.resize': pil_image_resize,
3800    'cv.resize': cv_resize,
3801    'scipy.ndimage.map_coordinates': scipy_ndimage_resize,
3802    'skimage.transform.resize': skimage_transform_resize,
3803    'tf.image.resize': tf_image_resize,
3804    'torch.nn.functional.interpolate': torch_nn_resize,
3805    'jax.image.scale_and_translate': jax_image_resize,
3806}
3807
3808
3809def _resizer_is_available(library_function: str) -> bool:
3810  """Return whether the resizer is available as an installed package."""
3811  top_name = library_function.split('.', 1)[0]
3812  module = {'PIL': 'Pillow', 'cv': 'cv2', 'tf': 'tensorflow'}.get(top_name, top_name)
3813  return importlib.util.find_spec(module) is not None
3814
3815
3816_RESIZERS = {
3817    library_function: resizer
3818    for library_function, resizer in _CANDIDATE_RESIZERS.items()
3819    if _resizer_is_available(library_function)
3820}
3821
3822
3823def _find_closest_filter(filter: str, resizer: Callable[..., Any]) -> str:
3824  """Return the filter supported by `resizer` (i.e., `*_resize`) that is closest to `filter`."""
3825  match filter:
3826    case 'box_like':
3827      return {
3828          cv_resize: 'trapezoid',
3829          skimage_transform_resize: 'box',
3830          tf_image_resize: 'trapezoid',
3831          torch_nn_resize: 'trapezoid',
3832      }.get(resizer, 'box')
3833    case 'cubic_like':
3834      return {
3835          cv_resize: 'sharpcubic',
3836          scipy_ndimage_resize: 'cardinal3',
3837          skimage_transform_resize: 'cardinal3',
3838          torch_nn_resize: 'sharpcubic',
3839      }.get(resizer, 'cubic')
3840    case 'high_quality':
3841      return {
3842          pil_image_resize: 'lanczos3',
3843          cv_resize: 'lanczos4',
3844          scipy_ndimage_resize: 'cardinal5',
3845          skimage_transform_resize: 'cardinal5',
3846          torch_nn_resize: 'sharpcubic',
3847      }.get(resizer, 'lanczos5')
3848    case _:
3849      return filter
3850
3851
3852# For Emacs:
3853# Local Variables:
3854# fill-column: 100
3855# End:
using_numba = False
def is_available(arraylib: str) -> bool:
809def is_available(arraylib: str) -> bool:
810  """Return whether the array library (e.g. 'tensorflow') is available as an installed package."""
811  return importlib.util.find_spec(arraylib) is not None  # Faster than trying to import it.

Return whether the array library (e.g. 'tensorflow') is available as an installed package.

ARRAYLIBS: list[str] = ['numpy']

Array libraries supported automatically in the resize and resampling operations.

  • The library is selected automatically based on the type of the array function parameter.

  • The class _Arraylib provides library-specific implementations of needed basic functions.

  • The _arr_*() functions dispatch the _Arraylib methods based on the array type.

@dataclasses.dataclass(frozen=True)
class Gridtype(abc.ABC):
1037@dataclasses.dataclass(frozen=True)
1038class Gridtype(abc.ABC):
1039  """Abstract base class for grid-types such as `'dual'` and `'primal'`.
1040
1041  In resampling operations, the grid-type may be specified separately as `src_gridtype` for the
1042  source domain and `dst_gridtype` for the destination domain.  Moreover, the grid-type may be
1043  specified per domain dimension.
1044
1045  Examples:
1046    `resize(source, shape, gridtype='primal')`  # Sets both src and dst to be `'primal'` grids.
1047
1048    `resize(source, shape, src_gridtype=['dual', 'primal'],
1049            dst_gridtype='dual')`  # Source is `'dual'` in dim0 and `'primal'` in dim1.
1050  """
1051
1052  name: str
1053  """Gridtype name."""
1054
1055  @abc.abstractmethod
1056  def min_size(self) -> int:
1057    """Return the necessary minimum number of grid samples."""
1058
1059  @abc.abstractmethod
1060  def size_in_samples(self, size: int, /) -> int:
1061    """Return the domain size in units of inter-sample spacing."""
1062
1063  @abc.abstractmethod
1064  def point_from_index(self, index: _NDArray, size: int, /) -> _NDArray:
1065    """Return [0.0, 1.0] coordinates given [0, size - 1] indices."""
1066
1067  @abc.abstractmethod
1068  def index_from_point(self, point: _NDArray, size: int, /) -> _NDArray:
1069    """Return location x given coordinates [0.0, 1.0], where x == 0.0 is the first grid sample
1070    and x == size - 1.0 is the last grid sample."""
1071
1072  @abc.abstractmethod
1073  def reflect(self, index: _NDArray, size: int, /) -> _NDArray:
1074    """Map integer sample indices to interior ones using boundary reflection."""
1075
1076  @abc.abstractmethod
1077  def wrap(self, index: _NDArray, size: int, /) -> _NDArray:
1078    """Map integer sample indices to interior ones using wrapping."""
1079
1080  @abc.abstractmethod
1081  def reflect_clamp(self, index: _NDArray, size: int, /) -> _NDArray:
1082    """Map integer sample indices to interior ones using reflect-clamp."""

Abstract base class for grid-types such as 'dual' and 'primal'.

In resampling operations, the grid-type may be specified separately as src_gridtype for the source domain and dst_gridtype for the destination domain. Moreover, the grid-type may be specified per domain dimension.

Examples:

resize(source, shape, gridtype='primal') # Sets both src and dst to be 'primal' grids.

resize(source, shape, src_gridtype=['dual', 'primal'], dst_gridtype='dual') # Source is 'dual' in dim0 and 'primal' in dim1.

GRIDTYPES: list[str] = ['dual', 'primal']

Shortcut names for the two predefined grid types (specified per dimension):

gridtype 'dual'
DualGridtype()
(default)
'primal'
PrimalGridtype()
 
Sample positions in 2D
and in 1D at different resolutions
Dual Primal
Nesting of samples across resolutions The samples positions do not nest. The even samples remain at coarser scale.
Number $N_\ell$ of samples (per-dimension) at resolution level $\ell$ $N_\ell=2^\ell$ $N_\ell=2^\ell+1$
Position of sample index $i$ within domain $[0, 1]$ $\frac{i + 0.5}{N}$ ("half-integer" coordinates) $\frac{i}{N-1}$
Image resolutions ($N_\ell\times N_\ell$) for dyadic scales $1\times1, ~~2\times2, ~~4\times4, ~~8\times8, ~\ldots$ $2\times2, ~~3\times3, ~~5\times5, ~~9\times9, ~\ldots$

See the source code for extensibility.

@dataclasses.dataclass(frozen=True)
class Boundary:
1469@dataclasses.dataclass(frozen=True)
1470class Boundary:
1471  """Domain boundary rules.  These define the reconstruction over the source domain near and beyond
1472  the domain boundaries.  The rules may be specified separately for each domain dimension."""
1473
1474  name: str = ''
1475  """Boundary rule name."""
1476
1477  coord_remap: RemapCoordinates = NoRemapCoordinates()
1478  """Modify specified coordinates prior to evaluating the reconstruction kernels."""
1479
1480  extend_samples: ExtendSamples = ReflectExtendSamples()
1481  """Define the value of each grid sample outside the unit domain as an affine combination of
1482  interior sample(s) and possibly the constant value (`cval`)."""
1483
1484  override_value: OverrideExteriorValue = NoOverrideExteriorValue()
1485  """Set the value outside some extent to a constant value (`cval`)."""
1486
1487  @property
1488  def uses_cval(self) -> bool:
1489    """True if weights may be non-affine, involving the constant value (`cval`)."""
1490    return self.extend_samples.uses_cval or self.override_value.uses_cval
1491
1492  def preprocess_coordinates(self, point: _NDArray, /) -> _NDArray:
1493    """Modify coordinates prior to evaluating the filter kernels."""
1494    # Antialiasing across the tile boundaries may be feasible but seems hard.
1495    point = self.coord_remap(point)
1496    return point
1497
1498  def apply(
1499      self, index: _NDArray, weight: _NDArray, point: _NDArray, size: int, gridtype: Gridtype, /
1500  ) -> tuple[_NDArray, _NDArray]:
1501    """Replace exterior samples by combinations of interior samples."""
1502    index, weight = self.extend_samples(index, weight, size, gridtype)
1503    self.override_reconstruction(weight, point)
1504    return index, weight
1505
1506  def override_reconstruction(self, weight: _NDArray, point: _NDArray, /) -> None:
1507    """For points outside an extent, modify weight to zero to assign `cval`."""
1508    self.override_value(weight, point)

Domain boundary rules. These define the reconstruction over the source domain near and beyond the domain boundaries. The rules may be specified separately for each domain dimension.

BOUNDARIES: list[str] = ['reflect', 'wrap', 'tile', 'clamp', 'border', 'natural', 'linear_constant', 'quadratic_constant', 'reflect_clamp', 'constant', 'linear', 'quadratic']

Shortcut names for some predefined boundary rules (as defined by _DICT_BOUNDARIES):

name a.k.a. / comments
'reflect' reflected, symm, symmetric, mirror, grid-mirror
'wrap' periodic, repeat, grid-wrap
'tile' like 'reflect' within unit domain, then tile discontinuously
'clamp' clamped, nearest, edge, clamp-to-edge, repeat last sample
'border' grid-constant, use cval for samples outside unit domain
'natural' renormalize using only interior samples, use cval outside domain
'reflect_clamp' mirror-clamp-to-edge
'constant' like 'reflect' but replace by cval outside unit domain
'linear' extrapolate from 2 last samples
'quadratic' extrapolate from 3 last samples
'linear_constant' like 'linear' but replace by cval outside unit domain
'quadratic_constant' like 'quadratic' but replace by cval outside unit domain

These boundary rules may be specified per dimension. See the source code for extensibility using the classes RemapCoordinates, ExtendSamples, and OverrideExteriorValue.

Boundary rules illustrated in 1D:

Boundary rules illustrated in 2D:

@dataclasses.dataclass(frozen=True)
class Filter(abc.ABC):
1589@dataclasses.dataclass(frozen=True)
1590class Filter(abc.ABC):
1591  """Abstract base class for filter kernel functions.
1592
1593  Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support
1594  interval [-radius, radius].  (Some sites instead define kernels over the interval [0, N]
1595  where N = 2 * radius.)
1596
1597  Portions of this code are adapted from the C++ library in
1598  https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp
1599
1600  See also https://hhoppe.com/proj/filtering/.
1601  """
1602
1603  name: str
1604  """Filter kernel name."""
1605
1606  radius: float
1607  """Max absolute value of x for which self(x) is nonzero."""
1608
1609  interpolating: bool = True
1610  """True if self(0) == 1.0 and self(i) == 0.0 for all nonzero integers i."""
1611
1612  continuous: bool = True
1613  """True if the kernel function has $C^0$ continuity."""
1614
1615  partition_of_unity: bool = True
1616  """True if the convolution of the kernel with a Dirac comb reproduces the
1617  unity function."""
1618
1619  unit_integral: bool = True
1620  """True if the integral of the kernel function is 1."""
1621
1622  requires_digital_filter: bool = False
1623  """True if the filter needs a pre/post digital filter for interpolation."""
1624
1625  @abc.abstractmethod
1626  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1627    """Return evaluation of filter kernel at locations x."""

Abstract base class for filter kernel functions.

Each kernel is assumed to be a zero-phase filter, i.e., to be symmetric in a support interval [-radius, radius]. (Some sites instead define kernels over the interval [0, N] where N = 2 * radius.)

Portions of this code are adapted from the C++ library in https://github.com/hhoppe/Mesh-processing-library/blob/main/libHh/Filter.cpp

See also https://hhoppe.com/proj/filtering/.

class ImpulseFilter(Filter):
1630class ImpulseFilter(Filter):
1631  """See https://en.wikipedia.org/wiki/Dirac_delta_function."""
1632
1633  def __init__(self) -> None:
1634    super().__init__(name='impulse', radius=1e-20, continuous=False, partition_of_unity=False)
1635
1636  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1637    raise AssertionError('The Impulse is infinitely narrow, so cannot be directly evaluated.')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class BoxFilter(Filter):
1640class BoxFilter(Filter):
1641  """See https://en.wikipedia.org/wiki/Box_function.
1642
1643  The kernel function has value 1.0 over the half-open interval [-.5, .5).
1644  """
1645
1646  def __init__(self) -> None:
1647    super().__init__(name='box', radius=0.5, continuous=False)
1648
1649  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1650    use_asymmetric = True
1651    if use_asymmetric:
1652      x = np.asarray(x)
1653      return np.where((-0.5 <= x) & (x < 0.5), 1.0, 0.0)
1654    x = np.abs(x)
1655    return np.where(x < 0.5, 1.0, np.where(x == 0.5, 0.5, 0.0))

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class TrapezoidFilter(Filter):
1658class TrapezoidFilter(Filter):
1659  """Filter for antialiased "area-based" filtering.
1660
1661  Args:
1662    radius: Specifies the support [-radius, radius] of the filter, where 0.5 < radius <= 1.0.
1663      The special case `radius = None` is a placeholder that indicates that the filter will be
1664      replaced by a trapezoid of the appropriate radius (based on scaling) for correct
1665      antialiasing in both minification and magnification.
1666
1667  This filter is similar to the BoxFilter but with linearly sloped sides.  It has value 1.0
1668  in the interval abs(x) <= 1.0 - radius and decreases linearly to value 0.0 in the interval
1669  1.0 - radius <= abs(x) <= radius, always with value 0.5 at x = 0.5.
1670  """
1671
1672  def __init__(self, *, radius: float | None = None) -> None:
1673    if radius is None:
1674      super().__init__(name='trapezoid', radius=0.0)
1675      return
1676    if not 0.5 < radius <= 1.0:
1677      raise ValueError(f'Radius {radius} is outside the range (0.5, 1.0].')
1678    super().__init__(name=f'trapezoid_{radius}', radius=radius)
1679
1680  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1681    x = np.abs(x)
1682    assert 0.5 < self.radius <= 1.0
1683    return ((0.5 + 0.25 / (self.radius - 0.5)) - (0.5 / (self.radius - 0.5)) * x).clip(0.0, 1.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class TriangleFilter(Filter):
1686class TriangleFilter(Filter):
1687  """See https://en.wikipedia.org/wiki/Triangle_function.
1688
1689  Also known as the hat or tent function.  It is used for piecewise-linear
1690  (or bilinear, or trilinear, ...) interpolation.
1691  """
1692
1693  def __init__(self) -> None:
1694    super().__init__(name='triangle', radius=1.0)
1695
1696  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1697    return (1.0 - np.abs(x)).clip(0.0, 1.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CubicFilter(Filter):
1700class CubicFilter(Filter):
1701  """Family of cubic filters parameterized by two scalar parameters.
1702
1703  Args:
1704    b: first scalar parameter.
1705    c: second scalar parameter.
1706
1707  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1708  https://doi.org/10.1145/378456.378514.
1709
1710  [D. P. Mitchell and A. N. Netravali. Reconstruction filters in computer graphics.
1711  Computer Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1712
1713  - The filter has quadratic precision iff b + 2 * c == 1.
1714  - The filter is interpolating iff b == 0.
1715  - (b=1, c=0) is the (non-interpolating) cubic B-spline basis;
1716  - (b=1/3, c=1/3) is the Mitchell filter;
1717  - (b=0, c=0.5) is the Catmull-Rom spline (which has cubic precision);
1718  - (b=0, c=0.75) is the "sharper cubic" used in Photoshop and OpenCV.
1719  """
1720
1721  def __init__(self, *, b: float, c: float, name: str | None = None) -> None:
1722    name = f'cubic_b{b}_c{c}' if name is None else name
1723    interpolating = b == 0
1724    super().__init__(name=name, radius=2.0, interpolating=interpolating)
1725    self.b, self.c = b, c
1726
1727  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1728    x = np.abs(x)
1729    b, c = self.b, self.c
1730    f3, f2, f0 = 2 - 9 / 6 * b - c, -3 + 2 * b + c, 1 - 1 / 3 * b
1731    g3, g2, g1, g0 = -b / 6 - c, b + 5 * c, -2 * b - 8 * c, 8 / 6 * b + 4 * c
1732    # (np.polynomial.polynomial.polyval(x, [f0, 0, f2, f3]) is almost
1733    # twice as slow; see also https://stackoverflow.com/questions/24065904)
1734    v01 = ((f3 * x + f2) * x) * x + f0
1735    v12 = ((g3 * x + g2) * x + g1) * x + g0
1736    return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CatmullRomFilter(CubicFilter):
1739class CatmullRomFilter(CubicFilter):
1740  """Cubic filter with cubic precision.  Also known as Keys filter.
1741
1742  [E. Catmull, R. Rom.  A class of local interpolating splines.  Computer aided geometric
1743  design, 1974]
1744  [Wikipedia](https://en.wikipedia.org/wiki/Cubic_Hermite_spline#Catmull%E2%80%93Rom_spline)
1745
1746  [R. G. Keys.  Cubic convolution interpolation for digital image processing.
1747  IEEE Trans. on Acoustics, Speech, and Signal Processing, 29(6), 1981.]
1748  https://ieeexplore.ieee.org/document/1163711/.
1749  """
1750
1751  def __init__(self) -> None:
1752    super().__init__(b=0, c=0.5, name='cubic')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class MitchellFilter(CubicFilter):
1755class MitchellFilter(CubicFilter):
1756  """See https://doi.org/10.1145/378456.378514.
1757
1758  [D. P. Mitchell and A. N. Netravali.  Reconstruction filters in computer graphics.  Computer
1759  Graphics (Proceedings of ACM SIGGRAPH 1988), 22(4):221-228, 1988.]
1760  """
1761
1762  def __init__(self) -> None:
1763    super().__init__(b=1 / 3, c=1 / 3, name='mitchell')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class SharpCubicFilter(CubicFilter):
1766class SharpCubicFilter(CubicFilter):
1767  """Cubic filter that is sharper than Catmull-Rom filter.
1768
1769  Used by some tools including OpenCV and Photoshop.
1770
1771  See https://en.wikipedia.org/wiki/Mitchell%E2%80%93Netravali_filters and
1772  https://entropymine.com/resamplescope/notes/photoshop/.
1773  """
1774
1775  def __init__(self) -> None:
1776    super().__init__(b=0, c=0.75, name='sharpcubic')

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class LanczosFilter(Filter):
1779class LanczosFilter(Filter):
1780  """High-quality filter: sinc function modulated by a sinc window.
1781
1782  Args:
1783    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1784    sampled: If True, use a discretized approximation for improved speed.
1785
1786  See https://en.wikipedia.org/wiki/Lanczos_kernel.
1787  """
1788
1789  def __init__(self, *, radius: int, sampled: bool = True) -> None:
1790    super().__init__(
1791        name=f'lanczos_{radius}', radius=radius, partition_of_unity=False, unit_integral=False
1792    )
1793
1794    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1795    def _eval(x: _ArrayLike) -> _NDArray:
1796      x = np.abs(x)
1797      # Note that window[n] = sinc(2*n/N - 1), with 0 <= n <= N.
1798      # But, x = n - N/2, or equivalently, n = x + N/2, with -N/2 <= x <= N/2.
1799      window = _sinc(x / radius)  # Zero-phase function w_0(x).
1800      return np.where(x < radius, _sinc(x) * window, 0.0)
1801
1802    self._function = _eval
1803
1804  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1805    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class GeneralizedHammingFilter(Filter):
1808class GeneralizedHammingFilter(Filter):
1809  """Sinc function modulated by a Hamming window.
1810
1811  Args:
1812    radius: Specifies the support window [-radius, radius] over which the filter is nonzero.
1813    a0: Scalar parameter, where 0.0 < a0 < 1.0.  The case of a0=0.5 is the Hann filter.
1814
1815  See https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows,
1816  and hamming() in https://github.com/scipy/scipy/blob/main/scipy/signal/windows/_windows.py.
1817
1818  Note that `'hamming3'` is `(radius=3, a0=25/46)`, which close to but different from `a0=0.54`.
1819
1820  See also np.hamming() and np.hanning().
1821  """
1822
1823  def __init__(self, *, radius: int, a0: float) -> None:
1824    super().__init__(
1825        name=f'hamming_{radius}',
1826        radius=radius,
1827        partition_of_unity=False,  # 1:1.00242  av=1.00188  sd=0.00052909
1828        unit_integral=False,  # 1.00188
1829    )
1830    assert 0.0 < a0 < 1.0
1831    self.a0 = a0
1832
1833  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1834    x = np.abs(x)
1835    # Note that window[n] = a0 - (1 - a0) * cos(2 * pi * n / N), 0 <= n <= N.
1836    # With n = x + N/2, we get the zero-phase function w_0(x):
1837    window = self.a0 + (1.0 - self.a0) * np.cos(np.pi / self.radius * x)
1838    return np.where(x < self.radius, _sinc(x) * window, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class KaiserFilter(Filter):
1841class KaiserFilter(Filter):
1842  """Sinc function modulated by a Kaiser-Bessel window.
1843
1844  See https://en.wikipedia.org/wiki/Kaiser_window, and example use in:
1845  [Karras et al. 20201.  Alias-free generative adversarial networks.
1846  https://arxiv.org/pdf/2106.12423.pdf].
1847
1848  See also np.kaiser().
1849
1850  Args:
1851    radius: Value L/2 in the definition.  It may be fractional for a (digital) resizing filter
1852      (sample spacing s != 1) with an even number of samples (dual grid), e.g., Eq. (6)
1853      in [Karras et al. 2021] --- this effects the precise shape of the window function.
1854    beta: Determines the trade-off between main-lobe width and side-lobe level.
1855    sampled: If True, use a discretized approximation for improved speed.
1856  """
1857
1858  def __init__(self, *, radius: float, beta: float, sampled: bool = True) -> None:
1859    assert beta >= 0.0
1860    super().__init__(
1861        name=f'kaiser_{radius}_{beta}', radius=radius, partition_of_unity=False, unit_integral=False
1862    )
1863
1864    @_cache_sampled_1d_function(xmin=-math.ceil(radius), xmax=math.ceil(radius), enable=sampled)
1865    def _eval(x: _ArrayLike) -> _NDArray:
1866      x = np.abs(x)
1867      window = np.i0(beta * np.sqrt((1.0 - np.square(x / radius)).clip(0.0, 1.0))) / np.i0(beta)
1868      return np.where(x <= radius + 1e-6, _sinc(x) * window, 0.0)
1869
1870    self._function = _eval
1871
1872  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1873    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class BsplineFilter(Filter):
1876class BsplineFilter(Filter):
1877  """B-spline of a non-negative degree.
1878
1879  Args:
1880    degree: The polynomial degree of the B-spline segments.
1881      With `degree=0`, it is like `BoxFilter` except with f(0.5) = f(-0.5) = 0.
1882      With `degree=1`, it is identical to `TriangleFilter`.
1883      With `degree >= 2`, it is no longer interpolating.
1884
1885  See [Carl de Boor.  A practical guide to splines.  Springer, 2001.]
1886  https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.BSpline.html
1887  """
1888
1889  def __init__(self, *, degree: int) -> None:
1890    if degree < 0:
1891      raise ValueError(f'Bspline of degree {degree} is invalid.')
1892    radius = (degree + 1) / 2
1893    interpolating = degree <= 1
1894    super().__init__(name=f'bspline{degree}', radius=radius, interpolating=interpolating)
1895    t = list(range(degree + 2))
1896    self._bspline = scipy.interpolate.BSpline.basis_element(t)
1897
1898  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1899    x = np.abs(x)
1900    return np.where(x < self.radius, self._bspline(x + self.radius), 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class CardinalBsplineFilter(Filter):
1903class CardinalBsplineFilter(Filter):
1904  """Interpolating B-spline, achieved with aid of digital pre or post filter.
1905
1906  Args:
1907    degree: The polynomial degree of the B-spline segments.
1908    sampled: If True, use a discretized approximation for improved speed.
1909
1910  See [Hou and Andrews.  Cubic splines for image interpolation and digital filtering, 1978] and
1911  [Unser et al.  Fast B-spline transforms for continuous image representation and interpolation,
1912  1991].
1913  """
1914
1915  def __init__(self, *, degree: int, sampled: bool = True) -> None:
1916    self.degree = degree
1917    if degree < 0:
1918      raise ValueError(f'Bspline of degree {degree} is invalid.')
1919    radius = (degree + 1) / 2
1920    super().__init__(
1921        name=f'cardinal{degree}',
1922        radius=radius,
1923        requires_digital_filter=degree >= 2,
1924        continuous=degree >= 1,
1925    )
1926    t = list(range(degree + 2))
1927    bspline = scipy.interpolate.BSpline.basis_element(t)
1928
1929    @_cache_sampled_1d_function(xmin=-radius, xmax=radius, enable=sampled)
1930    def _eval(x: _ArrayLike) -> _NDArray:
1931      x = np.abs(x)
1932      return np.where(x < radius, bspline(x + radius), 0.0)
1933
1934    self._function = _eval
1935
1936  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1937    return self._function(x)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class OmomsFilter(Filter):
1940class OmomsFilter(Filter):
1941  """OMOMS interpolating filter, with aid of digital pre or post filter.
1942
1943  Args:
1944    degree: The polynomial degree of the filter segments.
1945
1946  Optimal MOMS (maximal-order-minimal-support) function; see [Blu and Thevenaz, MOMS: Maximal-order
1947  interpolation of minimal support, 2001].
1948  https://infoscience.epfl.ch/record/63074/files/blu0101.pdf
1949  """
1950
1951  def __init__(self, *, degree: int) -> None:
1952    if degree not in (3, 5):
1953      raise ValueError(f'Degree {degree} not supported.')
1954    super().__init__(name=f'omoms{degree}', radius=(degree + 1) / 2, requires_digital_filter=True)
1955    self.degree = degree
1956
1957  def __call__(self, x: _ArrayLike, /) -> _NDArray:
1958    x = np.abs(x)
1959    match self.degree:
1960      case 3:
1961        v01 = ((0.5 * x - 1.0) * x + 3 / 42) * x + 26 / 42
1962        v12 = ((-7 / 42 * x + 1.0) * x - 85 / 42) * x + 58 / 42
1963        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, 0.0))
1964      case 5:
1965        v01 = ((((-1 / 12 * x + 1 / 4) * x - 5 / 99) * x - 9 / 22) * x - 1 / 792) * x + 229 / 440
1966        v12 = (
1967            (((1 / 24 * x - 3 / 8) * x + 505 / 396) * x - 83 / 44) * x + 1351 / 1584
1968        ) * x + 839 / 2640
1969        v23 = (
1970            (((-1 / 120 * x + 1 / 8) * x - 299 / 396) * x + 101 / 44) * x - 27811 / 7920
1971        ) * x + 5707 / 2640
1972        return np.where(x < 1.0, v01, np.where(x < 2.0, v12, np.where(x < 3.0, v23, 0.0)))
1973      case _:
1974        raise ValueError(self.degree)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class GaussianFilter(Filter):
1977class GaussianFilter(Filter):
1978  r"""See https://en.wikipedia.org/wiki/Gaussian_function.
1979
1980  Args:
1981    standard_deviation: Sets the Gaussian $\sigma$.  The default value is 1.25/3.0, which
1982      creates a kernel that is as-close-as-possible to a partition of unity.
1983  """
1984
1985  DEFAULT_STANDARD_DEVIATION = 1.25 / 3.0
1986  """This value creates a kernel that is as-close-as-possible to a partition of unity; see
1987  mesh_processing/test/GridOp_test.cpp: `0.93503:1.06497     av=1           sd=0.0459424`.
1988  Another possibility is 0.5, as suggested on p. 4 of [Ken Turkowski.  Filters for common
1989  resampling tasks, 1990] for kernels with a support of 3 pixels.
1990  https://cadxfem.org/inf/ResamplingFilters.pdf
1991  """
1992
1993  def __init__(self, *, standard_deviation: float = DEFAULT_STANDARD_DEVIATION) -> None:
1994    super().__init__(
1995        name=f'gaussian_{standard_deviation:.3f}',
1996        radius=np.ceil(8.0 * standard_deviation),  # Sufficiently large.
1997        interpolating=False,
1998        partition_of_unity=False,
1999    )
2000    self.standard_deviation = standard_deviation
2001
2002  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2003    x = np.abs(x)
2004    sdv = self.standard_deviation
2005    v0r = np.exp(np.square(x / sdv) / -2.0) / (np.sqrt(math.tau) * sdv)
2006    return np.where(x < self.radius, v0r, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

class NarrowBoxFilter(Filter):
2009class NarrowBoxFilter(Filter):
2010  """Compact footprint, used for visualization of grid sample location.
2011
2012  Args:
2013    radius: Specifies the support [-radius, radius] of the narrow box function.  (The default
2014      value 0.199 is an inexact 0.2 to avoid numerical ambiguities.)
2015  """
2016
2017  def __init__(self, *, radius: float = 0.199) -> None:
2018    super().__init__(
2019        name='narrowbox',
2020        radius=radius,
2021        continuous=False,
2022        unit_integral=False,
2023        partition_of_unity=False,
2024    )
2025
2026  def __call__(self, x: _ArrayLike, /) -> _NDArray:
2027    radius = self.radius
2028    magnitude = 1.0
2029    x = np.asarray(x)
2030    return np.where((-radius <= x) & (x < radius), magnitude, 0.0)

Filter(name: 'str', radius: 'float', interpolating: 'bool' = Ellipsis, continuous: 'bool' = Ellipsis, partition_of_unity: 'bool' = Ellipsis, unit_integral: 'bool' = Ellipsis, requires_digital_filter: 'bool' = Ellipsis)

FILTERS: list[str] = ['impulse', 'box', 'trapezoid', 'triangle', 'cubic', 'sharpcubic', 'lanczos3', 'lanczos5', 'lanczos10', 'cardinal3', 'cardinal5', 'omoms3', 'omoms5', 'hamming3', 'kaiser3', 'gaussian', 'bspline3', 'mitchell', 'narrowbox']

Shortcut names for some predefined filter kernels (specified per dimension). The names expand to:

name Filter a.k.a. / comments
'impulse' ImpulseFilter() nearest
'box' BoxFilter() non-antialiased box, e.g. ImageMagick
'trapezoid' TrapezoidFilter() area antialiasing, e.g. cv.INTER_AREA
'triangle' TriangleFilter() linear (bilinear in 2D), spline order=1
'cubic' CatmullRomFilter() catmullrom, keys, bicubic
'sharpcubic' SharpCubicFilter() cv.INTER_CUBIC, torch 'bicubic'
'lanczos3' LanczosFilter(radius=3) support window [-3, 3]
'lanczos5' LanczosFilter(radius=5) [-5, 5]
'lanczos10' LanczosFilter(radius=10) [-10, 10]
'cardinal3' CardinalBsplineFilter(degree=3) spline interpolation, order=3, GF
'cardinal5' CardinalBsplineFilter(degree=5) spline interpolation, order=5, GF
'omoms3' OmomsFilter(degree=3) non-$C^1$, [-3, 3], GF
'omoms5' OmomsFilter(degree=5) non-$C^1$, [-5, 5], GF
'hamming3' GeneralizedHammingFilter(...) (radius=3, a0=25/46)
'kaiser3' KaiserFilter(radius=3.0, beta=7.12)
'gaussian' GaussianFilter() non-interpolating, default $\sigma=1.25/3$
'bspline3' BsplineFilter(degree=3) non-interpolating
'mitchell' MitchellFilter() mitchellcubic
'narrowbox' NarrowBoxFilter() for visualization of sample positions

The comment label GF denotes a generalized filter, formed as the composition of a finitely supported kernel and a discrete inverse convolution.

Some example filter kernels:


A more extensive set of filters is presented here in the notebook, together with visualizations and analyses of the filter properties. See the source code for extensibility.

@dataclasses.dataclass(frozen=True)
class Gamma(abc.ABC):
2138@dataclasses.dataclass(frozen=True)
2139class Gamma(abc.ABC):
2140  """Abstract base class for transfer functions on sample values.
2141
2142  Image/video content is often stored using a color component transfer function.
2143  See https://en.wikipedia.org/wiki/Gamma_correction.
2144
2145  Converts between integer types and [0.0, 1.0] internal value range.
2146  """
2147
2148  name: str
2149  """Name of component transfer function."""
2150
2151  @abc.abstractmethod
2152  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2153    """Decode source sample values into floating-point, possibly nonlinearly.
2154
2155    Uint source values are mapped to the range [0.0, 1.0].
2156    """
2157
2158  @abc.abstractmethod
2159  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2160    """Encode float signal into destination samples, possibly nonlinearly.
2161
2162    Uint destination values are mapped from the range [0.0, 1.0].
2163
2164    Note that non-integer destination types are not clipped to the range [0.0, 1.0].
2165    If that is desired, it can be performed as a postprocess using `output.clip(0.0, 1.0)`.
2166    """

Abstract base class for transfer functions on sample values.

Image/video content is often stored using a color component transfer function. See https://en.wikipedia.org/wiki/Gamma_correction.

Converts between integer types and [0.0, 1.0] internal value range.

class IdentityGamma(Gamma):
2169class IdentityGamma(Gamma):
2170  """Identity component transfer function."""
2171
2172  def __init__(self) -> None:
2173    super().__init__('identity')
2174
2175  def decode(self, array: _Array, /, dtype: _DTypeLike = np.float32) -> _Array:
2176    dtype = np.dtype(dtype)
2177    assert np.issubdtype(dtype, np.inexact)
2178    if np.issubdtype(_arr_dtype(array), np.unsignedinteger):
2179      return _to_float_01(array, dtype)
2180    return _arr_astype(array, dtype)
2181
2182  def encode(self, array: _Array, /, dtype: _DTypeLike) -> _Array:
2183    dtype = np.dtype(dtype)
2184    assert np.issubdtype(dtype, np.number)
2185    if np.issubdtype(dtype, np.unsignedinteger):
2186      return _from_float(_arr_clip(array, 0.0, 1.0), dtype)
2187    if np.issubdtype(dtype, np.integer):
2188      return _arr_astype(array + 0.5, dtype)
2189    return _arr_astype(array, dtype)

Gamma(name: 'str')

GAMMAS: list[str] = ['identity', 'power2', 'power22', 'srgb']

Shortcut names for some predefined gamma-correction schemes:

name Gamma Decoding function
(linear space from stored value)
Encoding function
(stored value from linear space)
'identity' IdentityGamma() $l = e$ $e = l$
'power2' PowerGamma(2.0) $l = e^{2.0}$ $e = l^{1/2.0}$
'power22' PowerGamma(2.2) $l = e^{2.2}$ $e = l^{1/2.2}$
'srgb' SrgbGamma() $l = \left(\left(e + 0.055\right) / 1.055\right)^{2.4}$ $e = l^{1/2.4} * 1.055 - 0.055$

See the source code for extensibility.

def resize( array: Array, /, shape: Iterable[int], *, gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, boundary: str | Boundary | Iterable[str | Boundary] = 'auto', cval: ArrayLike = 0, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, gamma: str | Gamma | None = None, src_gamma: str | Gamma | None = None, dst_gamma: str | Gamma | None = None, scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0, precision: DTypeLike = None, dtype: DTypeLike = None, dim_order: Iterable[int] | None = None, num_threads: Union[int, Literal['auto']] = 'auto') -> Array:
2600def resize(
2601    array: _Array,
2602    /,
2603    shape: Iterable[int],
2604    *,
2605    gridtype: str | Gridtype | None = None,
2606    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2607    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2608    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2609    cval: _ArrayLike = 0,
2610    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2611    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2612    gamma: str | Gamma | None = None,
2613    src_gamma: str | Gamma | None = None,
2614    dst_gamma: str | Gamma | None = None,
2615    scale: float | Iterable[float] = 1.0,
2616    translate: float | Iterable[float] = 0.0,
2617    precision: _DTypeLike = None,
2618    dtype: _DTypeLike = None,
2619    dim_order: Iterable[int] | None = None,
2620    num_threads: int | Literal['auto'] = 'auto',
2621) -> _Array:
2622  """Resample `array` (a grid of sample values) onto a grid with resolution `shape`.
2623
2624  The source `array` is any object recognized by `ARRAYLIBS`.  It is interpreted as a grid
2625  with `len(shape)` domain coordinate dimensions, where each grid sample value has shape
2626  `array.shape[len(shape):]`.
2627
2628  Some examples:
2629
2630  - A grayscale image has `array.shape = height, width` and resizing it with `len(shape) == 2`
2631    produces a new image of scalar values.
2632  - An RGB image has `array.shape = height, width, 3` and resizing it with `len(shape) == 2`
2633    produces a new image of RGB values.
2634  - An 3D grid of 3x3 Jacobians has `array.shape = Z, Y, X, 3, 3` and resizing it with
2635    `len(shape) == 3` produces a new 3D grid of Jacobians.
2636
2637  This function also allows scaling and translation from the source domain to the output domain
2638  through the parameters `scale` and `translate`.  For more general transforms, see `resample`.
2639
2640  Args:
2641    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
2642      The array must have numeric type.  Its first `len(shape)` dimensions are the domain
2643      coordinate dimensions.  Each grid dimension must be at least 1 for a `'dual'` grid or
2644      at least 2 for a `'primal'` grid.
2645    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2646      `array` must have at least as many dimensions as `len(shape)`.
2647    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
2648      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
2649      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
2650    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2651      Parameters `gridtype` and `src_gridtype` cannot both be set.
2652    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2653      Parameters `gridtype` and `dst_gridtype` cannot both be set.
2654    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2655      a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses `'reflect'`
2656      for upsampling and `'clamp'` for downsampling.
2657    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
2658      onto `array.shape[len(shape):]`.
2659    filter: The reconstruction kernel for each dimension in `shape`, specified as either a filter
2660      name in `FILTERS` or a `Filter` instance.  It is used during upsampling (i.e., magnification).
2661    prefilter: The prefilter kernel for each dimension in `shape`, specified as either a filter
2662      name in `FILTERS` or a `Filter` instance.  It is used during downsampling
2663      (i.e., minification).  If `None`, it inherits the value of `filter`.  The default
2664      `'lanczos3'` is good for natural images.  For vector graphics images, `'trapezoid'` is better
2665      because it avoids ringing artifacts.
2666    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from
2667      `array` and when creating output grid samples.  It is specified as either a name in `GAMMAS`
2668      or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default is
2669      `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
2670      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
2671      range [0.0, 1.0].
2672    src_gamma: Component transfer function used to "decode" `array` samples.
2673      Parameters `gamma` and `src_gamma` cannot both be set.
2674    dst_gamma: Component transfer function used to "encode" the output samples.
2675      Parameters `gamma` and `dst_gamma` cannot both be set.
2676    scale: Scaling factor applied to each dimension of the source domain when it is mapped onto
2677      the destination domain.
2678    translate: Offset applied to each dimension of the scaled source domain when it is mapped onto
2679      the destination domain.
2680    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
2681      on `array.dtype` and `dtype`.
2682    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
2683      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
2684      to the uint range.
2685    dim_order: Override the automatically selected order in which the grid dimensions are resized.
2686      Must contain a permutation of `range(len(shape))`.
2687    num_threads: Used to determine multithread parallelism if `array` is from `numpy`.  If set to
2688      `'auto'`, it is selected automatically.  Otherwise, it must be a positive integer.
2689
2690  Returns:
2691    An array of the same class as the source `array`, with shape `shape + array.shape[len(shape):]`
2692      and data type `dtype`.
2693
2694  **Example of image upsampling:**
2695
2696  >>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
2697  >>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.
2698
2699  <center>
2700  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_upsampled.png"/>
2701  </center>
2702
2703  **Example of image downsampling:**
2704
2705  >>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
2706  >>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
2707  >>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
2708  >>> downsampled = resize(array, (24, 48))
2709
2710  <center>
2711  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_array_downsampled2.png"/>
2712  </center>
2713
2714  **Unit test:**
2715
2716  >>> result = resize([1.0, 4.0, 5.0], shape=(4,))
2717  >>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
2718  """
2719  if isinstance(array, (tuple, list)):
2720    array = np.asarray(array)
2721  arraylib = _arr_arraylib(array)
2722  array_dtype = _arr_dtype(array)
2723  if not np.issubdtype(array_dtype, np.number):
2724    raise ValueError(f'Type {array.dtype} is not numeric.')
2725  shape2 = tuple(shape)
2726  array_ndim = len(array.shape)
2727  if not 0 < len(shape2) <= array_ndim:
2728    raise ValueError(f'Shape {array.shape} cannot be resized to {shape2}.')
2729  src_shape = array.shape[: len(shape2)]
2730  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2731      gridtype, src_gridtype, dst_gridtype, len(shape2), len(shape2)
2732  )
2733  boundary2 = np.broadcast_to(np.array(boundary), len(shape2))
2734  cval = np.broadcast_to(cval, array.shape[len(shape2) :])
2735  prefilter = filter if prefilter is None else prefilter
2736  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), len(shape2))]
2737  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), len(shape2))]
2738  dtype = array_dtype if dtype is None else np.dtype(dtype)
2739  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, array_dtype, dtype)
2740  scale2 = np.broadcast_to(np.array(scale), len(shape2))
2741  translate2 = np.broadcast_to(np.array(translate), len(shape2))
2742  del (shape, src_gridtype, dst_gridtype, boundary, filter, prefilter)
2743  del (src_gamma, dst_gamma, scale, translate)
2744  precision = _get_precision(precision, [array_dtype, dtype], [])
2745  weight_precision = _real_precision(precision)
2746
2747  is_noop = (
2748      all(src == dst for src, dst in zip(src_shape, shape2))
2749      and all(gt1 == gt2 for gt1, gt2 in zip(src_gridtype2, dst_gridtype2))
2750      and all(f.interpolating for f in prefilter2)
2751      and np.all(scale2 == 1.0)
2752      and np.all(translate2 == 0.0)
2753      and src_gamma2 == dst_gamma2
2754  )
2755  if is_noop:
2756    return array
2757
2758  if dim_order is None:
2759    dim_order = _arr_best_dims_order_for_resize(array, shape2)
2760  else:
2761    dim_order = tuple(dim_order)
2762    if sorted(dim_order) != list(range(len(shape2))):
2763      raise ValueError(f'{dim_order} not a permutation of {list(range(len(shape2)))}.')
2764
2765  array = src_gamma2.decode(array, precision)
2766
2767  can_use_fast_box_downsampling = (
2768      using_numba
2769      and arraylib == 'numpy'
2770      and len(shape2) == 2
2771      and array_ndim in (2, 3)
2772      and all(src > dst for src, dst in zip(src_shape, shape2))
2773      and all(src % dst == 0 for src, dst in zip(src_shape, shape2))
2774      and all(gridtype.name == 'dual' for gridtype in src_gridtype2)
2775      and all(gridtype.name == 'dual' for gridtype in dst_gridtype2)
2776      and all(f.name in ('box', 'trapezoid') for f in prefilter2)
2777      and np.all(scale2 == 1.0)
2778      and np.all(translate2 == 0.0)
2779  )
2780  if can_use_fast_box_downsampling:
2781    assert isinstance(array, np.ndarray)  # Help mypy.
2782    array = _downsample_in_2d_using_box_filter(array, cast(Any, shape2))
2783    array = dst_gamma2.encode(array, dtype)
2784    return array
2785
2786  # Multidimensional resize can be expressed using einsum() with multiple per-dim resize matrices,
2787  # e.g., as in jax.image.resize().  A benefit is to seek the optimal order of multiplications.
2788  # However, efficiency often requires sparse resize matrices, which are unsupported in einsum().
2789  # Sparse tensors requested for tf.einsum: https://github.com/tensorflow/tensorflow/issues/43497
2790  # https://github.com/tensor-compiler/taco: C++ library that computes tensor algebra expressions
2791  # on sparse and dense tensors; however it does not interoperate with tensorflow, torch, or jax.
2792
2793  for dim in dim_order:
2794    skip_resize_on_this_dim = (
2795        shape2[dim] == array.shape[dim]
2796        and scale2[dim] == 1.0
2797        and translate2[dim] == 0.0
2798        and filter2[dim].interpolating
2799    )
2800    if skip_resize_on_this_dim:
2801      continue
2802
2803    def get_is_minification() -> bool:
2804      src_in_samples = src_gridtype2[dim].size_in_samples(array.shape[dim])
2805      dst_in_samples = dst_gridtype2[dim].size_in_samples(shape2[dim])
2806      return dst_in_samples / src_in_samples * scale2[dim] < 1.0
2807
2808    is_minification = get_is_minification()
2809    boundary_dim = boundary2[dim]
2810    if boundary_dim == 'auto':
2811      boundary_dim = 'clamp' if is_minification else 'reflect'
2812    boundary_dim = _get_boundary(boundary_dim)
2813    resize_matrix, cval_weight = _create_resize_matrix(
2814        array.shape[dim],
2815        shape2[dim],
2816        src_gridtype=src_gridtype2[dim],
2817        dst_gridtype=dst_gridtype2[dim],
2818        boundary=boundary_dim,
2819        filter=filter2[dim],
2820        prefilter=prefilter2[dim],
2821        scale=scale2[dim],
2822        translate=translate2[dim],
2823        dtype=weight_precision,
2824        arraylib=arraylib,
2825    )
2826
2827    array_dim: _Array = _arr_moveaxis(array, dim, 0)
2828    array_flat = _arr_reshape(array_dim, (array_dim.shape[0], -1))
2829    array_flat = _arr_possibly_make_contiguous(array_flat)
2830    if not is_minification and filter2[dim].requires_digital_filter:
2831      array_flat = _apply_digital_filter_1d(
2832          array_flat, src_gridtype2[dim], boundary_dim, cval, filter2[dim]
2833      )
2834
2835    array_flat = _arr_matmul_sparse_dense(resize_matrix, array_flat, num_threads=num_threads)
2836    if cval_weight is not None:
2837      cval_flat = np.broadcast_to(cval, array_dim.shape[1:]).reshape(-1)
2838      if np.issubdtype(array_dtype, np.complexfloating):
2839        cval_weight = _arr_astype(cval_weight, array_dtype)  # (Only necessary for 'tensorflow'.)
2840      array_flat += cval_weight[:, None] * cval_flat
2841
2842    if is_minification and filter2[dim].requires_digital_filter:  # use prefilter2[dim]?
2843      array_flat = _apply_digital_filter_1d(
2844          array_flat, dst_gridtype2[dim], boundary_dim, cval, filter2[dim]
2845      )
2846    array_dim = _arr_reshape(array_flat, (array_flat.shape[0], *array_dim.shape[1:]))
2847    array = _arr_moveaxis(array_dim, 0, dim)
2848
2849  array = dst_gamma2.encode(array, dtype)
2850  return array

Resample array (a grid of sample values) onto a grid with resolution shape.

The source array is any object recognized by ARRAYLIBS. It is interpreted as a grid with len(shape) domain coordinate dimensions, where each grid sample value has shape array.shape[len(shape):].

Some examples:

  • A grayscale image has array.shape = height, width and resizing it with len(shape) == 2 produces a new image of scalar values.
  • An RGB image has array.shape = height, width, 3 and resizing it with len(shape) == 2 produces a new image of RGB values.
  • An 3D grid of 3x3 Jacobians has array.shape = Z, Y, X, 3, 3 and resizing it with len(shape) == 3 produces a new 3D grid of Jacobians.

This function also allows scaling and translation from the source domain to the output domain through the parameters scale and translate. For more general transforms, see resample.

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. Its first len(shape) dimensions are the domain coordinate dimensions. Each grid dimension must be at least 1 for a 'dual' grid or at least 2 for a 'primal' grid.
  • shape: The number of grid samples in each coordinate dimension of the output array. The source array must have at least as many dimensions as len(shape).
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual' if gridtype, src_gridtype, and dst_gridtype are all kept None.
  • src_gridtype: Placement of the samples in the source domain grid for each dimension. Parameters gridtype and src_gridtype cannot both be set.
  • dst_gridtype: Placement of the samples in the output domain grid for each dimension. Parameters gridtype and dst_gridtype cannot both be set.
  • boundary: The reconstruction boundary rule for each dimension in shape, specified as either a name in BOUNDARIES or a Boundary instance. The special value 'auto' uses 'reflect' for upsampling and 'clamp' for downsampling.
  • cval: Constant value used beyond the samples by some boundary rules. It must be broadcastable onto array.shape[len(shape):].
  • filter: The reconstruction kernel for each dimension in shape, specified as either a filter name in FILTERS or a Filter instance. It is used during upsampling (i.e., magnification).
  • prefilter: The prefilter kernel for each dimension in shape, specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter. The default 'lanczos3' is good for natural images. For vector graphics images, 'trapezoid' is better because it avoids ringing artifacts.
  • gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from array and when creating output grid samples. It is specified as either a name in GAMMAS or a Gamma instance. If both array.dtype and dtype are uint, the default is 'power2'. If both are non-uint, the default is 'identity'. Otherwise, gamma or src_gamma/dst_gamma must be set. Gamma correction assumes that float values are in the range [0.0, 1.0].
  • src_gamma: Component transfer function used to "decode" array samples. Parameters gamma and src_gamma cannot both be set.
  • dst_gamma: Component transfer function used to "encode" the output samples. Parameters gamma and dst_gamma cannot both be set.
  • scale: Scaling factor applied to each dimension of the source domain when it is mapped onto the destination domain.
  • translate: Offset applied to each dimension of the scaled source domain when it is mapped onto the destination domain.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • dim_order: Override the automatically selected order in which the grid dimensions are resized. Must contain a permutation of range(len(shape)).
  • num_threads: Used to determine multithread parallelism if array is from numpy. If set to 'auto', it is selected automatically. Otherwise, it must be a positive integer.
Returns:

An array of the same class as the source array, with shape shape + array.shape[len(shape):] and data type dtype.

Example of image upsampling:

>>> array = np.random.default_rng(1).random((4, 6, 3))  # 4x6 RGB image.
>>> upsampled = resize(array, (128, 192))  # To 128x192 resolution.

Example of image downsampling:

>>> yx = (np.moveaxis(np.indices((96, 192)), 0, -1) + (0.5, 0.5)) / 96
>>> radius = np.linalg.norm(yx - (0.75, 0.5), axis=-1)
>>> array = np.cos((radius + 0.1) ** 0.5 * 70.0) * 0.5 + 0.5
>>> downsampled = resize(array, (24, 48))

Unit test:

>>> result = resize([1.0, 4.0, 5.0], shape=(4,))
>>> assert np.allclose(result, [0.74240461, 2.88088827, 4.68647155, 5.02641199])
def jaxjit_resize(array: Array, /, *args: Any, **kwargs: Any) -> Array:
2904def jaxjit_resize(array: _Array, /, *args: Any, **kwargs: Any) -> _Array:
2905  """Compute `resize` but with resize function jitted using Jax."""
2906  return _create_jaxjit_resize()(array, *args, **kwargs)  # pylint: disable=not-callable

Compute resize but with resize function jitted using Jax.

def uniform_resize( array: Array, /, shape: Iterable[int], *, object_fit: Literal['contain', 'cover'] = 'contain', gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, boundary: str | Boundary | Iterable[str | Boundary] = 'border', scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0, **kwargs: Any) -> Array:
2909def uniform_resize(
2910    array: _Array,
2911    /,
2912    shape: Iterable[int],
2913    *,
2914    object_fit: Literal['contain', 'cover'] = 'contain',
2915    gridtype: str | Gridtype | None = None,
2916    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2917    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
2918    boundary: str | Boundary | Iterable[str | Boundary] = 'natural',  # Instead of 'auto' default.
2919    scale: float | Iterable[float] = 1.0,
2920    translate: float | Iterable[float] = 0.0,
2921    **kwargs: Any,
2922) -> _Array:
2923  """Resample `array` onto a grid with resolution `shape` but with uniform scaling.
2924
2925  Calls function `resize` with `scale` and `translate` set such that the aspect ratio of `array`
2926  is preserved.  The effect is similar to CSS `object-fit: contain`.
2927  The parameter `boundary` (whose default is changed to `'natural'`) determines the values assigned
2928  outside the source domain.
2929
2930  Args:
2931    array: Regular grid of source sample values.
2932    shape: The number of grid samples in each coordinate dimension of the output array.  The source
2933      `array` must have at least as many dimensions as `len(shape)`.
2934    object_fit: Like CSS `object-fit`.  If `'contain'`, `array` is resized uniformly to fit within
2935      `shape`. If `'cover'`, `array` is resized to fully cover `shape`.
2936    gridtype: Placement of samples on all dimensions of both the source and output domain grids.
2937    src_gridtype: Placement of the samples in the source domain grid for each dimension.
2938    dst_gridtype: Placement of the samples in the output domain grid for each dimension.
2939    boundary: The reconstruction boundary rule for each dimension in `shape`, specified as either
2940      a name in `BOUNDARIES` or a `Boundary` instance.  The default is `'natural'`, which assigns
2941      `cval` to output points that map outside the source unit domain.
2942    scale: Parameter may not be specified.
2943    translate: Parameter may not be specified.
2944    **kwargs: Additional parameters for `resize` function (including `cval`).
2945
2946  Returns:
2947    An array with shape `shape + array.shape[len(shape):]`.
2948
2949  >>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
2950  array([[0., 1., 1., 0.],
2951         [0., 1., 1., 0.]])
2952
2953  >>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
2954  array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
2955         [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
2956
2957  >>> a = np.arange(6.0).reshape(2, 3)
2958  >>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
2959  array([[0.5, 1.5],
2960         [3.5, 4.5]])
2961  """
2962  if scale != 1.0 or translate != 0.0:
2963    raise ValueError('`uniform_resize()` does not accept `scale` or `translate` parameters.')
2964  if isinstance(array, (tuple, list)):
2965    array = np.asarray(array)
2966  shape = tuple(shape)
2967  array_ndim = len(array.shape)
2968  if not 0 < len(shape) <= array_ndim:
2969    raise ValueError(f'Shape {array.shape} cannot be resized to {shape}.')
2970  src_gridtype2, dst_gridtype2 = _get_gridtypes(
2971      gridtype, src_gridtype, dst_gridtype, len(shape), len(shape)
2972  )
2973  raw_scales = [
2974      dst_gridtype2[dim].size_in_samples(shape[dim])
2975      / src_gridtype2[dim].size_in_samples(array.shape[dim])
2976      for dim in range(len(shape))
2977  ]
2978  scale0 = {'contain': min(raw_scales), 'cover': max(raw_scales)}[object_fit]
2979  scale2 = scale0 / np.array(raw_scales)
2980  translate = (1.0 - scale2) / 2
2981  return resize(array, shape, boundary=boundary, scale=scale2, translate=translate, **kwargs)

Resample array onto a grid with resolution shape but with uniform scaling.

Calls function resize with scale and translate set such that the aspect ratio of array is preserved. The effect is similar to CSS object-fit: contain. The parameter boundary (whose default is changed to 'natural') determines the values assigned outside the source domain.

Arguments:
  • array: Regular grid of source sample values.
  • shape: The number of grid samples in each coordinate dimension of the output array. The source array must have at least as many dimensions as len(shape).
  • object_fit: Like CSS object-fit. If 'contain', array is resized uniformly to fit within shape. If 'cover', array is resized to fully cover shape.
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids.
  • src_gridtype: Placement of the samples in the source domain grid for each dimension.
  • dst_gridtype: Placement of the samples in the output domain grid for each dimension.
  • boundary: The reconstruction boundary rule for each dimension in shape, specified as either a name in BOUNDARIES or a Boundary instance. The default is 'natural', which assigns cval to output points that map outside the source unit domain.
  • scale: Parameter may not be specified.
  • translate: Parameter may not be specified.
  • **kwargs: Additional parameters for resize function (including cval).
Returns:

An array with shape shape + array.shape[len(shape):].

>>> uniform_resize(np.ones((2, 2)), (2, 4), filter='trapezoid')
array([[0., 1., 1., 0.],
       [0., 1., 1., 0.]])
>>> uniform_resize(np.ones((4, 8)), (2, 7), filter='trapezoid')
array([[0. , 0.5, 1. , 1. , 1. , 0.5, 0. ],
       [0. , 0.5, 1. , 1. , 1. , 0.5, 0. ]])
>>> a = np.arange(6.0).reshape(2, 3)
>>> uniform_resize(a, (2, 2), filter='trapezoid', object_fit='cover')
array([[0.5, 1.5],
       [3.5, 4.5]])
def resample( array: Array, /, coords: ArrayLike, *, gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual', boundary: str | Boundary | Iterable[str | Boundary] = 'auto', cval: ArrayLike = 0, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, gamma: str | Gamma | None = None, src_gamma: str | Gamma | None = None, dst_gamma: str | Gamma | None = None, jacobian: Optional[ArrayLike] = None, precision: DTypeLike = None, dtype: DTypeLike = None, max_block_size: int = 40000, debug: bool = False) -> Array:
2987def resample(
2988    array: _Array,
2989    /,
2990    coords: _ArrayLike,
2991    *,
2992    gridtype: str | Gridtype | Iterable[str | Gridtype] = 'dual',
2993    boundary: str | Boundary | Iterable[str | Boundary] = 'auto',
2994    cval: _ArrayLike = 0,
2995    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
2996    prefilter: str | Filter | Iterable[str | Filter] | None = None,
2997    gamma: str | Gamma | None = None,
2998    src_gamma: str | Gamma | None = None,
2999    dst_gamma: str | Gamma | None = None,
3000    jacobian: _ArrayLike | None = None,
3001    precision: _DTypeLike = None,
3002    dtype: _DTypeLike = None,
3003    max_block_size: int = 40_000,
3004    debug: bool = False,
3005) -> _Array:
3006  """Interpolate `array` (a grid of samples) at specified unit-domain coordinates `coords`.
3007
3008  The last dimension of `coords` contains unit-domain coordinates at which to interpolate the
3009  domain grid samples in `array`.
3010
3011  The number of coordinates (`coords.shape[-1]`) determines how to interpret `array`: its first
3012  `coords.shape[-1]` dimensions define the grid, and the remaining dimensions describe each grid
3013  sample (e.g., scalar, vector, tensor).
3014
3015  Concretely, the grid has shape `array.shape[:coords.shape[-1]]` and each grid sample has shape
3016  `array.shape[coords.shape[-1]:]`.
3017
3018  Examples include:
3019
3020  - Resample a grayscale image with `array.shape = height, width` onto a new grayscale image with
3021    `new.shape = height2, width2` by using `coords.shape = height2, width2, 2`.
3022
3023  - Resample an RGB image with `array.shape = height, width, 3` onto a new RGB image with
3024    `new.shape = height2, width2, 3` by using `coords.shape = height2, width2, 2`.
3025
3026  - Sample an RGB image at `num` 2D points along a line segment by using `coords.shape = num, 2`.
3027
3028  - Sample an RGB image at a single 2D point by using `coords.shape = (2,)`.
3029
3030  - Sample a 3D grid of 3x3 Jacobians with `array.shape = nz, ny, nx, 3, 3` along a 2D plane by
3031    using `coords.shape = height, width, 3`.
3032
3033  - Map a grayscale image through a color map by using `array.shape = 256, 3` and
3034    `coords.shape = height, width`.
3035
3036  Args:
3037    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3038      The array must have numeric type.  The coordinate dimensions appear first, and
3039      each grid sample may have an arbitrary shape.  Each grid dimension must be at least 1 for
3040      a `'dual'` grid or at least 2 for a `'primal'` grid.
3041    coords: Grid of points at which to resample `array`.  The point coordinates are in the last
3042      dimension of `coords`.  The domain associated with the source grid is a unit hypercube,
3043      i.e. with a range [0, 1] on each coordinate dimension.  The output grid has shape
3044      `coords.shape[:-1]` and each of its grid samples has shape `array.shape[coords.shape[-1]:]`.
3045    gridtype: Placement of the samples in the source domain grid for each dimension, specified as
3046      either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`.
3047    boundary: The reconstruction boundary rule for each dimension in `coords.shape[-1]`, specified
3048      as either a name in `BOUNDARIES` or a `Boundary` instance.  The special value `'auto'` uses
3049      `'reflect'` for upsampling and `'clamp'` for downsampling.
3050    cval: Constant value used beyond the samples by some boundary rules.  It must be broadcastable
3051      onto the shape `array.shape[coords.shape[-1]:]`.
3052    filter: The reconstruction kernel for each dimension in `coords.shape[-1]`, specified as either
3053      a filter name in `FILTERS` or a `Filter` instance.
3054    prefilter: The prefilter kernel for each dimension in `coords.shape[:-1]`, specified as either
3055      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3056      (i.e., minification).  If `None`, it inherits the value of `filter`.
3057    gamma: Component transfer functions (e.g., gamma correction) applied when reading samples
3058      from `array` and when creating output grid samples.  It is specified as either a name in
3059      `GAMMAS` or a `Gamma` instance.  If both `array.dtype` and `dtype` are `uint`, the default
3060      is `'power2'`.  If both are non-`uint`, the default is `'identity'`.  Otherwise, `gamma` or
3061      `src_gamma`/`dst_gamma` must be set.   Gamma correction assumes that float values are in the
3062      range [0.0, 1.0].
3063    src_gamma: Component transfer function used to "decode" `array` samples.
3064      Parameters `gamma` and `src_gamma` cannot both be set.
3065    dst_gamma: Component transfer function used to "encode" the output samples.
3066      Parameters `gamma` and `dst_gamma` cannot both be set.
3067    jacobian: Optional array, which must be broadcastable onto the shape
3068      `coords.shape[:-1] + (coords.shape[-1], coords.shape[-1])`, storing for each point in the
3069      output grid the Jacobian matrix of the map from the unit output domain to the unit source
3070      domain.  If omitted, it is estimated by computing finite differences on `coords`.
3071    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3072      on `array.dtype`, `coords.dtype`, and `dtype`.
3073    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3074      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3075      to the uint range.
3076    max_block_size: If nonzero, maximum number of grid points in `coords` before the resampling
3077      evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
3078    debug: Show internal information.
3079
3080  Returns:
3081    A new sample grid of shape `coords.shape[:-1]`, represented as an array of shape
3082    `coords.shape[:-1] + array.shape[coords.shape[-1]:]`, of the same array library type as
3083    the source array.
3084
3085  **Example of resample operation:**
3086
3087  <center>
3088  <img src="https://github.com/hhoppe/resampler/raw/main/media/example_warp_coords.png"/>
3089  </center>
3090
3091  For reference, the identity resampling for a scalar-valued grid with the default grid-type
3092  `'dual'` is:
3093
3094  >>> array = np.random.default_rng(0).random((5, 7, 3))
3095  >>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
3096  >>> new_array = resample(array, coords)
3097  >>> assert np.allclose(new_array, array)
3098
3099  It is more efficient to use the function `resize` for the special case where the `coords` are
3100  obtained as simple scaling and translation of a new regular grid over the source domain:
3101
3102  >>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
3103  >>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
3104  >>> coords = (coords - translate) / scale
3105  >>> resampled = resample(array, coords)
3106  >>> resized = resize(array, new_shape, scale=scale, translate=translate)
3107  >>> assert np.allclose(resampled, resized)
3108  """
3109  if isinstance(array, (tuple, list)):
3110    array = np.asarray(array)
3111  arraylib = _arr_arraylib(array)
3112  if len(array.shape) == 0:
3113    array = array[None]
3114  coords = np.atleast_1d(coords)
3115  if not np.issubdtype(_arr_dtype(array), np.number):
3116    raise ValueError(f'Type {array.dtype} is not numeric.')
3117  if not np.issubdtype(coords.dtype, np.floating):
3118    raise ValueError(f'Type {coords.dtype} is not floating.')
3119  array_ndim = len(array.shape)
3120  if coords.ndim == 1 and coords.shape[0] > 1 and array_ndim == 1:
3121    coords = coords[:, None]
3122  grid_ndim = coords.shape[-1]
3123  grid_shape = array.shape[:grid_ndim]
3124  sample_shape = array.shape[grid_ndim:]
3125  resampled_ndim = coords.ndim - 1
3126  resampled_shape = coords.shape[:-1]
3127  if grid_ndim > array_ndim:
3128    raise ValueError(
3129        f'There are more coordinate dimensions ({grid_ndim}) in {coords=} than in {array.shape=}.'
3130    )
3131  gridtype2 = [_get_gridtype(g) for g in np.broadcast_to(np.array(gridtype), grid_ndim)]
3132  boundary2 = np.broadcast_to(np.array(boundary), grid_ndim).tolist()
3133  cval = np.broadcast_to(cval, sample_shape)
3134  prefilter = filter if prefilter is None else prefilter
3135  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), grid_ndim)]
3136  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), resampled_ndim)]
3137  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3138  src_gamma2, dst_gamma2 = _get_src_dst_gamma(gamma, src_gamma, dst_gamma, _arr_dtype(array), dtype)
3139  del gridtype, boundary, filter, prefilter, src_gamma, dst_gamma
3140  if jacobian is not None:
3141    jacobian = np.broadcast_to(jacobian, resampled_shape + (coords.shape[-1],) * 2)
3142  precision = _get_precision(precision, [_arr_dtype(array), dtype], [coords.dtype])
3143  weight_precision = _real_precision(precision)
3144  coords = coords.astype(weight_precision, copy=False)
3145  is_minification = False  # Current limitation; no prefiltering!
3146  assert max_block_size >= 0 or max_block_size == _MAX_BLOCK_SIZE_RECURSING
3147  for dim in range(grid_ndim):
3148    if boundary2[dim] == 'auto':
3149      boundary2[dim] = 'clamp' if is_minification else 'reflect'
3150    boundary2[dim] = _get_boundary(boundary2[dim])
3151
3152  if max_block_size != _MAX_BLOCK_SIZE_RECURSING:
3153    array = src_gamma2.decode(array, precision)
3154    for dim in range(grid_ndim):
3155      assert not is_minification
3156      if filter2[dim].requires_digital_filter:
3157        array = _apply_digital_filter_1d(
3158            array, gridtype2[dim], boundary2[dim], cval, filter2[dim], axis=dim
3159        )
3160
3161  if math.prod(resampled_shape) > max_block_size > 0:
3162    block_shape = _block_shape_with_min_size(resampled_shape, max_block_size)
3163    if debug:
3164      print(f'(resample: splitting coords into blocks {block_shape}).')
3165    coord_blocks = _split_array_into_blocks(coords, block_shape)
3166
3167    def process_block(coord_block: _NDArray) -> _Array:
3168      return resample(
3169          array,
3170          coord_block,
3171          gridtype=gridtype2,
3172          boundary=boundary2,
3173          cval=cval,
3174          filter=filter2,
3175          prefilter=prefilter2,
3176          src_gamma='identity',
3177          dst_gamma=dst_gamma2,
3178          jacobian=jacobian,
3179          precision=precision,
3180          dtype=dtype,
3181          max_block_size=_MAX_BLOCK_SIZE_RECURSING,
3182      )
3183
3184    result_blocks = _map_function_over_blocks(coord_blocks, process_block)
3185    array = _merge_array_from_blocks(result_blocks)
3186    return array
3187
3188  # A concrete example of upsampling:
3189  #   array = np.ones((5, 7, 3))  # source RGB image has height=5 width=7
3190  #   coords = np.random.default_rng(0).random((8, 9, 2))  # output RGB image has height=8 width=9
3191  #   resample(array, coords, filter=('cubic', 'lanczos3'))
3192  #   grid_shape = 5, 7  grid_ndim = 2
3193  #   resampled_shape = 8, 9  resampled_ndim = 2
3194  #   sample_shape = (3,)
3195  #   src_float_index.shape = 8, 9
3196  #   src_first_index.shape = 8, 9
3197  #   sample_index.shape = (4,) for dim == 0, then (6,) for dim == 1
3198  #   weight = [shape(8, 9, 4), shape(8, 9, 6)]
3199  #   src_index = [shape(8, 9, 4), shape(8, 9, 6)]
3200
3201  # Both:[shape(8, 9, 4), shape(8, 9, 6)]
3202  weight: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3203  src_index: list[_NDArray] = [np.array([]) for _ in range(grid_ndim)]
3204  uses_cval = False
3205  all_num_samples = []  # will be [4, 6]
3206
3207  for dim in range(grid_ndim):
3208    src_size = grid_shape[dim]  # scalar
3209    coords_dim = coords[..., dim]  # (8, 9)
3210    radius = filter2[dim].radius  # scalar
3211    num_samples = int(np.ceil(radius * 2))  # scalar
3212    all_num_samples.append(num_samples)
3213
3214    boundary_dim = boundary2[dim]
3215    coords_dim = boundary_dim.preprocess_coordinates(coords_dim)
3216
3217    # Sample positions mapped back to source unit domain [0, 1].
3218    src_float_index = gridtype2[dim].index_from_point(coords_dim, src_size)  # (8, 9)
3219    src_first_index = (
3220        np.floor(src_float_index + (0.5 if num_samples % 2 == 1 else 0.0)).astype(np.int32)
3221        - (num_samples - 1) // 2
3222    )  # (8, 9)
3223
3224    sample_index = np.arange(num_samples, dtype=np.int32)  # (4,) then (6,)
3225    src_index[dim] = src_first_index[..., None] + sample_index  # (8, 9, 4) then (8, 9, 6)
3226    if filter2[dim].name == 'trapezoid':
3227      # (It might require changing the filter radius at every sample.)
3228      raise ValueError('resample() cannot use adaptive `trapezoid` filter.')
3229    if filter2[dim].name == 'impulse':
3230      weight[dim] = np.ones_like(src_index[dim], weight_precision)
3231    else:
3232      x = src_float_index[..., None] - src_index[dim].astype(weight_precision)
3233      weight[dim] = filter2[dim](x).astype(weight_precision, copy=False)
3234      if filter2[dim].name != 'narrowbox' and (
3235          is_minification or not filter2[dim].partition_of_unity
3236      ):
3237        weight[dim] = weight[dim] / weight[dim].sum(axis=-1)[..., None]
3238
3239    src_index[dim], weight[dim] = boundary_dim.apply(
3240        src_index[dim], weight[dim], coords_dim, src_size, gridtype2[dim]
3241    )
3242    if boundary_dim.uses_cval or filter2[dim].name == 'narrowbox':
3243      uses_cval = True
3244
3245  # Gather the samples.
3246
3247  # Recall that src_index = [shape(8, 9, 4), shape(8, 9, 6)].
3248  src_index_expanded = []
3249  for dim in range(grid_ndim):
3250    src_index_dim = np.moveaxis(
3251        src_index[dim].reshape(src_index[dim].shape + (1,) * (grid_ndim - 1)),
3252        resampled_ndim,
3253        resampled_ndim + dim,
3254    )
3255    src_index_expanded.append(src_index_dim)
3256  indices = tuple(src_index_expanded)  # (shape(8, 9, 4, 1), shape(8, 9, 1, 6))
3257  samples = _arr_getitem(array, indices)  # (8, 9, 4, 6, 3)
3258
3259  # Indirectly derive samples.ndim (which is unavailable during Tensorflow grad computation).
3260  samples_ndim = resampled_ndim + grid_ndim + len(sample_shape)
3261
3262  # Compute an Einstein summation over the samples and each of the per-dimension weights.
3263
3264  def label(dims: Iterable[int]) -> str:
3265    return ''.join(chr(ord('a') + i) for i in dims)
3266
3267  operands = [samples]  # (8, 9, 4, 6, 3)
3268  assert samples_ndim < 26  # Letters 'a' through 'z'.
3269  labels = [label(range(samples_ndim))]  # ['abcde']
3270  for dim in range(grid_ndim):
3271    operands.append(weight[dim])  # (8, 9, 4), then (8, 9, 6)
3272    labels.append(label(list(range(resampled_ndim)) + [resampled_ndim + dim]))  # 'abc' then 'abd'
3273  output_label = label(
3274      list(range(resampled_ndim)) + list(range(resampled_ndim + grid_ndim, samples_ndim))
3275  )  # 'abe'
3276  subscripts = ','.join(labels) + '->' + output_label  # 'abcde,abc,abd->abe'
3277  array = _arr_einsum(subscripts, *operands)  # (8, 9, 3)
3278
3279  # Gathering `samples` is the memory bottleneck.  It would be ideal if the gather() and einsum()
3280  # computations could be fused.  In Jax, https://github.com/google/jax/issues/3206 suggests
3281  # that this may become possible.  In any case, for large outputs it helps to partition the
3282  # evaluation over output tiles (using max_block_size).
3283
3284  if uses_cval:
3285    cval_weight = 1.0 - np.multiply.reduce(
3286        [weight[dim].sum(axis=-1) for dim in range(resampled_ndim)]
3287    )  # (8, 9)
3288    cval_weight_reshaped = cval_weight.reshape(cval_weight.shape + (1,) * len(sample_shape))
3289    array += _make_array((cval_weight_reshaped * cval).astype(precision, copy=False), arraylib)
3290
3291  array = dst_gamma2.encode(array, dtype)
3292  return array

Interpolate array (a grid of samples) at specified unit-domain coordinates coords.

The last dimension of coords contains unit-domain coordinates at which to interpolate the domain grid samples in array.

The number of coordinates (coords.shape[-1]) determines how to interpret array: its first coords.shape[-1] dimensions define the grid, and the remaining dimensions describe each grid sample (e.g., scalar, vector, tensor).

Concretely, the grid has shape array.shape[:coords.shape[-1]] and each grid sample has shape array.shape[coords.shape[-1]:].

Examples include:

  • Resample a grayscale image with array.shape = height, width onto a new grayscale image with new.shape = height2, width2 by using coords.shape = height2, width2, 2.

  • Resample an RGB image with array.shape = height, width, 3 onto a new RGB image with new.shape = height2, width2, 3 by using coords.shape = height2, width2, 2.

  • Sample an RGB image at num 2D points along a line segment by using coords.shape = num, 2.

  • Sample an RGB image at a single 2D point by using coords.shape = (2,).

  • Sample a 3D grid of 3x3 Jacobians with array.shape = nz, ny, nx, 3, 3 along a 2D plane by using coords.shape = height, width, 3.

  • Map a grayscale image through a color map by using array.shape = 256, 3 and coords.shape = height, width.

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. The coordinate dimensions appear first, and each grid sample may have an arbitrary shape. Each grid dimension must be at least 1 for a 'dual' grid or at least 2 for a 'primal' grid.
  • coords: Grid of points at which to resample array. The point coordinates are in the last dimension of coords. The domain associated with the source grid is a unit hypercube, i.e. with a range [0, 1] on each coordinate dimension. The output grid has shape coords.shape[:-1] and each of its grid samples has shape array.shape[coords.shape[-1]:].
  • gridtype: Placement of the samples in the source domain grid for each dimension, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual'.
  • boundary: The reconstruction boundary rule for each dimension in coords.shape[-1], specified as either a name in BOUNDARIES or a Boundary instance. The special value 'auto' uses 'reflect' for upsampling and 'clamp' for downsampling.
  • cval: Constant value used beyond the samples by some boundary rules. It must be broadcastable onto the shape array.shape[coords.shape[-1]:].
  • filter: The reconstruction kernel for each dimension in coords.shape[-1], specified as either a filter name in FILTERS or a Filter instance.
  • prefilter: The prefilter kernel for each dimension in coords.shape[:-1], specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter.
  • gamma: Component transfer functions (e.g., gamma correction) applied when reading samples from array and when creating output grid samples. It is specified as either a name in GAMMAS or a Gamma instance. If both array.dtype and dtype are uint, the default is 'power2'. If both are non-uint, the default is 'identity'. Otherwise, gamma or src_gamma/dst_gamma must be set. Gamma correction assumes that float values are in the range [0.0, 1.0].
  • src_gamma: Component transfer function used to "decode" array samples. Parameters gamma and src_gamma cannot both be set.
  • dst_gamma: Component transfer function used to "encode" the output samples. Parameters gamma and dst_gamma cannot both be set.
  • jacobian: Optional array, which must be broadcastable onto the shape coords.shape[:-1] + (coords.shape[-1], coords.shape[-1]), storing for each point in the output grid the Jacobian matrix of the map from the unit output domain to the unit source domain. If omitted, it is estimated by computing finite differences on coords.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype, coords.dtype, and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • max_block_size: If nonzero, maximum number of grid points in coords before the resampling evaluation gets partitioned into smaller blocks for reduced memory usage and better caching.
  • debug: Show internal information.
Returns:

A new sample grid of shape coords.shape[:-1], represented as an array of shape coords.shape[:-1] + array.shape[coords.shape[-1]:], of the same array library type as the source array.

Example of resample operation:

For reference, the identity resampling for a scalar-valued grid with the default grid-type 'dual' is:

>>> array = np.random.default_rng(0).random((5, 7, 3))
>>> coords = (np.moveaxis(np.indices(array.shape), 0, -1) + 0.5) / array.shape
>>> new_array = resample(array, coords)
>>> assert np.allclose(new_array, array)

It is more efficient to use the function resize for the special case where the coords are obtained as simple scaling and translation of a new regular grid over the source domain:

>>> scale, translate, new_shape = (1.1, 1.2), (0.1, -0.2), (6, 8)
>>> coords = (np.moveaxis(np.indices(new_shape), 0, -1) + 0.5) / new_shape
>>> coords = (coords - translate) / scale
>>> resampled = resample(array, coords)
>>> resized = resize(array, new_shape, scale=scale, translate=translate)
>>> assert np.allclose(resampled, resized)
def resample_affine( array: Array, /, shape: Iterable[int], matrix: ArrayLike, *, gridtype: str | Gridtype | None = None, src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None, filter: str | Filter | Iterable[str | Filter] = 'lanczos3', prefilter: str | Filter | Iterable[str | Filter] | None = None, precision: DTypeLike = None, dtype: DTypeLike = None, **kwargs: Any) -> Array:
3295def resample_affine(
3296    array: _Array,
3297    /,
3298    shape: Iterable[int],
3299    matrix: _ArrayLike,
3300    *,
3301    gridtype: str | Gridtype | None = None,
3302    src_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3303    dst_gridtype: str | Gridtype | Iterable[str | Gridtype] | None = None,
3304    filter: str | Filter | Iterable[str | Filter] = _DEFAULT_FILTER,
3305    prefilter: str | Filter | Iterable[str | Filter] | None = None,
3306    precision: _DTypeLike = None,
3307    dtype: _DTypeLike = None,
3308    **kwargs: Any,
3309) -> _Array:
3310  """Resample a source array using an affinely transformed grid of given shape.
3311
3312  The `matrix` transformation can be linear,
3313    `source_point = matrix @ destination_point`,
3314  or it can be affine where the last matrix column is an offset vector,
3315    `source_point = matrix @ (destination_point, 1.0)`.
3316
3317  Args:
3318    array: Regular grid of source sample values, as an array object recognized by `ARRAYLIBS`.
3319      The array must have numeric type.  The number of grid dimensions is determined from
3320      `matrix.shape[0]`; the remaining dimensions are for each sample value and are all
3321      linearly interpolated.
3322    shape: Dimensions of the desired destination grid.  The number of destination grid dimensions
3323      may be different from that of the source grid.
3324    matrix: 2D array for a linear or affine transform from unit-domain destination points
3325      (in a space with `len(shape)` dimensions) into unit-domain source points (in a space with
3326      `matrix.shape[0]` dimensions).  If the matrix has `len(shape) + 1` columns, the last column
3327      is the affine offset (i.e., translation).
3328    gridtype: Placement of samples on all dimensions of both the source and output domain grids,
3329      specified as either a name in `GRIDTYPES` or a `Gridtype` instance.  It defaults to `'dual'`
3330      if `gridtype`, `src_gridtype`, and `dst_gridtype` are all kept `None`.
3331    src_gridtype: Placement of samples in the source domain grid for each dimension.
3332      Parameters `gridtype` and `src_gridtype` cannot both be set.
3333    dst_gridtype: Placement of samples in the output domain grid for each dimension.
3334      Parameters `gridtype` and `dst_gridtype` cannot both be set.
3335    filter: The reconstruction kernel for each dimension in `matrix.shape[0]`, specified as either
3336      a filter name in `FILTERS` or a `Filter` instance.
3337    prefilter: The prefilter kernel for each dimension in `len(shape)`, specified as either
3338      a filter name in `FILTERS` or a `Filter` instance.  It is used during downsampling
3339      (i.e., minification).  If `None`, it inherits the value of `filter`.
3340    precision: Inexact precision of intermediate computations.  If `None`, it is determined based
3341      on `array.dtype` and `dtype`.
3342    dtype: Desired data type of the output array.  If `None`, it is taken to be `array.dtype`.
3343      If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range
3344      to the uint range.
3345    **kwargs: Additional parameters for `resample` function.
3346
3347  Returns:
3348    An array of the same class as the source `array`, representing a grid with specified `shape`,
3349    where each grid value is resampled from `array`.  Thus the shape of the returned array is
3350    `shape + array.shape[matrix.shape[0]:]`.
3351  """
3352  if isinstance(array, (tuple, list)):
3353    array = np.asarray(array)
3354  shape = tuple(shape)
3355  matrix = np.asarray(matrix)
3356  dst_ndim = len(shape)
3357  if matrix.ndim != 2:
3358    raise ValueError(f'Array {matrix} is not 2D matrix.')
3359  src_ndim = matrix.shape[0]
3360  # grid_shape = array.shape[:src_ndim]
3361  is_affine = matrix.shape[1] == dst_ndim + 1
3362  if src_ndim > len(array.shape):
3363    raise ValueError(
3364        f'Matrix {matrix} has more rows ({matrix.shape[0]}) than ndim in {array.shape=}.'
3365    )
3366  if matrix.shape[1] != dst_ndim and not is_affine:
3367    raise ValueError(
3368        f'Matrix has {matrix.shape=}, but we expect either {dst_ndim} or {dst_ndim + 1} columns.'
3369    )
3370  src_gridtype2, dst_gridtype2 = _get_gridtypes(
3371      gridtype, src_gridtype, dst_gridtype, src_ndim, dst_ndim
3372  )
3373  prefilter = filter if prefilter is None else prefilter
3374  filter2 = [_get_filter(f) for f in np.broadcast_to(np.array(filter), src_ndim)]
3375  prefilter2 = [_get_filter(f) for f in np.broadcast_to(np.array(prefilter), dst_ndim)]
3376  del src_gridtype, dst_gridtype, filter, prefilter
3377  dtype = _arr_dtype(array) if dtype is None else np.dtype(dtype)
3378  precision = _get_precision(precision, [_arr_dtype(array), dtype], [])
3379  weight_precision = _real_precision(precision)
3380
3381  dst_position_list = []  # per dimension
3382  for dim in range(dst_ndim):
3383    dst_size = shape[dim]
3384    dst_index = np.arange(dst_size, dtype=weight_precision)
3385    dst_position_list.append(dst_gridtype2[dim].point_from_index(dst_index, dst_size))
3386  dst_position = np.meshgrid(*dst_position_list, indexing='ij')
3387
3388  linear_matrix = matrix[:, :-1] if is_affine else matrix
3389  src_position = np.tensordot(linear_matrix, dst_position, 1)
3390  coords = np.moveaxis(src_position, 0, -1)
3391  if is_affine:
3392    coords += matrix[:, -1]
3393
3394  # TODO: Based on grid_shape, shape, linear_matrix, and prefilter, determine a
3395  # convolution prefilter and apply it to bandlimit 'array', using boundary for padding.
3396
3397  return resample(
3398      array,
3399      coords,
3400      gridtype=src_gridtype2,
3401      filter=filter2,
3402      prefilter=prefilter2,
3403      precision=precision,
3404      dtype=dtype,
3405      **kwargs,
3406  )

Resample a source array using an affinely transformed grid of given shape.

The matrix transformation can be linear, source_point = matrix @ destination_point, or it can be affine where the last matrix column is an offset vector, source_point = matrix @ (destination_point, 1.0).

Arguments:
  • array: Regular grid of source sample values, as an array object recognized by ARRAYLIBS. The array must have numeric type. The number of grid dimensions is determined from matrix.shape[0]; the remaining dimensions are for each sample value and are all linearly interpolated.
  • shape: Dimensions of the desired destination grid. The number of destination grid dimensions may be different from that of the source grid.
  • matrix: 2D array for a linear or affine transform from unit-domain destination points (in a space with len(shape) dimensions) into unit-domain source points (in a space with matrix.shape[0] dimensions). If the matrix has len(shape) + 1 columns, the last column is the affine offset (i.e., translation).
  • gridtype: Placement of samples on all dimensions of both the source and output domain grids, specified as either a name in GRIDTYPES or a Gridtype instance. It defaults to 'dual' if gridtype, src_gridtype, and dst_gridtype are all kept None.
  • src_gridtype: Placement of samples in the source domain grid for each dimension. Parameters gridtype and src_gridtype cannot both be set.
  • dst_gridtype: Placement of samples in the output domain grid for each dimension. Parameters gridtype and dst_gridtype cannot both be set.
  • filter: The reconstruction kernel for each dimension in matrix.shape[0], specified as either a filter name in FILTERS or a Filter instance.
  • prefilter: The prefilter kernel for each dimension in len(shape), specified as either a filter name in FILTERS or a Filter instance. It is used during downsampling (i.e., minification). If None, it inherits the value of filter.
  • precision: Inexact precision of intermediate computations. If None, it is determined based on array.dtype and dtype.
  • dtype: Desired data type of the output array. If None, it is taken to be array.dtype. If it is a uint type, the intermediate float values are rescaled from the [0.0, 1.0] range to the uint range.
  • **kwargs: Additional parameters for resample function.
Returns:

An array of the same class as the source array, representing a grid with specified shape, where each grid value is resampled from array. Thus the shape of the returned array is shape + array.shape[matrix.shape[0]:].

def rotation_about_center_in_2d( src_shape: ArrayLike, /, angle: float, *, new_shape: Optional[ArrayLike] = None, scale: float = 1.0) -> np.ndarray:
3437def rotation_about_center_in_2d(
3438    src_shape: _ArrayLike,
3439    /,
3440    angle: float,
3441    *,
3442    new_shape: _ArrayLike | None = None,
3443    scale: float = 1.0,
3444) -> _NDArray:
3445  """Return the 3x3 matrix mapping destination into a source unit domain.
3446
3447  The returned matrix accounts for the possibly non-square domain shapes.
3448
3449  Args:
3450    src_shape: Resolution `(ny, nx)` of the source domain grid.
3451    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3452      onto the destination domain.
3453    new_shape: Resolution `(ny, nx)` of the destination domain grid; it defaults to `src_shape`.
3454    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3455  """
3456
3457  def translation_matrix(vector: _NDArray) -> _NDArray:
3458    matrix = np.eye(len(vector) + 1)
3459    matrix[:-1, -1] = vector
3460    return matrix
3461
3462  def scaling_matrix(scale: _NDArray) -> _NDArray:
3463    return np.diag(tuple(scale) + (1.0,))
3464
3465  def rotation_matrix_2d(angle: float) -> _NDArray:
3466    cos, sin = np.cos(angle), np.sin(angle)
3467    return np.array([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
3468
3469  src_shape = np.asarray(src_shape)
3470  new_shape = src_shape if new_shape is None else np.asarray(new_shape)
3471  _check_eq(src_shape.shape, (2,))
3472  _check_eq(new_shape.shape, (2,))
3473  half = np.array([0.5, 0.5])
3474  matrix = (
3475      translation_matrix(half)
3476      @ scaling_matrix(min(src_shape) / src_shape)
3477      @ rotation_matrix_2d(angle)
3478      @ scaling_matrix(scale * new_shape / min(new_shape))
3479      @ translation_matrix(-half)
3480  )
3481  assert np.allclose(matrix[-1], [0.0, 0.0, 1.0])
3482  return matrix

Return the 3x3 matrix mapping destination into a source unit domain.

The returned matrix accounts for the possibly non-square domain shapes.

Arguments:
  • src_shape: Resolution (ny, nx) of the source domain grid.
  • angle: Angle in radians (positive from x to y axis) applied when mapping the source domain onto the destination domain.
  • new_shape: Resolution (ny, nx) of the destination domain grid; it defaults to src_shape.
  • scale: Scaling factor applied when mapping the source domain onto the destination domain.
def rotate_image_about_center( image: np.ndarray, /, angle: float, *, new_shape: Optional[ArrayLike] = None, scale: float = 1.0, num_rotations: int = 1, **kwargs: Any) -> np.ndarray:
3485def rotate_image_about_center(
3486    image: _NDArray,
3487    /,
3488    angle: float,
3489    *,
3490    new_shape: _ArrayLike | None = None,
3491    scale: float = 1.0,
3492    num_rotations: int = 1,
3493    **kwargs: Any,
3494) -> _NDArray:
3495  """Return a copy of `image` rotated about its center.
3496
3497  Args:
3498    image: Source grid samples; the first two dimensions are spatial (ny, nx).
3499    angle: Angle in radians (positive from x to y axis) applied when mapping the source domain
3500      onto the destination domain.
3501    new_shape: Resolution `(ny, nx)` of the output grid; it defaults to `image.shape[:2]`.
3502    scale: Scaling factor applied when mapping the source domain onto the destination domain.
3503    num_rotations: Number of rotations (each by `angle`).  Successive resamplings are useful in
3504      analyzing the filtering quality.
3505    **kwargs: Additional parameters for `resample_affine`.
3506  """
3507  new_shape = image.shape[:2] if new_shape is None else np.asarray(new_shape)
3508  matrix = rotation_about_center_in_2d(image.shape[:2], angle, new_shape=new_shape, scale=scale)
3509  for _ in range(num_rotations):
3510    image = resample_affine(image, new_shape, matrix[:-1], **kwargs)
3511  return image

Return a copy of image rotated about its center.

Arguments:
  • image: Source grid samples; the first two dimensions are spatial (ny, nx).
  • angle: Angle in radians (positive from x to y axis) applied when mapping the source domain onto the destination domain.
  • new_shape: Resolution (ny, nx) of the output grid; it defaults to image.shape[:2].
  • scale: Scaling factor applied when mapping the source domain onto the destination domain.
  • num_rotations: Number of rotations (each by angle). Successive resamplings are useful in analyzing the filtering quality.
  • **kwargs: Additional parameters for resample_affine.
def pil_image_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str) -> np.ndarray:
3514def pil_image_resize(
3515    array: _ArrayLike,
3516    /,
3517    shape: Iterable[int],
3518    *,
3519    filter: str,
3520    boundary: str = 'natural',
3521    cval: float = 0.0,
3522) -> _NDArray:
3523  """Invoke `PIL.Image.resize` using the same parameters as `resize`."""
3524  import PIL.Image
3525
3526  if boundary != 'natural':
3527    raise ValueError(f"{boundary=} must equal 'natural'.")
3528  del cval
3529  array = np.asarray(array)
3530  assert 1 <= array.ndim <= 3
3531  assert np.issubdtype(array.dtype, np.floating)
3532  shape = tuple(shape)
3533  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3534  if array.ndim == 1:
3535    return pil_image_resize(array[None], (1, *shape), filter=filter)[0]
3536  if not hasattr(PIL.Image, 'Resampling'):  # Pillow<9.0
3537    PIL.Image.Resampling = PIL.Image
3538  filters = {
3539      'impulse': PIL.Image.Resampling.NEAREST,
3540      'box': PIL.Image.Resampling.BOX,
3541      'triangle': PIL.Image.Resampling.BILINEAR,
3542      'hamming1': PIL.Image.Resampling.HAMMING,
3543      'cubic': PIL.Image.Resampling.BICUBIC,
3544      'lanczos3': PIL.Image.Resampling.LANCZOS,
3545  }
3546  if filter not in filters:
3547    raise ValueError(f'{filter=} not in {filters=}.')
3548  pil_resample = filters[filter]
3549  if array.ndim == 2:
3550    return np.array(
3551        PIL.Image.fromarray(array).resize(shape[::-1], resample=pil_resample), array.dtype
3552    )
3553  stack = []
3554  for channel in np.moveaxis(array, -1, 0):
3555    pil_image = PIL.Image.fromarray(channel).resize(shape[::-1], resample=pil_resample)
3556    stack.append(np.array(pil_image, array.dtype))
3557  return np.dstack(stack)

Invoke PIL.Image.resize using the same parameters as resize.

def cv_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str) -> np.ndarray:
3560def cv_resize(
3561    array: _ArrayLike,
3562    /,
3563    shape: Iterable[int],
3564    *,
3565    filter: str,
3566    boundary: str = 'clamp',
3567    cval: float = 0.0,
3568) -> _NDArray:
3569  """Invoke `cv.resize` using the same parameters as `resize`."""
3570  import cv2 as cv
3571
3572  if boundary != 'clamp':
3573    raise ValueError(f"{boundary=} must equal 'clamp'.")
3574  del cval
3575  array = np.asarray(array)
3576  assert 1 <= array.ndim <= 3
3577  shape = tuple(shape)
3578  _check_eq(len(shape), 2 if array.ndim >= 2 else 1)
3579  if array.ndim == 1:
3580    return cv_resize(array[None], (1, *shape), filter=filter)[0]
3581  filters = {
3582      'impulse': cv.INTER_NEAREST,  # Or consider cv.INTER_NEAREST_EXACT.
3583      'triangle': cv.INTER_LINEAR_EXACT,  # Or just cv.INTER_LINEAR.
3584      'trapezoid': cv.INTER_AREA,
3585      'sharpcubic': cv.INTER_CUBIC,
3586      'lanczos4': cv.INTER_LANCZOS4,
3587  }
3588  if filter not in filters:
3589    raise ValueError(f'{filter=} not in {filters=}.')
3590  interpolation = filters[filter]
3591  result = cv.resize(array, shape[::-1], interpolation=interpolation)
3592  if array.ndim == 3 and result.ndim == 2:
3593    assert array.shape[2] == 1
3594    return result[..., None]  # Add back the last dimension dropped by cv.resize().
3595  return result

Invoke cv.resize using the same parameters as resize.

def scipy_ndimage_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, boundary: str = 'reflect', cval: float = 0.0) -> np.ndarray:
3598def scipy_ndimage_resize(
3599    array: _ArrayLike,
3600    /,
3601    shape: Iterable[int],
3602    *,
3603    filter: str,
3604    boundary: str = 'reflect',
3605    cval: float = 0.0,
3606    scale: float | Iterable[float] = 1.0,
3607    translate: float | Iterable[float] = 0.0,
3608) -> _NDArray:
3609  """Invoke `scipy.ndimage.map_coordinates` using the same parameters as `resize`."""
3610  array = np.asarray(array)
3611  shape = tuple(shape)
3612  assert 1 <= len(shape) <= array.ndim
3613  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3614  if filter not in filters:
3615    raise ValueError(f'{filter=} not in {filters=}.')
3616  order = filters[filter]
3617  boundaries = {'reflect': 'reflect', 'wrap': 'grid-wrap', 'clamp': 'nearest', 'border': 'constant'}
3618  if boundary not in boundaries:
3619    raise ValueError(f'{boundary=} not in {boundaries=}.')
3620  mode = boundaries[boundary]
3621  shape_all = shape + array.shape[len(shape) :]
3622  coords = np.moveaxis(np.indices(shape_all, array.dtype), 0, -1)
3623  coords[..., : len(shape)] = (
3624      (coords[..., : len(shape)] + 0.5) / shape - np.asarray(translate)
3625  ) / np.asarray(scale) * np.array(array.shape)[: len(shape)] - 0.5
3626  coords = np.moveaxis(coords, -1, 0)
3627  return scipy.ndimage.map_coordinates(array, coords, order=order, mode=mode, cval=cval)

Invoke scipy.ndimage.map_coordinates using the same parameters as resize.

def skimage_transform_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, boundary: str = 'reflect', cval: float = 0.0) -> np.ndarray:
3630def skimage_transform_resize(
3631    array: _ArrayLike,
3632    /,
3633    shape: Iterable[int],
3634    *,
3635    filter: str,
3636    boundary: str = 'reflect',
3637    cval: float = 0.0,
3638) -> _NDArray:
3639  """Invoke `skimage.transform.resize` using the same parameters as `resize`."""
3640  import skimage.transform
3641
3642  array = np.asarray(array)
3643  shape = tuple(shape)
3644  assert 1 <= len(shape) <= array.ndim
3645  filters = {'box': 0, 'triangle': 1} | {f'cardinal{i}': i for i in range(2, 6)}
3646  if filter not in filters:
3647    raise ValueError(f'{filter=} not in {filters=}.')
3648  order = filters[filter]
3649  boundaries = {'reflect': 'symmetric', 'wrap': 'wrap', 'clamp': 'edge', 'border': 'constant'}
3650  if boundary not in boundaries:
3651    raise ValueError(f'{boundary=} not in {boundaries=}.')
3652  mode = boundaries[boundary]
3653  shape_all = shape + array.shape[len(shape) :]
3654  # Default anti_aliasing=None automatically enables (poor) Gaussian prefilter if downsampling.
3655  # clip=False is the default behavior in `resampler` if the output type is non-integer.
3656  return skimage.transform.resize(
3657      array, shape_all, order=order, mode=mode, cval=cval, clip=False
3658  )  # type: ignore[no-untyped-call]

Invoke skimage.transform.resize using the same parameters as resize.

def tf_image_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, antialias: bool = True) -> TensorflowTensor:
3673def tf_image_resize(
3674    array: _ArrayLike,
3675    /,
3676    shape: Iterable[int],
3677    *,
3678    filter: str,
3679    boundary: str = 'natural',
3680    cval: float = 0.0,
3681    antialias: bool = True,
3682) -> _TensorflowTensor:
3683  """Invoke `tf.image.resize` using the same parameters as `resize`."""
3684  import tensorflow as tf
3685
3686  if filter not in _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER:
3687    raise ValueError(f'{filter=} not in {_TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER=}.')
3688  if boundary != 'natural':
3689    raise ValueError(f"{boundary=} must equal 'natural'.")
3690  del cval
3691  array2 = tf.convert_to_tensor(array)
3692  ndim = len(array2.shape)
3693  del array
3694  assert 1 <= ndim <= 3
3695  shape = tuple(shape)
3696  _check_eq(len(shape), 2 if ndim >= 2 else 1)
3697  match ndim:
3698    case 1:
3699      return tf_image_resize(array2[None], (1, *shape), filter=filter, antialias=antialias)[0]
3700    case 2:
3701      return tf_image_resize(array2[..., None], shape, filter=filter, antialias=antialias)[..., 0]
3702    case _:
3703      method = _TENSORFLOW_IMAGE_RESIZE_METHOD_FROM_FILTER[filter]
3704      return tf.image.resize(array2, shape, method=method, antialias=antialias)

Invoke tf.image.resize using the same parameters as resize.

def torch_nn_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, antialias: bool = False) -> TorchTensor:
3715def torch_nn_resize(
3716    array: _ArrayLike,
3717    /,
3718    shape: Iterable[int],
3719    *,
3720    filter: str,
3721    boundary: str = 'clamp',
3722    cval: float = 0.0,
3723    antialias: bool = False,
3724) -> _TorchTensor:
3725  """Invoke `torch.nn.functional.interpolate` using the same parameters as `resize`."""
3726  import torch
3727
3728  if filter not in _TORCH_INTERPOLATE_MODE_FROM_FILTER:
3729    raise ValueError(f'{filter=} not in {_TORCH_INTERPOLATE_MODE_FROM_FILTER=}.')
3730  if boundary != 'clamp':
3731    raise ValueError(f"{boundary=} must equal 'clamp'.")
3732  del cval
3733  a = torch.as_tensor(array)
3734  del array
3735  assert 1 <= a.ndim <= 3
3736  shape = tuple(shape)
3737  _check_eq(len(shape), 2 if a.ndim >= 2 else 1)
3738  mode = _TORCH_INTERPOLATE_MODE_FROM_FILTER[filter]
3739
3740  def local_resize(a: _TorchTensor) -> _TorchTensor:
3741    # For upsampling, BILINEAR antialias is same PSNR and slower,
3742    #  and BICUBIC antialias is worse PSNR and faster.
3743    # For downsampling, antialias improves PSNR for both BILINEAR and BICUBIC.
3744    # Default align_corners=None corresponds to False which is what we desire.
3745    return torch.nn.functional.interpolate(a, shape, mode=mode, antialias=antialias)
3746
3747  match a.ndim:
3748    case 1:
3749      shape = (1, *shape)
3750      return local_resize(a[None, None, None])[0, 0, 0]
3751    case 2:
3752      return local_resize(a[None, None])[0, 0]
3753    case _:
3754      return local_resize(a.moveaxis(2, 0)[None])[0].moveaxis(0, 2)

Invoke torch.nn.functional.interpolate using the same parameters as resize.

def jax_image_resize( array: ArrayLike, /, shape: Iterable[int], *, filter: str, scale: float | Iterable[float] = 1.0, translate: float | Iterable[float] = 0.0) -> JaxArray:
3757def jax_image_resize(
3758    array: _ArrayLike,
3759    /,
3760    shape: Iterable[int],
3761    *,
3762    filter: str,
3763    boundary: str = 'natural',
3764    cval: float = 0.0,
3765    scale: float | Iterable[float] = 1.0,
3766    translate: float | Iterable[float] = 0.0,
3767) -> _JaxArray:
3768  """Invoke `jax.image.scale_and_translate` using the same parameters as `resize`."""
3769  import jax.image
3770  import jax.numpy as jnp
3771
3772  filters = 'triangle cubic lanczos3 lanczos5'.split()
3773  if filter not in filters:
3774    raise ValueError(f'{filter=} not in {filters=}.')
3775  if boundary != 'natural':
3776    raise ValueError(f"{boundary=} must equal 'natural'.")
3777  # When `scale` or `translate` are applied, any region outside the unit domain is assigned value 0.
3778  # To be consistent, the parameter `cval` must be zero.
3779  if scale != 1.0 and cval != 0.0:
3780    raise ValueError(f'Non-unity {scale=} requires that {cval=} be zero.')
3781  if translate != 0.0 and cval != 0.0:
3782    raise ValueError(f'Nonzero {translate=} requires that {cval=} be zero.')
3783  array2 = jnp.asarray(array)
3784  del array
3785  shape = tuple(shape)
3786  assert len(shape) <= array2.ndim
3787  completed_shape = shape + (1,) * (array2.ndim - len(shape))
3788  spatial_dims = list(range(len(shape)))
3789  scale2 = np.broadcast_to(np.array(scale), len(shape))
3790  scale2 = scale2 / np.array(array2.shape[: len(shape)]) * np.array(shape)
3791  translate2 = np.broadcast_to(np.array(translate), len(shape))
3792  translate2 = translate2 * np.array(shape)
3793  return jax.image.scale_and_translate(
3794      array2, completed_shape, spatial_dims, scale2, translate2, filter
3795  )

Invoke jax.image.scale_and_translate using the same parameters as resize.