Visualization of Heatmaps

Module

prosper_nn.utils.visualization.plot_heatmap(heatmap_matrix: Tensor, xlabel: str | None = None, ylabel: str | None = None, xticks: dict = {}, yticks: dict = {}, title: str | None = None, save_at: str | None = None, center: float | None = None, cbar_kws: dict = {}, grid: dict = {'b': False}, vmin: float | None = None, vmax: float | None = None, annot: array | None = None, mask: array | None = None, fmt: str | None = '', figsize: tuple | None = None, square: bool = False)[source]

Plots a heatmap for a heatmap matrix.

Parameters:
  • heatmap_matrix (torch.Tensor) – A matrix where each pixel represents the heat.

  • xlabel (Optional[str]) – Set the label for the x-axis.

  • ylabel (Optional[str]) – Set the label for the y-axis.

  • xticks (dict) – Set the current tick locations and labels of the x-axis.

  • yticks (dict) – Set the current tick locations and labels of the y-axis.

  • title (Optional[str]) – Set a title for the axes.

  • save_at (Optional[str]) – Save the current figure.

  • center (Optional[float]) – The value at which to center the colormap when plotting divergant data. Using this parameter will change the default cmap if none is specified.

  • cbar_kws (dict) – Keyword arguments for matplotlib.figure.Figure.colorbar().

  • grid (dict) – Configure the grid lines.

  • vmin (Optional[float]) – Value to anchor the colormap, otherwise it is inferred from the data and other keyword arguments.

  • vmax (Optional[float]) – Value to anchor the colormap, otherwise it is inferred from the data and other keyword arguments.

  • annot (Optional[np.array]) – If True, write the data value in each cell. If an array-like with the same shape as data, then use this to annotate the heatmap instead of the data. Note that DataFrames will match on position, not index.

  • mask (Optional[np.array]) – If passed, data will not be shown in cells where mask is True. Cells with missing values are automatically masked.

  • fmt (Optional[str]) – String formatting code to use when adding annotations.

  • figsize (Optional[tuple]) – width, height in inches. If not provided, defaults to rcParams[“figure.figsize”] = [6.4, 4.8].

  • square (bool) – If True, set the Axes aspect to “equal” so each cell will be square-shaped.

Return type:

None

Example

X = torch.rand([10, 100])
plot_heatmap(X,
             xlabel={'xlabel': 'xlabel'},
             ylabel={'ylabel': 'ylable'},
             title={'label': 'Title'})