Visualization of Heatmaps
Module
- prosper_nn.utils.visualization.plot_heatmap(heatmap_matrix: Tensor, xlabel: str | None = None, ylabel: str | None = None, xticks: dict = {}, yticks: dict = {}, title: str | None = None, save_at: str | None = None, center: float | None = None, cbar_kws: dict = {}, grid: dict = {'b': False}, vmin: float | None = None, vmax: float | None = None, annot: array | None = None, mask: array | None = None, fmt: str | None = '', figsize: tuple | None = None, square: bool = False)[source]
Plots a heatmap for a heatmap matrix.
- Parameters:
heatmap_matrix (torch.Tensor) – A matrix where each pixel represents the heat.
xlabel (Optional[str]) – Set the label for the x-axis.
ylabel (Optional[str]) – Set the label for the y-axis.
xticks (dict) – Set the current tick locations and labels of the x-axis.
yticks (dict) – Set the current tick locations and labels of the y-axis.
title (Optional[str]) – Set a title for the axes.
save_at (Optional[str]) – Save the current figure.
center (Optional[float]) – The value at which to center the colormap when plotting divergant data. Using this parameter will change the default cmap if none is specified.
cbar_kws (dict) – Keyword arguments for matplotlib.figure.Figure.colorbar().
grid (dict) – Configure the grid lines.
vmin (Optional[float]) – Value to anchor the colormap, otherwise it is inferred from the data and other keyword arguments.
vmax (Optional[float]) – Value to anchor the colormap, otherwise it is inferred from the data and other keyword arguments.
annot (Optional[np.array]) – If True, write the data value in each cell. If an array-like with the same shape as data, then use this to annotate the heatmap instead of the data. Note that DataFrames will match on position, not index.
mask (Optional[np.array]) – If passed, data will not be shown in cells where mask is True. Cells with missing values are automatically masked.
fmt (Optional[str]) – String formatting code to use when adding annotations.
figsize (Optional[tuple]) – width, height in inches. If not provided, defaults to rcParams[“figure.figsize”] = [6.4, 4.8].
square (bool) – If True, set the Axes aspect to “equal” so each cell will be square-shaped.
- Return type:
None
Example
X = torch.rand([10, 100])
plot_heatmap(X,
xlabel={'xlabel': 'xlabel'},
ylabel={'ylabel': 'ylable'},
title={'label': 'Title'})