import matplotlib.pyplot as plt
import pywt

Plot Scaling and Wavelet functions for the Wavelets

w = pywt.Wavelet('db4')
print("Wavelet", w)
print("Filter bank", w.filter_bank)
print("dec_lo", w.dec_lo)
print("dec_hi", w.dec_hi)
print("rec_lo", w.rec_lo)
print("rec_hi", w.rec_hi)


# plot the decomposition and reconstruction filters
plt.figure(figsize=(12, 4))
plt.subplot(121)
plt.stem(w.dec_lo, label="Decomposition Low Pass", linefmt="-r")
plt.stem(w.dec_hi, label="Decomposition High Pass")
plt.legend()
plt.title("Decomposition Filters")
plt.grid(True)


plt.subplot(122)
plt.stem(w.rec_lo, label="Reconstruction Low Pass", linefmt="-r")
plt.stem(w.rec_hi, label="Reconstruction High Pass")
plt.legend()
plt.title("Reconstruction Filters")
plt.grid(True)




phi, psi, x = w.wavefun(level=1)


print("Scaling Coefficients", phi)
print("Detail Coefficients", psi)
print("Number of vanishing moments in detail", w.vanishing_moments_psi)
print("Number of vanishing moments in scale", w.vanishing_moments_phi)


# plot the wavelet as stem plot with linespace
plt.figure(figsize=(12, 4))
plt.subplot(121)
plt.stem(phi, label="Scaling Function")
plt.grid()
plt.legend()
plt.subplot(122)
plt.stem(psi, label="Wavelet Function")
plt.legend()
plt.suptitle("coif2 Wavelet")
plt.grid(True)
Wavelet Wavelet db4
  Family name:    Daubechies
  Short name:     db
  Filters length: 8
  Orthogonal:     True
  Biorthogonal:   True
  Symmetry:       asymmetric
  DWT:            True
  CWT:            False
Filter bank ([-0.010597401785069032, 0.0328830116668852, 0.030841381835560764, -0.18703481171909309, -0.027983769416859854, 0.6308807679298589, 0.7148465705529157, 0.2303778133088965], [-0.2303778133088965, 0.7148465705529157, -0.6308807679298589, -0.027983769416859854, 0.18703481171909309, 0.030841381835560764, -0.0328830116668852, -0.010597401785069032], [0.2303778133088965, 0.7148465705529157, 0.6308807679298589, -0.027983769416859854, -0.18703481171909309, 0.030841381835560764, 0.0328830116668852, -0.010597401785069032], [-0.010597401785069032, -0.0328830116668852, 0.030841381835560764, 0.18703481171909309, -0.027983769416859854, -0.6308807679298589, 0.7148465705529157, -0.2303778133088965])
dec_lo [-0.010597401785069032, 0.0328830116668852, 0.030841381835560764, -0.18703481171909309, -0.027983769416859854, 0.6308807679298589, 0.7148465705529157, 0.2303778133088965]
dec_hi [-0.2303778133088965, 0.7148465705529157, -0.6308807679298589, -0.027983769416859854, 0.18703481171909309, 0.030841381835560764, -0.0328830116668852, -0.010597401785069032]
rec_lo [0.2303778133088965, 0.7148465705529157, 0.6308807679298589, -0.027983769416859854, -0.18703481171909309, 0.030841381835560764, 0.0328830116668852, -0.010597401785069032]
rec_hi [-0.010597401785069032, -0.0328830116668852, 0.030841381835560764, 0.18703481171909309, -0.027983769416859854, -0.6308807679298589, 0.7148465705529157, -0.2303778133088965]
Scaling Coefficients [ 0.          0.32580343  1.01094572  0.89220014 -0.03957503 -0.26450717
  0.0436163   0.0465036  -0.01498699  0.          0.          0.
  0.          0.          0.        ]
Detail Coefficients [ 0.         -0.01498699 -0.0465036   0.0436163   0.26450717 -0.03957503
 -0.89220014  1.01094572 -0.32580343  0.          0.          0.
  0.          0.          0.        ]
Number of vanishing moments in detail 4
Number of vanishing moments in scale 0

Visualizing 2D Wavelets

import numpy as np
import pywt
import matplotlib.pyplot as plt


for type in ['haar', 'db4', 'sym4', 'coif2']:
    wavelet = pywt.Wavelet(type)
    phi, psi, x = wavelet.wavefun(level=1)


    # Create a 2D grid
    xx, yy = np.meshgrid(x, x)


    labels = ["LL", "LH", "HL", "HH"]
    plt.figure(figsize=(12, 3))
    i = 0
    for s in [phi, psi]:
        for w in [phi, psi]:
            # Calculate 2D wavelet
            wavelet = np.outer(s, w)


            plt.subplot(1, 4, i + 1)
            plt.imshow(wavelet, cmap='coolwarm', extent=[x.min(), x.max(), x.min(), x.max()])
            plt.title(labels[i])
            plt.xlabel('X')
            plt.colorbar(shrink=0.7)
            plt.ylabel('Y')
            plt.axis('off')


            # # Plot the 3D surface
            # ax = plt.subplot(2, 4, i+5, projection='3d')
            # surf = ax.plot_surface(xx, yy, wavelet, cmap='coolwarm')
            # # ax.set_title(labels[i])
            # ax.set_xlabel('X')
            # ax.set_ylabel('Y')
            # ax.set_zlabel('Amplitude')
            # plt.colorbar(surf, shrink=0.5, aspect=5, label='Amplitude')


            i += 1


    plt.suptitle(f"{type} Wavelet")
    plt.show()

import dtcwt.coeffs


coeefs = dtcwt.coeffs.biort('near_sym_a') # h0, g0, h1, g1
phi0, phi1, psi0, psi1 = [c.reshape(-1,) for c in coeefs]


labels = ["LL", "LH", "HL", "HH"]


tree_a = []
plt.figure(figsize=(12, 3))
i = 0
for s in [psi0, psi0]:
    for w in [phi0, psi0]:
        # Calculate 2D wavelet
        wavelet = np.outer(s, w)
        print(wavelet.shape)
        tree_a.append(wavelet)


        plt.subplot(1, 4, i + 1)
        plt.imshow(wavelet, cmap='coolwarm', extent=[x.min(), x.max(), x.min(), x.max()])
        plt.title(labels[i])
        plt.xlabel('X')
        plt.colorbar(shrink=0.7)
        plt.ylabel('Y')
        plt.axis('off')


        i += 1


plt.suptitle(f"DTCWT Real Wavelet")
plt.show()


tree_b = []
plt.figure(figsize=(12, 3))
i = 0
for s in [phi1, psi1]:
    for w in [phi1, psi1]:
        # Calculate 2D wavelet
        wavelet = np.outer(s, w)
        print(wavelet.shape)
        tree_b.append(wavelet)


        plt.subplot(1, 4, i + 1)
        plt.imshow(wavelet, cmap='coolwarm', extent=[x.min(), x.max(), x.min(), x.max()])
        plt.title(labels[i])
        plt.xlabel('X')
        plt.colorbar(shrink=0.7)
        plt.ylabel('Y')
        plt.axis('off')


        i += 1


plt.suptitle(f"DTCWT Imag Wavelet")
plt.show()


# Wavelts
plt.figure(figsize=(12, 6))
i = 0
for s in [phi1, psi1]:
    for w in [phi1, psi1]:
        # Calculate 2D wavelet
        wavelet1 = tree_a[i] + tree_b[3 - i] 
        wavelet2 = tree_a[i] - tree_b[3 - i] 


        plt.subplot(2, 4, i + 1)
        plt.imshow(wavelet1, cmap='coolwarm', extent=[x.min(), x.max(), x.min(), x.max()])
        plt.title(labels[i])
        plt.xlabel('X')
        plt.colorbar(shrink=0.7)
        plt.ylabel('Y')
        plt.axis('off')


        plt.subplot(2, 4, i + 5)
        plt.imshow(wavelet2, cmap='coolwarm', extent=[x.min(), x.max(), x.min(), x.max()])
        plt.title(labels[i])
        plt.xlabel('X')
        plt.colorbar(shrink=0.7)
        plt.ylabel('Y')
        plt.axis('off')


        i += 1


plt.suptitle(f"DTCWT Wavelets")
plt.show()

# Load the mandrill image
mandrill = plt.imread("../star.png")
mandrill = mandrill.mean(axis=2)


# Show mandrill
plt.figure(1)
plt.axis('off')
plt.title("Sample Image")
plt.imshow(mandrill, cmap='gray', clim=(0,1))


LL, wavelets_haar = pywt.dwt2(mandrill, 'haar')
print(len(wavelets_haar))


# Show the absolute images for each direction in level 2.
# Note that the 2nd level has index 1 since the 1st has index 0.
labels = [0, 90, 45]
plt.figure(2, figsize=(12, 4))
plt.suptitle("Haar Bands")
for slice_idx in range(len(wavelets_haar)):
    plt.subplot(1, 3, slice_idx + 1)
    plt.axis('off')
    plt.title(f"Angle {labels[slice_idx]}")
    plt.imshow(10 * (np.abs(wavelets_haar[slice_idx])), cmap="gray", clim=(0, 1))

DTCWT Demonstration

# Load the mandrill image
mandrill = plt.imread("../star.png")
mandrill = mandrill.mean(axis=2)


# Show mandrill
plt.figure(1)
plt.axis('off')
plt.title("Sample Image")
plt.imshow(mandrill, cmap='gray', clim=(0,1))


import dtcwt
import numpy as np
transform = dtcwt.Transform2d()


# Compute two levels of dtcwt with the defaul wavelet family
mandrill_t = transform.forward(mandrill, nlevels=2)


# Show the absolute images for each direction in level 2.
# Note that the 2nd level has index 1 since the 1st has index 0.
labels = [15, 45, 75, 105, 135, 165]
plt.figure(2, figsize=(12, 3))
plt.suptitle("DT CWT Bands")
for slice_idx in range(mandrill_t.highpasses[1].shape[2]):
    plt.subplot(1, 6, slice_idx + 1)
    plt.axis('off')
    plt.title(f"Angle {labels[slice_idx]}")
    plt.imshow(5 * (np.abs(mandrill_t.highpasses[1][:,:,slice_idx])), cmap="gray", clim=(0, 1))






# Show the phase images for each direction in level 2.
plt.figure(3)
for slice_idx in range(mandrill_t.highpasses[1].shape[2]):
    plt.subplot(1, 6, slice_idx + 1)
    plt.imshow(np.angle(mandrill_t.highpasses[1][:,:,slice_idx]), cmap="viridis", clim=(-np.pi, np.pi))

import numpy as np
from scipy.signal import hilbert


# Input signal
signal = np.array([-0.05, 0.25, 0.6, 0.25, -0.05])


# Compute the Hilbert Transform
analytic_signal = hilbert(signal)
hilbert_transform = np.imag(analytic_signal)


# Display the results
print("Original Signal: ", signal)
print("Hilbert Transform: ", hilbert_transform)
Original Signal:  [-0.05  0.25  0.6   0.25 -0.05]
Hilbert Transform:  [-1.33803035e-01 -3.56506308e-01  2.22044605e-17  3.56506308e-01
  1.33803035e-01]
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.stem(h0)
plt.title("Analysis H0")


plt.subplot(2, 2, 2)
plt.stem(h1)
plt.title("Analysis H1")


plt.subplot(2, 2, 3)
plt.stem(g0)
plt.title("Recon G0")


plt.subplot(2, 2, 4)
plt.stem(g1)
plt.title("Recon G1")


plt.show()

plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
a = np.asarray([0] + list(h0.reshape(-1, )) + [0])
b = h1.reshape(-1)
plt.stem(a + b)
plt.title("Adding 5 tap LP and 7 tap HP")


plt.subplot(2, 2, 2)
plt.plot(a + b)
plt.title("Adding 5 tap LP and 7 tap HP")

import numpy as np
import matplotlib.pyplot as plt
import dtcwt


# Function to plot the wavelets
def plot_dtcwt_wavelets(levels=3):
    # Create a 1D signal
    signal_length = 64
    signal = np.zeros(signal_length)
    signal[signal_length // 2] = 1  # Impulse signal


    # Perform Dual-Tree Complex Wavelet Transform
    transform = dtcwt.Transform1d()
    wt = transform.forward(signal, nlevels=levels)


    # Plot the real and imaginary parts of wavelets at each level
    fig, axes = plt.subplots(2, levels, figsize=(15, 6))
    for level in range(levels):
        # Extract wavelet coefficients
        wavelet_coeffs = wt.highpasses[level]
        x = np.arange(len(wavelet_coeffs))


        # Plot real part
        axes[0, level].plot(x, wavelet_coeffs.real, label='Real', color='blue')
        axes[0, level].set_title(f'Level {level + 1} - Real Part')
        axes[0, level].grid(True)


        # Plot imaginary part
        axes[1, level].plot(x, wavelet_coeffs.imag, label='Imaginary', color='orange')
        axes[1, level].set_title(f'Level {level + 1} - Imaginary Part')
        axes[1, level].grid(True)


    # Add labels
    axes[0, 0].set_ylabel('Amplitude')
    axes[1, 0].set_ylabel('Amplitude')
    axes[1, 1].set_xlabel('Index')


    plt.tight_layout()
    plt.show()


# Call the function to plot wavelets
plot_dtcwt_wavelets(levels=4)

学术咨询:

担任《Mechanical System and Signal Processing》《中国电机工程学报》等期刊审稿专家,擅长领域:信号滤波/降噪,机器学习/深度学习,时间序列预分析/预测,设备故障诊断/缺陷检测/异常检测。

分割线分割线分割线

基于小波分析的Linear电磁谱降噪(Python)

Python环境下基于最小最大凹面全变分一维信号降噪方法

MATLAB环境下简单的PPG信号(光电容积脉搏波信号)分析方法(滤波降噪分解等)

基于双树复小波和邻域多尺度的非平稳信号降噪方法(MATLAB)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐