pytorch旧版本(1.7之前)中有一个函数torch.rfft(),但是新版本(1.8、1.9)中被移除了,添加了torch.fft.rfft(),但它并不是旧版的替代品。
torch.fft
label_fft1 = torch.rfft(label_img4, signal_ndim=2, normalized=False, onesided=False)
参数说明:
input (Tensor) – the input tensor of at least signal_ndim dimensions(输入数据)
signal_ndim (int) – the number of dimensions in each signal. signal_ndim can only be 1, 2 or 3(输入数据的维度)
normalized (bool, optional) – controls whether to return normalized results. Default: False(是否进行归一化处理)
onesided (bool, optional) – controls whether to return half of results to avoid redundancy. Default: True(fft后因共轭对称,返回一半的数据长度)
在1.7版本torch.rfft中,有一个warning,表示在新版中,要“one-side ouput”的话用torch.fft.rfft(),要“two-side ouput”的话用torch.fft.fft()。这里的one/two side,跟旧版的onesided参数对应。
新版pytorch中,各种在新版本中各种fft的解释如下:
fft, which computes a complex FFT over a single dimension, and ifft, its inverse
the more general fftn and ifftn, which support multiple dimensions
The “real” FFT functions, rfft, irfft, rfftn, irfftn, designed to work with signals that are real-valued in their time domains
The “Hermitian” FFT functions, hfft and ihfft, designed to work with signals that are real-valued in their frequency domains
Helper functions, like fftfreq, rfftfreq, fftshift, ifftshift, that make it easier to manipulate signals
上述描述中:
fft和ifft就是计算单维的复数FFT,fftn和ifftn计算多维的复数FFT
rfft、irfft、rfftn、irfftn计算实数FFT
对于ifft,需要注意的是, 新版中要求输入的数据类型为complex,即要求输入的维度不跟旧版一样将复数的实部和虚部存成二维向量(即在最后多出一个值为2的维度)。如果说输入时以二维向量存复数,则需要使用torch.complex()将其转化成complex类型。
import torch
input = torch.randn(1, 3, 64, 64)
### 旧版 ###
# 参数normalized对这篇文章的结论没有影响,加上只是为了跟文章开头同步
# 时域=>频域
output_fft_old = torch.rfft(input, signal_ndim=2, normalized=False, onesided=False)
# 频域=>时域
output_ifft_old = torch.irfft(output_fft_old , signal_ndim=2, normalized=False, onesided=False)
### 新版 ###
# 时域=>频域
output_fft_new = torch.fft.fft2(input, dim=(-2, -1))
output_fft_new_2dim = torch.stack((output_fft_new.real, output_fft_new.imag), -1) # 根据需求将复数形式转成数组形式
# 频域=>时域
output_ifft_new = torch.fft.ifft2(output_fft_new, dim=(-2, -1)) # 输入为复数形式
output_ifft_new = torch.fft.ifft2(torch.complex(output_fft_new_2dim[..., 0], # 输入为数组形式
output_fft_new_2dim[..., 1]), dim=(-2, -1))
# 注意最后输出的结果还是为复数,需要将虚部丢弃
output_ifft_new = output_ifft_new.real
print((output_ifft_new - input).sum()) # 输出应该趋近于0(因为存在数值误差)
注意函数中的参数,新版的dim是要进行FFT的第几维度,一般是最后一维。