如果用OpenCV-Python进行图像的离散傅里叶变换与逆变换其实还蛮简单的,流程就是上图所示,值得注意的是,如果是多通道的图像,譬如多光谱、高光谱图像,需要对每个通道都进行傅里叶变换,最后再聚合,如果只是RGB,可以用如下方式合成灰度图,只需要对灰度图做处理即可。
img1 = 0.2126 * image1[:,:,0] + 0.7152 * image1[:,:,1] + 0.0722 * image1[:,:,2]
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
# 测试图像
ori=cv.imread(r"F:\ori.jpg")
# numpy 中的 fft 需要输入灰度图,我们需要将图像分割成不同的通道
def getRGBDFT(img):
# cv2默认的图像通道是BGR,需要进行转换
img=cv.cvtColor(img,cv.COLOR_BGR2RGB)
# 分离通道
r,g,b=cv.split(img)
# 对每个通道进行傅里叶变换
f_r,f_g,f_b=_dft(r),_dft(g),_dft(b)
# 组合通道,还是以bgr格式返回
return cv.merge([f_b,f_g,f_r])
def _dft(img):
f=cv.dft(np.float32(img),flags=cv.DFT_COMPLEX_OUTPUT)
# 计算幅度谱
magnitude=cv.magnitude(f[:,:,0],f[:,:,1])
# 对数变换增强对比度
res=np.log(magnitude+1)
# 移动低频分量至中心
return np.fft.fftshift(res)
out=getRGBDFT(ori)
# 选取
plt.subplot(121),plt.imshow(ori[:,:,0],cmap='gray')
plt.title("Ori"),plt.xticks([]),plt.yticks([])
plt.subplot(122),plt.imshow(out[:,:,0],cmap='gray')
plt.title("Magnitude"),plt.xticks([]),plt.yticks([])
plt.show()
我们查看B通道的图像与傅里叶幅度谱:
接下来要进行傅里叶逆变换,代码如下:
def _idft(img):
img=np.fft.ifftshift(img)
img=cv.idft(img)
return cv.magnitude(img[:,:,0],img[:,:,1])
若要在频域上做处理,可以添加掩膜:
def _Mask(img,type=None,d=2,size=4):
row,col=img.shape[:-1]
if type==None:
return np.ones((row,col,d),np.uint8)
if type == "LPF":
mask=np.zeros((row,col,d),np.uint8)
mask[row//size:row//size*(size-1),col//size:col//size*(size-1)]=1
elif type=="HPF":
mask = np.ones((row, col,d), np.uint8)
mask[row//size:row//size*(size-1),col//size:col//size*(size-1)] = 0
else:
mask=np.ones((row,col,d),np.uint8)
mask[row // 2-30:row // 2+30, col // 2-30:col // 2+30] = 0
return mask
def _idft(img,mask=None):
# img=_Mask(img,type)*img
if mask!=None:
img=mask*np.fft.ifftshift(img)
else:
img=np.fft.ifftshift(img)
img=cv.idft(img)
return cv.magnitude(img[:,:,0],img[:,:,1])
彩色图像结果:
如果想要用A图的高频细节替换B图,可以如下处理:
def _normal(img):
a,b=img.max(),img.min()
return np.clip((img-b)/(a-b),0,1)
def swap(ori,aug):
oriFFT=getRGBDFT(ori)
augFFT=getRGBDFT(aug)
HPF=_Mask(ori,"HPF",size=3)
LPF=_Mask(ori,"LPF",size=16)
res=[augFFT[i]*LPF+oriFFT[i]*HPF for i in range(len(oriFFT))]
res=[_normal(_idft(i)) for i in res]
return cv.merge(res)
完整代码如下:
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
# 测试图像
ori=cv.imread(r"F:\ori.jpg")
aug=cv.imread(r"F:\129.jpg")
# numpy 中的 fft 需要输入灰度图,我们需要将图像分割成不同的通道
def getRGBDFT(img):
# cv2默认的图像通道是BGR,需要进行转换
img=cv.cvtColor(img,cv.COLOR_BGR2RGB)
# 分离通道
r,g,b=cv.split(img)
# 对每个通道进行傅里叶变换
f_r,f_g,f_b=_dft(r,False),_dft(g,False),_dft(b,False)
# 组合通道,还是以bgr格式返回
# return cv.merge([f_b,f_g,f_r])
return [f_b,f_g,f_r]
def _dft(img,to_show=True):
f=cv.dft(np.float32(img),flags=cv.DFT_COMPLEX_OUTPUT)
# 计算幅度谱
if to_show:
# 对数变换增强对比度
magnitude = cv.magnitude(f[:, :, 0], f[:, :, 1])
f=np.log(magnitude+1)
# 移动低频分量至中心
return np.fft.fftshift(f)
def _idft(img,mask=None):
# img=_Mask(img,type)*img
if mask!=None:
img=mask*np.fft.ifftshift(img)
else:
img=np.fft.ifftshift(img)
img=cv.idft(img)
return cv.magnitude(img[:,:,0],img[:,:,1])
def _Mask(img,type=None,d=2,size=4):
row,col=img.shape[:-1]
if type==None:
return np.ones((row,col,d),np.uint8)
if type == "LPF":
mask=np.zeros((row,col,d),np.uint8)
mask[row//size:row//size*(size-1),col//size:col//size*(size-1)]=1
elif type=="HPF":
mask = np.ones((row, col,d), np.uint8)
mask[row//size:row//size*(size-1),col//size:col//size*(size-1)] = 0
else:
mask=np.ones((row,col,d),np.uint8)
mask[row // 2-30:row // 2+30, col // 2-30:col // 2+30] = 0
return mask
def _normal(img):
a,b=img.max(),img.min()
return np.clip((img-b)/(a-b),0,1)
def getRGBIDFT(img,type):
fft=getRGBDFT(img)
ifft=[_normal(_idft(i,type)) for i in fft]
return cv.merge(ifft)
def swap(ori,aug):
oriFFT=getRGBDFT(ori)
augFFT=getRGBDFT(aug)
HPF=_Mask(ori,"HPF",size=3)
LPF=_Mask(ori,"LPF",size=4)
res=[augFFT[i]*LPF+oriFFT[i]*HPF for i in range(len(oriFFT))]
res=[_normal(_idft(i)) for i in res]
return cv.merge(res)
res=swap(ori,aug)
plt.subplot(131),plt.imshow(ori)
plt.title("Ori"),plt.xticks([]),plt.yticks([])
plt.subplot(132),plt.imshow(aug)
plt.title("Aug"),plt.xticks([]),plt.yticks([])
plt.subplot(133),plt.imshow(res)
plt.title("res"),plt.xticks([]),plt.yticks([])
plt.show()