Coverage for src/uqtils/plots.py: 78%
199 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 03:45 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 03:45 +0000
1"""Module for plotting utilities.
3Includes:
5- `ax_default` - Nice default plt formatting for x-y data
6- `plot_slice` - Plots a grid of 1d slices of a multivariate function
7- `ndscatter` - Plots a grid of 1d and 2d marginals in a "corner plot" for n-dimensional data (especially for MCMC)
8"""
9from typing import Literal
11import matplotlib
12import matplotlib.pyplot as plt
13import numpy as np
14import scipy.stats as st
15from matplotlib.colors import LinearSegmentedColormap, ListedColormap
16from matplotlib.pyplot import cycler
17from matplotlib.ticker import AutoLocator, FuncFormatter, StrMethodFormatter
19from .mcmc import normal_sample
20from .uq_types import Array
22__all__ = ['ax_default', 'plot_slice', 'ndscatter']
25def ax_default(ax: plt.Axes, xlabel='', ylabel='', legend=None, cmap='tab10'):
26 """Nice default plt formatting for plotting X-Y data.
28 :param ax: the axes to apply these settings to
29 :param xlabel: the xlabel to set for `ax`
30 :param ylabel: the ylabel to set for `ax`
31 :param legend: will display a legend if bool(legend) is truthy, can pass a dict of legend kwargs here (optional)
32 :param cmap: colormap to use for cycling
33 """
34 default_leg = {'fancybox': True, 'facecolor': 'white', 'framealpha': 1, 'loc': 'best', 'edgecolor': 'k'}
35 leg_use = legend if isinstance(legend, dict) else default_leg
36 for key, val in default_leg.items():
37 if key not in leg_use:
38 leg_use[key] = val
40 ax.set_prop_cycle(_get_cycle(cmap))
41 ax.set_xlabel(xlabel)
42 ax.set_ylabel(ylabel)
43 ax.tick_params(axis='both', which='both', direction='in')
44 if legend:
45 leg = ax.legend(**leg_use)
46 return leg
49def _get_cycle(cmap: str | matplotlib.colors.Colormap, num_colors: int = None):
50 """Get a color cycler for plotting.
52 :param cmap: a string specifier of a matplotlib colormap (or a colormap instance)
53 :param num_colors: the number of colors to cycle through
54 """
55 use_index = False
56 if isinstance(cmap, str):
57 use_index = cmap in ['Pastel1', 'Pastel2', 'Paired', 'Accent', 'Dark2', 'Set1', 'Set2', 'Set3',
58 'tab10', 'tab20', 'tab20b', 'tab20c']
59 cmap = plt.get_cmap(cmap)
60 if num_colors is None:
61 num_colors = cmap.N
62 if cmap.N > 100:
63 use_index = False
64 elif isinstance(cmap, LinearSegmentedColormap):
65 use_index = False
66 elif isinstance(cmap, ListedColormap):
67 use_index = True
68 if use_index:
69 ind = np.arange(int(num_colors)) % cmap.N
70 return cycler("color", cmap(ind))
71 else:
72 colors = cmap(np.linspace(0, 1, num_colors))
73 return cycler("color", colors)
76def plot_slice(funs, bds: list[tuple], x0: Array = None, x_idx: list[int] = None,
77 y_idx: list[int] = None, N: int = 50, random_walk: bool = False, xlabels: list[str] = None,
78 ylabels: list[str] = None, cmap='viridis', fun_labels=None):
79 """Helper function to plot 1d slices of a function(s) over inputs.
81 :param funs: function callable as `y=f(x)`, with `x` as `(..., xdim)` and `y` as `(..., ydim)`, can also be a list
82 of functions to evaluate and plot together.
83 :param bds: list of tuples of `(min, max)` specifying the bounds of the inputs
84 :param x0: the default values for all inputs; defaults to middle of `bds`
85 :param x_idx: list of input indices to take 1d slices of
86 :param y_idx: list of output indices to plot 1d slices of
87 :param N: the number of points to take in each 1d slice
88 :param random_walk: whether to slice in a random d-dimensional direction or hold all params const while slicing
89 :param xlabels: list of labels for the inputs
90 :param ylabels: list of labels for the outputs
91 :param cmap: the name of the matplotlib colormap to use
92 :param fun_labels: the legend labels if plotting multiple functions on each plot
93 :returns: `fig, ax` with `num_inputs` by `num_outputs` subplots
94 """
95 funs = funs if isinstance(funs, list) else [funs]
96 x_idx = list(np.arange(0, min(3, len(bds)))) if x_idx is None else x_idx
97 y_idx = [0] if y_idx is None else y_idx
98 xlabels = [f'x{i}' for i in range(len(x_idx))] if xlabels is None else xlabels
99 ylabels = [f'QoI {i}' for i in range(len(y_idx))] if ylabels is None else ylabels
100 fun_labels = [f'fun {i}' for i in range(len(funs))] if fun_labels is None else fun_labels
101 x0 = [(b[0] + b[1]) / 2 for b in bds] if x0 is None else x0
102 x0 = np.atleast_1d(x0)
103 xdim = x0.shape[0]
104 lb = np.atleast_1d([b[0] for b in bds])
105 ub = np.atleast_1d([b[1] for b in bds])
106 cmap = plt.get_cmap(cmap)
108 # Construct sliced inputs
109 xs = np.zeros((N, len(x_idx), xdim))
110 for i in range(len(x_idx)):
111 if random_walk:
112 # Make a random straight-line walk across d-cube
113 r0 = np.random.rand(xdim) * (ub - lb) + lb
114 r0[x_idx[i]] = lb[x_idx[i]] # Start slice at this lower bound
115 rf = np.random.rand(xdim) * (ub - lb) + lb
116 rf[x_idx[i]] = ub[x_idx[i]] # Slice up to this upper bound
117 xs[0, i, :] = r0
118 for k in range(1, N):
119 xs[k, i, :] = xs[k-1, i, :] + (rf-r0)/(N-1)
120 else:
121 # Otherwise, only slice one variable
122 for j in range(xdim):
123 if j == x_idx[i]:
124 xs[:, i, j] = np.linspace(lb[x_idx[i]], ub[x_idx[i]], N)
125 else:
126 xs[:, i, j] = x0[j]
128 # Compute function values and show ydim by xdim grid of subplots
129 ys = []
130 for func in funs:
131 y = func(xs)
132 if y.shape == (N, len(x_idx)):
133 y = y[..., np.newaxis]
134 ys.append(y)
135 c_intervals = np.linspace(0, 1, len(ys))
137 fig, axs = plt.subplots(len(y_idx), len(x_idx), sharex='col', sharey='row')
138 for i in range(len(y_idx)):
139 for j in range(len(x_idx)):
140 if len(y_idx) == 1:
141 ax = axs if len(x_idx) == 1 else axs[j]
142 elif len(x_idx) == 1:
143 ax = axs if len(y_idx) == 1 else axs[i]
144 else:
145 ax = axs[i, j]
146 x = xs[:, j, x_idx[j]]
147 for k in range(len(ys)):
148 y = ys[k][:, j, y_idx[i]]
149 ax.plot(x, y, ls='-', color=cmap(c_intervals[k]), label=fun_labels[k])
150 ylabel = ylabels[i] if j == 0 else ''
151 xlabel = xlabels[j] if i == len(y_idx) - 1 else ''
152 legend = (i == 0 and j == len(x_idx) - 1 and len(ys) > 1)
153 ax_default(ax, xlabel, ylabel, legend=legend)
154 fig.set_size_inches(3 * len(x_idx), 3 * len(y_idx))
155 fig.tight_layout()
157 return fig, axs
160def ndscatter(samples: np.ndarray, labels: list[str] = None, tick_fmts: list[str] = None,
161 plot1d: Literal['kde', 'hist'] = None, plot2d: Literal['scatter', 'kde', 'hist', 'hex'] = 'scatter',
162 cmap='viridis', bins=20, cmin=0, z: np.ndarray = None, cb_label=None, cb_norm='linear',
163 subplot_size=3, cov_overlay=None):
164 """Triangle scatter plots of n-dimensional samples.
166 !!! Warning
167 Best for `dim < 10`. You can shrink the `subplot_size` to assist graphics loading time.
169 :param samples: `(N, dim)` samples to plot
170 :param labels: list of axis labels of length `dim`
171 :param tick_fmts: list of str.format() specifiers for ticks, e.g `['{x: ^10.2f}', ...]`, of length `dim`
172 :param plot1d: 'hist' or 'kde' for 1d marginals, defaults to plot2d if None
173 :param plot2d: 'hist' for 2d hist plot, 'kde' for kernel density estimation, 'hex', or 'scatter' (default)
174 :param cmap: the matplotlib string specifier of a colormap
175 :param bins: number of bins in each dimension for histogram marginals
176 :param cmin: the minimum bin count below which the bins are not displayed
177 :param z: `(N,)` a performance metric corresponding to `samples`, used to color code the scatter plot if provided
178 :param cb_label: label for color bar (if `z` is provided)
179 :param cb_norm: `str` or `plt.colors.Normalize`, normalization method for plotting `z` on scatter plot
180 :param subplot_size: size in inches of a single 2d marginal subplot
181 :param cov_overlay: `(ndim, ndim)` a covariance matrix to overlay as a Gaussian kde over the samples
182 :returns fig, axs: the `plt` Figure and Axes objects, (returns an additional `cb_fig, cb_ax` if `z` is specified)
183 """
184 N, dim = samples.shape
185 x_min = np.min(samples, axis=0)
186 x_max = np.max(samples, axis=0)
187 show_colorbar = z is not None
188 if labels is None:
189 labels = [f"x{i}" for i in range(dim)]
190 if z is None:
191 z = plt.get_cmap(cmap)([0])
192 if cb_label is None:
193 cb_label = 'Performance metric'
195 def tick_format_func(value, pos):
196 if np.isclose(value, 0):
197 return f'{value:.2f}'
198 if abs(value) > 1000:
199 return f'{value:.2E}'
200 if abs(value) > 100:
201 return f'{int(value):d}'
202 if abs(value) > 1:
203 return f'{value:.2f}'
204 if abs(value) > 0.01:
205 return f'{value:.4f}'
206 if abs(value) < 0.01:
207 return f'{value:.2E}'
208 default_ticks = FuncFormatter(tick_format_func)
209 # if tick_fmts is None:
210 # tick_fmts = ['{x:.2G}' for i in range(dim)]
212 # Set up triangle plot formatting
213 fig, axs = plt.subplots(dim, dim, sharex='col', sharey='row')
214 for i in range(dim):
215 for j in range(dim):
216 ax = axs[i, j]
217 if i == j: # 1d marginals on diagonal
218 # ax.get_shared_y_axes().remove(ax)
219 ax._shared_axes['y'].remove(ax)
220 ax.spines['top'].set_visible(False)
221 ax.spines['right'].set_visible(False)
222 ax.spines['left'].set_visible(False)
223 if i == 0:
224 ax.get_yaxis().set_ticks([])
225 if j > i: # Clear the upper triangle
226 ax.axis('off')
227 if i == dim - 1: # Bottom row
228 ax.set_xlabel(labels[j])
229 ax.xaxis.set_major_locator(AutoLocator())
230 formatter = StrMethodFormatter(tick_fmts[j]) if tick_fmts is not None else default_ticks
231 ax.xaxis.set_major_formatter(formatter)
232 if j == 0 and i > 0: # Left column
233 ax.set_ylabel(labels[i])
234 ax.yaxis.set_major_locator(AutoLocator())
235 formatter = StrMethodFormatter(tick_fmts[i]) if tick_fmts is not None else default_ticks
236 ax.yaxis.set_major_formatter(formatter)
238 if cov_overlay is not None:
239 x_overlay = normal_sample(np.mean(samples, axis=0), cov_overlay, 5000)
241 # Plot marginals
242 for i in range(dim):
243 for j in range(dim):
244 ax = axs[i, j]
245 if i == j: # 1d marginals (on diagonal)
246 c = plt.get_cmap(cmap)(0)
247 plot = plot1d if plot1d is not None else plot2d
248 if plot == 'kde':
249 kernel = st.gaussian_kde(samples[:, i])
250 x = np.linspace(x_min[i], x_max[i], 500)
251 ax.fill_between(x, y1=kernel(x), y2=0, lw=0, alpha=0.3, facecolor=c)
252 ax.plot(x, kernel(x), ls='-', c=c, lw=1.5)
253 else:
254 ax.hist(samples[:, i], edgecolor='black', color=c, density=True, alpha=0.5,
255 linewidth=1.2, bins=bins)
256 if cov_overlay is not None:
257 kernel = st.gaussian_kde(x_overlay[:, i])
258 x = np.linspace(x_min[i], x_max[i], 500)
259 ax.fill_between(x, y1=kernel(x), y2=0, lw=0, alpha=0.5, facecolor=[0.5, 0.5, 0.5])
260 ax.plot(x, kernel(x), ls='-', c='k', lw=1.5, alpha=0.5)
261 bottom, top = ax.get_ylim()
262 ax.set_ylim([0, top])
263 if j < i: # 2d marginals (lower triangle)
264 ax.set_xlim([x_min[j], x_max[j]])
265 ax.set_ylim([x_min[i], x_max[i]])
266 if plot2d == 'scatter':
267 sc = ax.scatter(samples[:, j], samples[:, i], s=1.5, c=z, cmap=cmap, norm=cb_norm)
268 elif plot2d == 'hist':
269 ax.hist2d(samples[:, j], samples[:, i], bins=bins, cmap=cmap, cmin=cmin)
270 elif plot2d == 'kde':
271 kernel = st.gaussian_kde(samples[:, [j, i]].T)
272 xg, yg = np.meshgrid(np.linspace(x_min[j], x_max[j], 40), np.linspace(x_min[i], x_max[i], 40))
273 x = np.vstack([xg.ravel(), yg.ravel()])
274 zg = np.reshape(kernel(x), xg.shape)
275 cs = ax.contourf(xg, yg, zg, 5, cmap=cmap, alpha=0.9, extend='both')
276 cs.cmap.set_under('white')
277 cs.changed()
278 ax.contour(xg, yg, zg, 5, colors=[(0.5, 0.5, 0.5)], linewidths=1.2)
279 elif plot2d == 'hex':
280 ax.hexbin(samples[:, j], samples[:, i], gridsize=bins, cmap=cmap, mincnt=cmin)
281 else:
282 raise NotImplementedError('This plot type is not known. plot2d=["hist", "kde", "scatter"]')
284 if cov_overlay is not None:
285 kernel = st.gaussian_kde(x_overlay[:, [j, i]].T)
286 xg, yg = np.meshgrid(np.linspace(x_min[j], x_max[j], 40), np.linspace(x_min[i], x_max[i], 40))
287 x = np.vstack([xg.ravel(), yg.ravel()])
288 zg = np.reshape(kernel(x), xg.shape)
289 ax.contourf(xg, yg, zg, 4, cmap='Greys', alpha=0.4)
290 ax.contour(xg, yg, zg, 4, colors='k', linewidths=1.5, alpha=0.6)
292 fig.set_size_inches(subplot_size * dim, subplot_size * dim)
293 fig.tight_layout()
295 # Plot colorbar in standalone figure
296 if show_colorbar and plot2d == 'scatter':
297 cb_fig, cb_ax = plt.subplots(figsize=(1.5, 6))
298 cb_fig.subplots_adjust(right=0.7)
299 cb_fig.colorbar(sc, cax=cb_ax, orientation='vertical', label=cb_label)
300 cb_fig.tight_layout()
301 return fig, axs, cb_fig, cb_ax
303 return fig, axs