1"""Module for plotting utilities. 




5- `ax_default` - Nice default plt formatting for x-y data 

6- `plot_slice` - Plots a grid of 1d slices of a multivariate function 

7- `ndscatter` - Plots a grid of 1d and 2d marginals in a "corner plot" for n-dimensional data (especially for MCMC) 


9from typing import Literal 


11import matplotlib 

12import matplotlib.pyplot as plt 

13import numpy as np 

14import scipy.stats as st 

15from matplotlib.colors import LinearSegmentedColormap, ListedColormap 

16from matplotlib.pyplot import cycler 

17from matplotlib.ticker import AutoLocator, FuncFormatter, StrMethodFormatter 


19from .mcmc import normal_sample 

20from .uq_types import Array 


22__all__ = ['ax_default', 'plot_slice', 'ndscatter'] 



25def ax_default(ax: plt.Axes, xlabel='', ylabel='', legend=None, cmap='tab10'): 

26 """Nice default plt formatting for plotting X-Y data. 


28 :param ax: the axes to apply these settings to 

29 :param xlabel: the xlabel to set for `ax` 

30 :param ylabel: the ylabel to set for `ax` 

31 :param legend: will display a legend if bool(legend) is truthy, can pass a dict of legend kwargs here (optional) 

32 :param cmap: colormap to use for cycling 

33 """ 

34 default_leg = {'fancybox': True, 'facecolor': 'white', 'framealpha': 1, 'loc': 'best', 'edgecolor': 'k'} 

35 leg_use = legend if isinstance(legend, dict) else default_leg 

36 for key, val in default_leg.items(): 

37 if key not in leg_use: 

38 leg_use[key] = val 


40 ax.set_prop_cycle(_get_cycle(cmap)) 

41 ax.set_xlabel(xlabel) 

42 ax.set_ylabel(ylabel) 

43 ax.tick_params(axis='both', which='both', direction='in') 

44 if legend: 

45 leg = ax.legend(**leg_use) 

46 return leg 



49def _get_cycle(cmap: str | matplotlib.colors.Colormap, num_colors: int = None): 

50 """Get a color cycler for plotting. 


52 :param cmap: a string specifier of a matplotlib colormap (or a colormap instance) 

53 :param num_colors: the number of colors to cycle through 

54 """ 

55 use_index = False 

56 if isinstance(cmap, str): 

57 use_index = cmap in ['Pastel1', 'Pastel2', 'Paired', 'Accent', 'Dark2', 'Set1', 'Set2', 'Set3', 

58 'tab10', 'tab20', 'tab20b', 'tab20c'] 

59 cmap = plt.get_cmap(cmap) 

60 if num_colors is None: 

61 num_colors = cmap.N 

62 if cmap.N > 100: 

63 use_index = False 

64 elif isinstance(cmap, LinearSegmentedColormap): 

65 use_index = False 

66 elif isinstance(cmap, ListedColormap): 

67 use_index = True 

68 if use_index: 

69 ind = np.arange(int(num_colors)) % cmap.N 

70 return cycler("color", cmap(ind)) 

71 else: 

72 colors = cmap(np.linspace(0, 1, num_colors)) 

73 return cycler("color", colors) 



76def plot_slice(funs, bds: list[tuple], x0: Array = None, x_idx: list[int] = None, 

77 y_idx: list[int] = None, N: int = 50, random_walk: bool = False, xlabels: list[str] = None, 

78 ylabels: list[str] = None, cmap='viridis', fun_labels=None): 

79 """Helper function to plot 1d slices of a function(s) over inputs. 


81 :param funs: function callable as `y=f(x)`, with `x` as `(..., xdim)` and `y` as `(..., ydim)`, can also be a list 

82 of functions to evaluate and plot together. 

83 :param bds: list of tuples of `(min, max)` specifying the bounds of the inputs 

84 :param x0: the default values for all inputs; defaults to middle of `bds` 

85 :param x_idx: list of input indices to take 1d slices of 

86 :param y_idx: list of output indices to plot 1d slices of 

87 :param N: the number of points to take in each 1d slice 

88 :param random_walk: whether to slice in a random d-dimensional direction or hold all params const while slicing 

89 :param xlabels: list of labels for the inputs 

90 :param ylabels: list of labels for the outputs 

91 :param cmap: the name of the matplotlib colormap to use 

92 :param fun_labels: the legend labels if plotting multiple functions on each plot 

93 :returns: `fig, ax` with `num_inputs` by `num_outputs` subplots 

94 """ 

95 funs = funs if isinstance(funs, list) else [funs] 

96 x_idx = list(np.arange(0, min(3, len(bds)))) if x_idx is None else x_idx 

97 y_idx = [0] if y_idx is None else y_idx 

98 xlabels = [f'x{i}' for i in range(len(x_idx))] if xlabels is None else xlabels 

99 ylabels = [f'QoI {i}' for i in range(len(y_idx))] if ylabels is None else ylabels 

100 fun_labels = [f'fun {i}' for i in range(len(funs))] if fun_labels is None else fun_labels 

101 x0 = [(b[0] + b[1]) / 2 for b in bds] if x0 is None else x0 

102 x0 = np.atleast_1d(x0) 

103 xdim = x0.shape[0] 

104 lb = np.atleast_1d([b[0] for b in bds]) 

105 ub = np.atleast_1d([b[1] for b in bds]) 

106 cmap = plt.get_cmap(cmap) 


108 # Construct sliced inputs 

109 xs = np.zeros((N, len(x_idx), xdim)) 

110 for i in range(len(x_idx)): 

111 if random_walk: 

112 # Make a random straight-line walk across d-cube 

113 r0 = np.random.rand(xdim) * (ub - lb) + lb 

114 r0[x_idx[i]] = lb[x_idx[i]] # Start slice at this lower bound 

115 rf = np.random.rand(xdim) * (ub - lb) + lb 

116 rf[x_idx[i]] = ub[x_idx[i]] # Slice up to this upper bound 

117 xs[0, i, :] = r0 

118 for k in range(1, N): 

119 xs[k, i, :] = xs[k-1, i, :] + (rf-r0)/(N-1) 

120 else: 

121 # Otherwise, only slice one variable 

122 for j in range(xdim): 

123 if j == x_idx[i]: 

124 xs[:, i, j] = np.linspace(lb[x_idx[i]], ub[x_idx[i]], N) 

125 else: 

126 xs[:, i, j] = x0[j] 


128 # Compute function values and show ydim by xdim grid of subplots 

129 ys = [] 

130 for func in funs: 

131 y = func(xs) 

132 if y.shape == (N, len(x_idx)): 

133 y = y[..., np.newaxis] 

134 ys.append(y) 

135 c_intervals = np.linspace(0, 1, len(ys)) 


137 fig, axs = plt.subplots(len(y_idx), len(x_idx), sharex='col', sharey='row') 

138 for i in range(len(y_idx)): 

139 for j in range(len(x_idx)): 

140 if len(y_idx) == 1: 

141 ax = axs if len(x_idx) == 1 else axs[j] 

142 elif len(x_idx) == 1: 

143 ax = axs if len(y_idx) == 1 else axs[i] 

144 else: 

145 ax = axs[i, j] 

146 x = xs[:, j, x_idx[j]] 

147 for k in range(len(ys)): 

148 y = ys[k][:, j, y_idx[i]] 

149 ax.plot(x, y, ls='-', color=cmap(c_intervals[k]), label=fun_labels[k]) 

150 ylabel = ylabels[i] if j == 0 else '' 

151 xlabel = xlabels[j] if i == len(y_idx) - 1 else '' 

152 legend = (i == 0 and j == len(x_idx) - 1 and len(ys) > 1) 

153 ax_default(ax, xlabel, ylabel, legend=legend) 

154 fig.set_size_inches(3 * len(x_idx), 3 * len(y_idx)) 

155 fig.tight_layout() 


157 return fig, axs 



160def ndscatter(samples: np.ndarray, labels: list[str] = None, tick_fmts: list[str] = None, 

161 plot1d: Literal['kde', 'hist'] = None, plot2d: Literal['scatter', 'kde', 'hist', 'hex'] = 'scatter', 

162 cmap='viridis', bins=20, cmin=0, z: np.ndarray = None, cb_label=None, cb_norm='linear', 

163 subplot_size=3, cov_overlay=None): 

164 """Triangle scatter plots of n-dimensional samples. 


166 !!! Warning 

167 Best for `dim < 10`. You can shrink the `subplot_size` to assist graphics loading time. 


169 :param samples: `(N, dim)` samples to plot 

170 :param labels: list of axis labels of length `dim` 

171 :param tick_fmts: list of str.format() specifiers for ticks, e.g `['{x: ^10.2f}', ...]`, of length `dim` 

172 :param plot1d: 'hist' or 'kde' for 1d marginals, defaults to plot2d if None 

173 :param plot2d: 'hist' for 2d hist plot, 'kde' for kernel density estimation, 'hex', or 'scatter' (default) 

174 :param cmap: the matplotlib string specifier of a colormap 

175 :param bins: number of bins in each dimension for histogram marginals 

176 :param cmin: the minimum bin count below which the bins are not displayed 

177 :param z: `(N,)` a performance metric corresponding to `samples`, used to color code the scatter plot if provided 

178 :param cb_label: label for color bar (if `z` is provided) 

179 :param cb_norm: `str` or `plt.colors.Normalize`, normalization method for plotting `z` on scatter plot 

180 :param subplot_size: size in inches of a single 2d marginal subplot 

181 :param cov_overlay: `(ndim, ndim)` a covariance matrix to overlay as a Gaussian kde over the samples 

182 :returns fig, axs: the `plt` Figure and Axes objects, (returns an additional `cb_fig, cb_ax` if `z` is specified) 

183 """ 

184 N, dim = samples.shape 

185 x_min = np.min(samples, axis=0) 

186 x_max = np.max(samples, axis=0) 

187 show_colorbar = z is not None 

188 if labels is None: 

189 labels = [f"x{i}" for i in range(dim)] 

190 if z is None: 

191 z = plt.get_cmap(cmap)([0]) 

192 if cb_label is None: 

193 cb_label = 'Performance metric' 


195 def tick_format_func(value, pos): 

196 if np.isclose(value, 0): 

197 return f'{value:.2f}' 

198 if abs(value) > 1000: 

199 return f'{value:.2E}' 

200 if abs(value) > 100: 

201 return f'{int(value):d}' 

202 if abs(value) > 1: 

203 return f'{value:.2f}' 

204 if abs(value) > 0.01: 

205 return f'{value:.4f}' 

206 if abs(value) < 0.01: 

207 return f'{value:.2E}' 

208 default_ticks = FuncFormatter(tick_format_func) 

209 # if tick_fmts is None: 

210 # tick_fmts = ['{x:.2G}' for i in range(dim)] 


212 # Set up triangle plot formatting 

213 fig, axs = plt.subplots(dim, dim, sharex='col', sharey='row') 

214 for i in range(dim): 

215 for j in range(dim): 

216 ax = axs[i, j] 

217 if i == j: # 1d marginals on diagonal 

218 # ax.get_shared_y_axes().remove(ax) 

219 ax._shared_axes['y'].remove(ax) 

220 ax.spines['top'].set_visible(False) 

221 ax.spines['right'].set_visible(False) 

222 ax.spines['left'].set_visible(False) 

223 if i == 0: 

224 ax.get_yaxis().set_ticks([]) 

225 if j > i: # Clear the upper triangle 

226 ax.axis('off') 

227 if i == dim - 1: # Bottom row 

228 ax.set_xlabel(labels[j]) 

229 ax.xaxis.set_major_locator(AutoLocator()) 

230 formatter = StrMethodFormatter(tick_fmts[j]) if tick_fmts is not None else default_ticks 

231 ax.xaxis.set_major_formatter(formatter) 

232 if j == 0 and i > 0: # Left column 

233 ax.set_ylabel(labels[i]) 

234 ax.yaxis.set_major_locator(AutoLocator()) 

235 formatter = StrMethodFormatter(tick_fmts[i]) if tick_fmts is not None else default_ticks 

236 ax.yaxis.set_major_formatter(formatter) 


238 if cov_overlay is not None: 

239 x_overlay = normal_sample(np.mean(samples, axis=0), cov_overlay, 5000) 


241 # Plot marginals 

242 for i in range(dim): 

243 for j in range(dim): 

244 ax = axs[i, j] 

245 if i == j: # 1d marginals (on diagonal) 

246 c = plt.get_cmap(cmap)(0) 

247 plot = plot1d if plot1d is not None else plot2d 

248 if plot == 'kde': 

249 kernel = st.gaussian_kde(samples[:, i]) 

250 x = np.linspace(x_min[i], x_max[i], 500) 

251 ax.fill_between(x, y1=kernel(x), y2=0, lw=0, alpha=0.3, facecolor=c) 

252 ax.plot(x, kernel(x), ls='-', c=c, lw=1.5) 

253 else: 

254 ax.hist(samples[:, i], edgecolor='black', color=c, density=True, alpha=0.5, 

255 linewidth=1.2, bins=bins) 

256 if cov_overlay is not None: 

257 kernel = st.gaussian_kde(x_overlay[:, i]) 

258 x = np.linspace(x_min[i], x_max[i], 500) 

259 ax.fill_between(x, y1=kernel(x), y2=0, lw=0, alpha=0.5, facecolor=[0.5, 0.5, 0.5]) 

260 ax.plot(x, kernel(x), ls='-', c='k', lw=1.5, alpha=0.5) 

261 bottom, top = ax.get_ylim() 

262 ax.set_ylim([0, top]) 

263 if j < i: # 2d marginals (lower triangle) 

264 ax.set_xlim([x_min[j], x_max[j]]) 

265 ax.set_ylim([x_min[i], x_max[i]]) 

266 if plot2d == 'scatter': 

267 sc = ax.scatter(samples[:, j], samples[:, i], s=1.5, c=z, cmap=cmap, norm=cb_norm) 

268 elif plot2d == 'hist': 

269 ax.hist2d(samples[:, j], samples[:, i], bins=bins, cmap=cmap, cmin=cmin) 

270 elif plot2d == 'kde': 

271 kernel = st.gaussian_kde(samples[:, [j, i]].T) 

272 xg, yg = np.meshgrid(np.linspace(x_min[j], x_max[j], 40), np.linspace(x_min[i], x_max[i], 40)) 

273 x = np.vstack([xg.ravel(), yg.ravel()]) 

274 zg = np.reshape(kernel(x), xg.shape) 

275 cs = ax.contourf(xg, yg, zg, 5, cmap=cmap, alpha=0.9, extend='both') 

276 cs.cmap.set_under('white') 

277 cs.changed() 

278 ax.contour(xg, yg, zg, 5, colors=[(0.5, 0.5, 0.5)], linewidths=1.2) 

279 elif plot2d == 'hex': 

280 ax.hexbin(samples[:, j], samples[:, i], gridsize=bins, cmap=cmap, mincnt=cmin) 

281 else: 

282 raise NotImplementedError('This plot type is not known. plot2d=["hist", "kde", "scatter"]') 


284 if cov_overlay is not None: 

285 kernel = st.gaussian_kde(x_overlay[:, [j, i]].T) 

286 xg, yg = np.meshgrid(np.linspace(x_min[j], x_max[j], 40), np.linspace(x_min[i], x_max[i], 40)) 

287 x = np.vstack([xg.ravel(), yg.ravel()]) 

288 zg = np.reshape(kernel(x), xg.shape) 

289 ax.contourf(xg, yg, zg, 4, cmap='Greys', alpha=0.4) 

290 ax.contour(xg, yg, zg, 4, colors='k', linewidths=1.5, alpha=0.6) 


292 fig.set_size_inches(subplot_size * dim, subplot_size * dim) 

293 fig.tight_layout() 


295 # Plot colorbar in standalone figure 

296 if show_colorbar and plot2d == 'scatter': 

297 cb_fig, cb_ax = plt.subplots(figsize=(1.5, 6)) 

298 cb_fig.subplots_adjust(right=0.7) 

299 cb_fig.colorbar(sc, cax=cb_ax, orientation='vertical', label=cb_label) 

300 cb_fig.tight_layout() 

301 return fig, axs, cb_fig, cb_ax 


303 return fig, axs