Source code for pysdkit.plot._plot_images

# -*- coding: utf-8 -*-
"""
Created on 2025/02/02 16:47:10
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com
"""
import numpy as np
from numpy import fft
from matplotlib import pyplot as plt

from typing import Optional, Tuple


[docs] def plot_images( img: np.ndarray, spectrum: Optional[bool] = False, dpi: Optional[int] = 128, cmap: Optional[str] = "coolwarm", colorbar: Optional[bool] = False, save_figure: Optional[bool] = False, save_name: Optional[str] = None, return_figure: Optional[bool] = False, ) -> Optional[plt.Figure]: """ Visualize univariate and multivariate 2D images. It is a packaged general interface. The input data `img` is a univariate image [height, width] or a multivariate image [n_vars, height, width] The `spectrum` variable controls whether to visualize the time domain The `colorbar` variable controls whether to add a color bar :param img: The input images,which shape are [height, width]或[n_vars, height, width] :param spectrum: bool, Whether to draw the spectrum image of fast Fourier transform at the same time :param dpi: The resolution at which the image is drawn :param cmap: The colormap to use, defaults is `colorwarm` :param colorbar: bool, whether to add a color bar to the drawn image :param save_figure: Whether to save the figure as an image :param save_name: The name of the saved image file :param return_figure: Whether to return the figure object :return: The plotting Figure from matplotlib """ # Get the shape of the input image to determine whether it is a unary or multivariate image shape = img.shape if len(shape) == 2: # Two dimensions represent a univariate image # The width of the created image width = 5 # Whether to plot frequency domain features via 2D Fast Fourier Transform if colorbar is True: width += 0.75 if spectrum is True: fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(width, 10), dpi=dpi) # Plotting features in the spatial domain of an image cax_image = ax[0].imshow(img, cmap=cmap) ax[0].set_aspect("equal") # Plotting the features of the image in the frequency domain cax_spectrum = ax[1].imshow(np.abs(fft.fftshift(fft.fft2(img))), cmap=cmap) ax[1].set_aspect("equal") # Add color bar here if colorbar is True: fig.colorbar(cax_image, ax=ax[0], orientation="vertical", fraction=0.05) fig.colorbar( cax_spectrum, ax=ax[1], orientation="vertical", fraction=0.05 ) else: # Do not plot frequency domain images fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), dpi=dpi) # Draw only the airspace image cax_image = ax.imshow(img, cmap=cmap) # Add color bar here if colorbar is True: fig.colorbar(cax_image, ax=ax, orientation="vertical", fraction=0.05) elif len(shape) == 3: # Three dimensions represent binary images # Get the number of channels of the image n_vars = shape[0] # The width of the created image width = 5 # Whether to plot frequency domain features via 2D Fast Fourier Transform if colorbar is True: width += 0.2 if spectrum is True: # Whether to plot frequency domain features via 2D Fast Fourier Transform fig, ax = plt.subplots( nrows=2, ncols=n_vars, figsize=(width * n_vars, 10), dpi=dpi ) # Plotting features in the spatial domain of an image cax_image, cax_spectrum = None, None for n in range(n_vars): cax_image = ax[0, n].imshow(img[n], cmap=cmap) cax_spectrum = ax[1, n].imshow( np.abs(fft.fftshift(fft.fft2(img[n]))), cmap=cmap ) # Adding a colorbar if colorbar is True: fig.colorbar( cax_image, ax=[ax[0, i] for i in range(n_vars)], orientation="vertical", fraction=0.05, ) fig.colorbar( cax_spectrum, ax=[ax[1, i] for i in range(n_vars)], orientation="vertical", fraction=0.05, ) else: # Do not plot frequency domain images fig, ax = plt.subplots( nrows=1, ncols=n_vars, figsize=(5 * n_vars, 5), dpi=dpi ) # Draw only the airspace image cax_image = None for n in range(n_vars): cax_image = ax[n].imshow(img[n], cmap=cmap) # Adding a colorbar if colorbar is True: fig.colorbar( cax_image, ax=[ax[i] for i in range(n_vars)], fraction=0.05 ) else: raise ValueError( "The input shape is wrong, please input your univariate image with shape [height, width] and multivariate image with shape [n_vars, height, width]." ) # Save the figure if requested saved = False if save_figure is True: if save_name is not None: for formate in [".jpg", ".pdf", ".png", ".bmp"]: if formate in save_name: fig.savefig(save_name, dpi=dpi, bbox_inches="tight") saved = True break if saved is False: fig.savefig(save_name + ".jpg", dpi=dpi, bbox_inches="tight") else: if len(shape) == 2: save_name = "plot_image.jpg" else: save_name = "plot_images.jpg" fig.savefig(save_name, dpi=dpi, bbox_inches="tight") # Return the figure if requested if return_figure is True: return fig
[docs] def plot_grayscale_image( img: np.ndarray, figsize: Optional[Tuple] = (5, 5), dpi: Optional[int] = 100, cmap: Optional[str] = "coolwarm", ) -> Tuple[plt.Figure, plt.Axes]: """ Visualize a 2D grayscale image. :param img: The input 2D ndarray matrix from numpy. :param figsize: The size of the figure. :param dpi: The resolution used, default is 100. :param cmap: The colormap used. :return: Figure and Axes from matplotlib. """ # Create the figure object fig, ax = plt.subplots(figsize=figsize, dpi=dpi) ax.imshow(img, cmap=cmap) # Visualize the image ax.set_aspect("equal") return fig, ax
[docs] def plot_grayscale_spectrum( img: np.ndarray, figsize: Optional[Tuple] = (5, 5), dpi: Optional[int] = 100, cmap: Optional[str] = "coolwarm", ) -> Tuple[plt.Figure, plt.Axes]: """ Plot the spectrum distribution of a 2D grayscale image. :param img: The input 2D ndarray matrix from numpy. :param figsize: The size of the figure. :param dpi: The resolution used, default is 100. :param cmap: The colormap used. :return: Figure and Axes from matplotlib. """ # Create the figure object fig, ax = plt.subplots(figsize=figsize, dpi=dpi) # Perform a 2D Fast Fourier Transform on the input image spectrum = np.abs(fft.fftshift(fft.fft2(img))) # Obtain the power spectrum ax.imshow(spectrum, cmap=cmap) # Visualize the image ax.set_aspect("equal") return fig, ax
if __name__ == "__main__": from pysdkit.data import test_univariate_image, test_multivariate_image from matplotlib import pyplot as plt plot_images(test_univariate_image(), spectrum=True, colorbar=True) plt.show() plot_images(test_univariate_image(), spectrum=False, colorbar=True) plt.show() plot_images(test_multivariate_image(case=(5, 6, 7)), spectrum=True, colorbar=True) plt.show() plot_images(test_multivariate_image(case=(5, 6, 7)), spectrum=False, colorbar=True) plt.show()