Skip to content

images

Image visualization tools.

cmat(arr: np.ndarray, labels: Iterable[str] | None = None, annot: bool = True, cmap: str = 'gist_heat_r', cbar: bool = False, fmt: str = '0.0%', dark_color: str = '#222222', light_color: str = '#dddddd', grid_color: str = cast(str, c.gray[9]), theta: float = 0.5, label_fontsize: float = 10.0, fontsize: float = 10.0, vmin: float = 0.0, vmax: float = 1.0, **kwargs: Any) -> tuple[AxesImage, Axes]

Plot confusion matrix.

Source code in src/jetplot/images.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
@plotwrapper
def cmat(
    arr: np.ndarray,
    labels: Iterable[str] | None = None,
    annot: bool = True,
    cmap: str = "gist_heat_r",
    cbar: bool = False,
    fmt: str = "0.0%",
    dark_color: str = "#222222",
    light_color: str = "#dddddd",
    grid_color: str = cast(str, c.gray[9]),
    theta: float = 0.5,
    label_fontsize: float = 10.0,
    fontsize: float = 10.0,
    vmin: float = 0.0,
    vmax: float = 1.0,
    **kwargs: Any,
) -> tuple[AxesImage, Axes]:
    """Plot confusion matrix."""
    num_rows, num_cols = arr.shape

    ax = kwargs.pop("ax")
    cb = imv(arr, ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, cbar=cbar)

    xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy")

    for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True):  # pyrefly: ignore
        color = dark_color if (value <= theta) else light_color
        label = f"{{:{fmt}}}".format(value)
        ax.text(x, y, label, ha="center", va="center", color=color, fontsize=fontsize)

    if labels is not None:
        ax.set_xticks(np.arange(num_cols))
        ax.set_xticklabels(labels, rotation=90, fontsize=label_fontsize)
        ax.set_yticks(np.arange(num_rows))
        ax.set_yticklabels(labels, fontsize=label_fontsize)

    ax.xaxis.set_minor_locator(FixedLocator((np.arange(num_cols) - 0.5).tolist()))

    ax.yaxis.set_minor_locator(FixedLocator((np.arange(num_rows) - 0.5).tolist()))

    ax.grid(
        visible=True,
        which="minor",
        axis="both",
        linewidth=1.0,
        color=grid_color,
        linestyle="-",
        alpha=1.0,
    )

    return cb, ax

fsurface(func: Callable[..., np.ndarray], xrng: tuple[float, float] | None = None, yrng: tuple[float, float] | None = None, n: int = 100, nargs: int = 2, **kwargs: Any) -> None

Plot a 2‑D function as a filled surface.

Source code in src/jetplot/images.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@plotwrapper
def fsurface(
    func: Callable[..., np.ndarray],
    xrng: tuple[float, float] | None = None,
    yrng: tuple[float, float] | None = None,
    n: int = 100,
    nargs: int = 2,
    **kwargs: Any,
) -> None:
    """Plot a 2‑D function as a filled surface."""
    xrng = (-1, 1) if xrng is None else xrng
    yrng = xrng if yrng is None else yrng

    xs = np.linspace(xrng[0], xrng[1], n)
    ys = np.linspace(yrng[0], yrng[1], n)

    xm, ym = np.meshgrid(xs, ys)

    if nargs == 1:
        zz = np.vstack([xm.ravel(), ym.ravel()])
        args = (zz,)
    elif nargs == 2:
        args = (xm.ravel(), ym.ravel())
    else:
        raise ValueError(f"Invalid value for nargs ({nargs})")

    zm = func(*args).reshape(xm.shape)

    kwargs["ax"].contourf(xm, ym, zm)

img(data: np.ndarray, mode: str = 'div', cmap: str | None = None, aspect: str = 'equal', vmin: float | None = None, vmax: float | None = None, cbar: bool = True, interpolation: str = 'none', **kwargs: Any) -> AxesImage

Visualize a matrix as an image.

Parameters:

Name Type Description Default
img

array_like, The array to visualize.

required
mode str

string, One of 'div' for a diverging image, 'seq' for sequential, 'cov' for covariance matrices, or 'corr' for correlation matrices (default: 'div').

'div'
cmap str | None

string, Colormap to use.

None
aspect str

string, Either 'equal' or 'auto'

'equal'
Source code in src/jetplot/images.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@plotwrapper
def img(
    data: np.ndarray,
    mode: str = "div",
    cmap: str | None = None,
    aspect: str = "equal",
    vmin: float | None = None,
    vmax: float | None = None,
    cbar: bool = True,
    interpolation: str = "none",
    **kwargs: Any,
) -> AxesImage:
    """Visualize a matrix as an image.

    Args:
      img: array_like, The array to visualize.
      mode: string, One of 'div' for a diverging image, 'seq' for
        sequential, 'cov' for covariance matrices, or 'corr' for
        correlation matrices (default: 'div').
      cmap: string, Colormap to use.
      aspect: string, Either 'equal' or 'auto'
    """
    # work with a copy of the original image data
    img = np.squeeze(data.copy())

    # image bounds
    img_min = np.min(img)
    img_max = np.max(img)
    abs_max = np.max(np.abs(img))

    if mode == "div":
        if vmin is None:
            vmin = -abs_max
        if vmax is None:
            vmax = abs_max
        if cmap is None:
            cmap = "seismic"
    elif mode == "seq":
        if vmin is None:
            vmin = img_min
        if vmax is None:
            vmax = img_max
        if cmap is None:
            cmap = "viridis"
    elif mode == "cov":
        vmin, vmax, cmap, cbar = 0, 1, "viridis", True
    elif mode == "corr":
        vmin, vmax, cmap, cbar = -1, 1, "seismic", True
    else:
        raise ValueError("Unrecognized mode: '" + mode + "'")

    # make the image
    im = kwargs["ax"].imshow(
        img, cmap=cmap, interpolation=interpolation, vmin=vmin, vmax=vmax, aspect=aspect
    )

    # colorbar
    if cbar:
        plt.colorbar(im)

    # clear ticks
    noticks(ax=kwargs["ax"])

    return im