Source code for pysdkit.plot._plot_imfs

# -*- coding: utf-8 -*-
"""
Created on Sat Mar 4 11:59:21 2024
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com

可视化图像的函数需要再专门去写几个
就不同可视化信号的放在一起了

对分解后的频谱进行可视化
"""
import numpy as np
import matplotlib.pyplot as plt

from typing import Optional, List

from pysdkit.plot._functions import generate_random_hex_color
from pysdkit.plot._functions import set_themes

COLORS = [
    "#000000",
    "#228B22",
    "#FF8C00",
    "#BA55D3",
    "#4169E1",
    "#FF6347",
    "#20B2AA",
]


[docs] def plot_IMFs( signal: np.ndarray, IMFs: np.ndarray, max_imfs: Optional[int] = -1, view: Optional[str] = "2d", colors: Optional[List] = None, save_figure: Optional[bool] = False, return_figure: Optional[bool] = False, dpi: Optional[int] = 64, spine_width: float = 2, labelpad: float = 10, save_name: Optional[str] = None, ) -> Optional[plt.Figure]: """ Visualizes the numpy array of intrinsic mode functions derived from the decomposition of a signal. Can be used as a generic interface for plotting. You can choose to visualize on a 2D plane or in 3D space. The 2D plane visualization is more intuitive, while the 3D visualization can better reflect the size relationship between the decomposed modes. :param signal: The input original signal :param IMFs: The intrinsic mode functions obtained after signal decomposition :param max_imfs: The number of decomposition modes to be plotted :param view: The view of the figure, choice ["2d", "3d"] :param colors: List of color strings for plotting :param save_figure: Whether to save the figure as an image :param return_figure: Whether to return the figure object :param dpi: The resolution of the saved image :param spine_width: The width of the visible axes spines :param labelpad: Controls the filling distance of the y-axis coordinate :param save_name: The name of the saved image file :return: The figure object for the plot """ # Get the selected visualization view = view.lower() # Here you need to determine the dimension of the function input and then select the function to use # Judging whether the input signal is multivariate by its dimension shape = signal.shape if view == "2d": # 在2D平面上进行可视化 if len(shape) == 1: # Plotting the results of the decomposition of a univariate signal on a 2D plane return plot_2D_IMFs( signal=signal, IMFs=IMFs, max_imfs=max_imfs, colors=colors, save_figure=save_figure, return_figure=return_figure, dpi=dpi, spine_width=spine_width, labelpad=labelpad, save_name=save_name, ) elif len(shape) == 2: # Plotting the results of the decomposition of a multivariate signal on a 2D plane return plot_multi_IMFs( signal=signal, IMFs=IMFs, max_imfs=max_imfs, colors=colors, save_figure=save_figure, return_figure=return_figure, dpi=dpi, spine_width=spine_width, save_name=save_name, ) else: # The output data is in the wrong format raise ValueError( "The shape of the input signal must be the univariate with [seq_len, ] or multivariate with [n_vars, seq_len]" ) elif view == "3d": # Visualization in 3D space if len(shape) == 1: # Plot the results of the decomposition of a univariate signal in 3D space return plot_3D_IMFs( signal=signal, IMFs=IMFs, max_imfs=max_imfs, colors=colors, save_figure=save_figure, return_figure=return_figure, dpi=dpi, save_name=save_name, ) elif len(shape) == 2: # Plotting the results of the decomposition of the multivariate signal in 3D space return plot_multi_3D_IMFs( signal=signal, IMFs=IMFs, max_imfs=max_imfs, colors=colors, save_figure=save_figure, return_figure=return_figure, dpi=dpi, save_name=save_name, ) else: # The output data is in the wrong format raise ValueError( "The shape of the input signal must be the univariate with [seq_len, ] or multivariate with [n_vars, seq_len]" ) else: raise ValueError("View must be either `2d` or `3d`")
def plot_2D_IMFs( signal: np.ndarray, IMFs: np.ndarray, max_imfs: Optional[int] = -1, colors: Optional[List] = None, save_figure: Optional[bool] = False, return_figure: Optional[bool] = False, dpi: Optional[int] = 64, fontsize: float = 14, spine_width: float = 2, labelpad: float = 10, save_name: Optional[str] = None, ) -> Optional[plt.Figure]: """ Visualizes the numpy array of intrinsic mode functions derived from the decomposition of a signal. Can be used as a generic interface for plotting. :param signal: The input original signal :param IMFs: The intrinsic mode functions obtained after signal decomposition :param max_imfs: The number of decomposition modes to be plotted :param colors: List of color strings for plotting :param save_figure: Whether to save the figure as an image :param return_figure: Whether to return the figure object :param dpi: The resolution of the saved image :param fontsize: The font size of the axis labels :param spine_width: The width of the visible axes spines :param labelpad: Controls the filling distance of the y-axis coordinate :param save_name: The name of the saved image file :return: The figure object for the plot """ # Set the matplotlib configs set_themes(choice="plot_imfs") # Determine the number of rows if max_imfs == -1: n_rows = IMFs.shape[0] + 1 else: n_rows = min(max_imfs, IMFs.shape[0]) + 1 # The length of the signal length = IMFs.shape[1] # Edge padding padding = int(length / 50) # Create the figure and axes fig, ax = plt.subplots(nrows=n_rows, ncols=1, figsize=(10, 2 * n_rows - 1), dpi=256) fig.tight_layout() # Set the colors for plotting if colors is None: colors = COLORS # Add random colors if there are not enough colors in the list while len(colors) <= n_rows: colors.append(generate_random_hex_color()) for i in range(0, n_rows): # Plot a horizontal gray line as the x-axis ax[i].axhline( y=0, color="gray", linestyle="-", alpha=0.6, linewidth=spine_width ) if i == 0: # Plot the original signal in black ax[i].plot(signal, color=colors[i]) ax[i].set_ylabel("Signal", fontsize=fontsize, labelpad=labelpad) else: # Plot the intrinsic mode functions in other colors ax[i].plot(IMFs[i - 1], color=colors[i]) ax[i].set_ylabel(f"IMF-{i - 1}", fontsize=fontsize, labelpad=labelpad) ax[i].set_xlim(-padding, length + padding) # Keep only the left spine visible for spine_name, spine in ax[i].spines.items(): if spine_name != "left": # Only keep the left spine visible spine.set_visible(False) else: # Set the width of the left spine spine.set_visible(True) spine.set_linewidth(spine_width) # Hide x-axis tick labels for all but the last plot if i != n_rows - 1: ax[i].set_xticks([]) # Open the bottom spine of the last axes and set its position ax[-1].spines["bottom"].set_position(("axes", -0.2)) ax[-1].spines["bottom"].set_visible(True) ax[-1].spines["bottom"].set_linewidth(spine_width) # Save the figure if requested saved = False if save_figure is True: if save_name is not None: for formate in [".jpg", ".pdf", ".png", ".bmp"]: if formate in save_name: fig.savefig(save_name, dpi=dpi, bbox_inches="tight") saved = True break if saved is False: fig.savefig(save_name + ".jpg", dpi=dpi, bbox_inches="tight") else: fig.savefig("plot_imfs.jpg", dpi=dpi, bbox_inches="tight") # Return the figure if requested if return_figure is True: return fig def plot_3D_IMFs( signal: np.ndarray, IMFs: np.ndarray, max_imfs: Optional[int] = -1, colors: Optional[List] = None, save_figure: Optional[bool] = False, return_figure: Optional[bool] = False, dpi: Optional[int] = 64, fontsize: float = 5, save_name: Optional[str] = None, ) -> Optional[plt.Figure]: """ Visualizes the numpy array of intrinsic mode functions derived from the decomposition of a signal. Can be used as a generic interface for plotting. :param signal: The input original signal :param IMFs: The intrinsic mode functions obtained after signal decomposition :param max_imfs: The number of decomposition modes to be plotted :param colors: List of color strings for plotting :param save_figure: Whether to save the figure as an image :param return_figure: Whether to return the figure object :param dpi: The resolution of the saved image :param fontsize: The font size of the axis labels :param save_name: The name of the saved image file :return: The figure object for the plot """ # Determine the number of rows if max_imfs == -1: n_rows = IMFs.shape[0] + 1 else: n_rows = min(IMFs.shape[0], max_imfs) + 1 # The length of the signal length = IMFs.shape[1] # Set the colors for plotting if colors is None: colors = COLORS # Add random colors if there are not enough colors in the list while len(colors) <= n_rows: colors.append(generate_random_hex_color()) # Create the figure and axes fig = plt.figure(figsize=(6, 6)) ax = fig.add_subplot(111, projection="3d") # Stacking all the signal including input and IMFs signals = np.vstack([signal, IMFs]) # Create the x and y axes for 3D plotting x = np.flip(np.arange(n_rows)) y = np.arange(length) # Set the x string label x_label = ["Signal"] for num in range(n_rows - 1): x_label.append(f"IMF-{num + 1}") # Traverse each signal segment to draw an image for i in range(0, n_rows): ax.plot(np.ones(length) * x[i], y, signals[i, :], color=colors[i], lw=0.75) # Set the x axes ticks and labels ax.set_xticks(x) ax.set_xticklabels(x_label) # Set the tick params including fontsize and colors ax.tick_params(axis="both", which="major", labelsize=8, colors="black") # Save the figure if requested saved = False if save_figure is True: if save_name is not None: for formate in [".jpg", ".pdf", ".png", ".bmp"]: if formate in save_name: fig.savefig(save_name, dpi=dpi, bbox_inches="tight") saved = True break if saved is False: fig.savefig(save_name + ".jpg", dpi=dpi, bbox_inches="tight") else: fig.savefig("plot_IMFs_3D.jpg", dpi=dpi, bbox_inches="tight") # Return the figure if requested if return_figure is True: return fig def plot_multi_IMFs( signal: np.ndarray, IMFs: np.ndarray, max_imfs: Optional[int] = -1, colors: Optional[List] = None, save_figure: Optional[bool] = False, return_figure: Optional[bool] = False, dpi: Optional[int] = 64, fontsize: float = 8, spine_width: float = 2, save_name: Optional[str] = None, ) -> Optional[plt.Figure]: """ Plotting a multivariate signal and its decomposed intrinsic mode functions :param signal: The input original signal :param IMFs: The intrinsic mode functions obtained after signal decomposition :param max_imfs: The number of decomposition modes to be plotted :param colors: List of color strings for plotting :param save_figure: Whether to save the figure as an image :param return_figure: Whether to return the figure object :param dpi: The resolution of the saved image :param fontsize: The font size of the axis labels :param save_name: The name of the saved image file :param spine_width: The width of the visible axes spines :return: The figure object for the plot """ # Set the matplotlib configs set_themes(choice="plot_imfs") # Get the number of elements and signal length of a multi-element signal n_vars, seq_len = signal.shape # Edge padding padding = int(seq_len / 50) # Determine the number of rows if max_imfs == -1: n_rows = IMFs.shape[0] + 1 else: n_rows = min(max_imfs, IMFs.shape[0]) + 1 # Reconstruct the input signal signals = np.zeros(shape=(n_rows, seq_len, n_vars)) signals[0, :, :] = signal.transpose(1, 0) signals[1:, :, :] = IMFs # Set the colors for plotting if colors is None: colors = COLORS # Add random colors if there are not enough colors in the list while len(colors) <= n_rows: colors.append(generate_random_hex_color()) # Create the figure and axes for multi-plotting fig, ax = plt.subplots( nrows=n_rows, ncols=n_vars, figsize=(3 * n_vars, 1.5 * n_rows - 1), dpi=256 ) fig.tight_layout() # Start drawing the signal for row in range(n_rows): # Iterate over each row and plot the original signal and the intrinsic mode function for col in range(n_vars): # Traverse each column to draw each element signal and its decomposition result ax[row, col].axhline( y=0, color="gray", linestyle="-", alpha=0.6, linewidth=spine_width ) # Plotting the signal ax[row, col].plot(signals[row, :, col], color=colors[row], lw=0.75) # Set the range of the signal ax[row, col].set_xlim(-padding, seq_len + padding) # Adjust the size of the axis scale ax[row, col].tick_params(axis="both", which="major", labelsize=8) # Keep only the left spine visible for spine_name, spine in ax[row, col].spines.items(): if spine_name != "left": # Only keep the left spine visible spine.set_visible(False) else: # Set the width of the left spine spine.set_visible(True) spine.set_linewidth(spine_width) # Hide x-axis tick labels for all but the last plot if row != n_rows - 1: ax[row, col].set_xticks([]) for col in range(n_vars): # Open the bottom spine of the last axes and set its position ax[-1, col].spines["bottom"].set_position(("axes", -0.1)) ax[-1, col].spines["bottom"].set_visible(True) ax[-1, col].spines["bottom"].set_linewidth(spine_width) # Setting the y label for row in range(n_rows): if row == 0: ax[row, 0].set_ylabel("Signal", fontsize=fontsize) else: ax[row, 0].set_ylabel(f"IMF-{row - 1}", fontsize=fontsize) # Setting the number of variations for col in range(n_vars): ax[0, col].set_title(f"Var-{col}", fontsize=fontsize + 1) # Save the figure if requested saved = False if save_figure is True: if save_name is not None: for formate in [".jpg", ".pdf", ".png", ".bmp"]: if formate in save_name: fig.savefig(save_name, dpi=dpi, bbox_inches="tight") saved = True break if saved is False: fig.savefig(save_name + ".jpg", dpi=dpi, bbox_inches="tight") else: fig.savefig("plot_imfs.jpg", dpi=dpi, bbox_inches="tight") # Return the figure if requested if return_figure is True: return fig def plot_multi_3D_IMFs( signal: np.ndarray, IMFs: np.ndarray, max_imfs: Optional[int] = -1, colors: Optional[List] = None, save_figure: Optional[bool] = False, return_figure: Optional[bool] = False, dpi: Optional[int] = 128, save_name: Optional[str] = None, ) -> Optional[plt.Figure]: """ Plot the results of multivariate signal decomposition in 3D :param signal: The input original signal :param IMFs: The intrinsic mode functions obtained after signal decomposition :param max_imfs: The number of decomposition modes to be plotted :param colors: List of color strings for plotting :param save_figure: Whether to save the figure as an image :param return_figure: Whether to return the figure object :param dpi: The resolution of the saved image :param save_name: The name of the saved image file :return: The figure object for the plot """ # Determine the number of rows if max_imfs == -1: n_rows = IMFs.shape[0] + 1 else: n_rows = min(IMFs.shape[0], max_imfs) + 1 # Get the number of elements and signal length of a multi-element signal n_vars, seq_len = signal.shape # Reconstruct the input signal signals = np.zeros(shape=(n_rows, seq_len, n_vars)) signals[0, :, :] = signal.transpose(1, 0) signals[1:, :, :] = IMFs # Set the colors for plotting if colors is None: colors = COLORS # Add random colors if there are not enough colors in the list while len(colors) <= n_rows: colors.append(generate_random_hex_color()) # Create the figure and axes fig = plt.figure(figsize=(5 * n_vars, 6), dpi=200) # Adjust the ax object by using a list axes = [ fig.add_subplot(100 + n_vars * 10 + i, projection="3d") for i in range(1, n_vars + 1) ] # Create the x and y axes for 3D plotting x = np.flip(np.arange(n_rows)) y = np.arange(seq_len) # Set the x string label x_label = ["Signal"] for num in range(n_rows - 1): x_label.append(f"IMF-{num + 1}") # 遍历每一个维度绘制信号 for col in range(n_vars): # 遍历原始信号和分解的本征模态函数 for row in range(n_rows): axes[col].plot( np.ones(seq_len) * x[row], y, signals[row, :, col], color=colors[row], lw=0.75, ) # Adjust the size of the axis scale axes[col].tick_params(axis="both", which="major", labelsize=8) # Set the x axes ticks and labels axes[col].set_xticks(x) axes[col].set_xticklabels(x_label) # Setting the number of variations for col in range(n_vars): axes[col].set_title(f"Var-{col}", fontsize=15) # Save the figure if requested saved = False if save_figure is True: if save_name is not None: for formate in [".jpg", ".pdf", ".png", ".bmp"]: if formate in save_name: fig.savefig(save_name, dpi=dpi, bbox_inches="tight") saved = True break if saved is False: fig.savefig(save_name + ".jpg", dpi=dpi, bbox_inches="tight") else: fig.savefig("plot_imfs.jpg", dpi=dpi, bbox_inches="tight") # Return the figure if requested if return_figure is True: return fig if __name__ == "__main__": from pysdkit.data._generator import test_multivariate_1D_1 from pysdkit import EWT, MVMD from matplotlib import pyplot as plt time, signal = test_multivariate_1D_1() print(signal.shape) mvmd = MVMD(alpha=100, K=2, tau=0.0) IMFs = mvmd(signal=signal) print(IMFs.shape) plot_multi_3D_IMFs(signal=signal, IMFs=IMFs) plt.show()