181 lines
6.5 KiB
Python
181 lines
6.5 KiB
Python
from scipy._lib._array_api import (
|
|
array_namespace, is_numpy, xp_unsupported_param_msg, is_complex
|
|
)
|
|
from . import _pocketfft
|
|
import numpy as np
|
|
|
|
|
|
def _validate_fft_args(workers, plan, norm):
|
|
if workers is not None:
|
|
raise ValueError(xp_unsupported_param_msg("workers"))
|
|
if plan is not None:
|
|
raise ValueError(xp_unsupported_param_msg("plan"))
|
|
if norm is None:
|
|
norm = 'backward'
|
|
return norm
|
|
|
|
|
|
# pocketfft is used whenever SCIPY_ARRAY_API is not set,
|
|
# or x is a NumPy array or array-like.
|
|
# When SCIPY_ARRAY_API is set, we try to use xp.fft for CuPy arrays,
|
|
# PyTorch arrays and other array API standard supporting objects.
|
|
# If xp.fft does not exist, we attempt to convert to np and back to use pocketfft.
|
|
|
|
def _execute_1D(func_str, pocketfft_func, x, n, axis, norm, overwrite_x, workers, plan):
|
|
xp = array_namespace(x)
|
|
|
|
if is_numpy(xp):
|
|
x = np.asarray(x)
|
|
return pocketfft_func(x, n=n, axis=axis, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
norm = _validate_fft_args(workers, plan, norm)
|
|
if hasattr(xp, 'fft'):
|
|
xp_func = getattr(xp.fft, func_str)
|
|
return xp_func(x, n=n, axis=axis, norm=norm)
|
|
|
|
x = np.asarray(x)
|
|
y = pocketfft_func(x, n=n, axis=axis, norm=norm)
|
|
return xp.asarray(y)
|
|
|
|
|
|
def _execute_nD(func_str, pocketfft_func, x, s, axes, norm, overwrite_x, workers, plan):
|
|
xp = array_namespace(x)
|
|
|
|
if is_numpy(xp):
|
|
x = np.asarray(x)
|
|
return pocketfft_func(x, s=s, axes=axes, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
norm = _validate_fft_args(workers, plan, norm)
|
|
if hasattr(xp, 'fft'):
|
|
xp_func = getattr(xp.fft, func_str)
|
|
return xp_func(x, s=s, axes=axes, norm=norm)
|
|
|
|
x = np.asarray(x)
|
|
y = pocketfft_func(x, s=s, axes=axes, norm=norm)
|
|
return xp.asarray(y)
|
|
|
|
|
|
def fft(x, n=None, axis=-1, norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return _execute_1D('fft', _pocketfft.fft, x, n=n, axis=axis, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
|
|
def ifft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, *,
|
|
plan=None):
|
|
return _execute_1D('ifft', _pocketfft.ifft, x, n=n, axis=axis, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
|
|
def rfft(x, n=None, axis=-1, norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return _execute_1D('rfft', _pocketfft.rfft, x, n=n, axis=axis, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
|
|
def irfft(x, n=None, axis=-1, norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return _execute_1D('irfft', _pocketfft.irfft, x, n=n, axis=axis, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
|
|
def hfft(x, n=None, axis=-1, norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return _execute_1D('hfft', _pocketfft.hfft, x, n=n, axis=axis, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
|
|
def ihfft(x, n=None, axis=-1, norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return _execute_1D('ihfft', _pocketfft.ihfft, x, n=n, axis=axis, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
|
|
def fftn(x, s=None, axes=None, norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return _execute_nD('fftn', _pocketfft.fftn, x, s=s, axes=axes, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
|
|
|
|
def ifftn(x, s=None, axes=None, norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return _execute_nD('ifftn', _pocketfft.ifftn, x, s=s, axes=axes, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
|
|
def fft2(x, s=None, axes=(-2, -1), norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return fftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
|
|
|
|
|
|
def ifft2(x, s=None, axes=(-2, -1), norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return ifftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
|
|
|
|
|
|
def rfftn(x, s=None, axes=None, norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return _execute_nD('rfftn', _pocketfft.rfftn, x, s=s, axes=axes, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
|
|
def rfft2(x, s=None, axes=(-2, -1), norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return rfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
|
|
|
|
|
|
def irfftn(x, s=None, axes=None, norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return _execute_nD('irfftn', _pocketfft.irfftn, x, s=s, axes=axes, norm=norm,
|
|
overwrite_x=overwrite_x, workers=workers, plan=plan)
|
|
|
|
|
|
def irfft2(x, s=None, axes=(-2, -1), norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return irfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
|
|
|
|
|
|
def _swap_direction(norm):
|
|
if norm in (None, 'backward'):
|
|
norm = 'forward'
|
|
elif norm == 'forward':
|
|
norm = 'backward'
|
|
elif norm != 'ortho':
|
|
raise ValueError('Invalid norm value %s; should be "backward", '
|
|
'"ortho", or "forward".' % norm)
|
|
return norm
|
|
|
|
|
|
def hfftn(x, s=None, axes=None, norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
xp = array_namespace(x)
|
|
if is_numpy(xp):
|
|
x = np.asarray(x)
|
|
return _pocketfft.hfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
|
|
if is_complex(x, xp):
|
|
x = xp.conj(x)
|
|
return irfftn(x, s, axes, _swap_direction(norm),
|
|
overwrite_x, workers, plan=plan)
|
|
|
|
|
|
def hfft2(x, s=None, axes=(-2, -1), norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return hfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
|
|
|
|
|
|
def ihfftn(x, s=None, axes=None, norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
xp = array_namespace(x)
|
|
if is_numpy(xp):
|
|
x = np.asarray(x)
|
|
return _pocketfft.ihfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
|
|
return xp.conj(rfftn(x, s, axes, _swap_direction(norm),
|
|
overwrite_x, workers, plan=plan))
|
|
|
|
def ihfft2(x, s=None, axes=(-2, -1), norm=None,
|
|
overwrite_x=False, workers=None, *, plan=None):
|
|
return ihfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
|