Coverage for src/uqtils/sobol.py: 100%
158 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 03:45 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 03:45 +0000
1"""Module for Sobol' sensitivity analysis.
3Includes:
5- `sobol_sa` - function for global sensitivity analysis
6"""
7import matplotlib.pyplot as plt
8import numpy as np
9import scipy.stats as st
11from uqtils import ax_default
13__all__ = ['sobol_sa', 'ishigami']
16def sobol_sa(model, sampler, num_samples: int, qoi_idx: list[int] = None, qoi_labels: list[str] = None,
17 param_labels: list[str] = None, plot: bool = False, verbose: bool = True, cmap: str = 'viridis',
18 compute_s2: bool = False):
19 """Perform a global Sobol' sensitivity analysis.
21 :param model: callable as `y=model(x)`, with `y=(..., ydim)`, `x=(..., xdim)`
22 :param sampler: callable as `x=sampler(shape)`, with `x=(*shape, xdim)`
23 :param num_samples: number of samples
24 :param qoi_idx: list of indices of model output to report results for
25 :param qoi_labels: list of labels for plotting QoIs
26 :param param_labels: list of labels for plotting input parameters
27 :param plot: whether to plot bar/pie charts
28 :param verbose: whether to print `S1/ST/S2` results to the console
29 :param cmap: `str` specifier of `plt.colormap` for bar/pie charts
30 :param compute_s2: whether to compute the second order indices
31 :return: `S1`, `[S2]`, `ST`, the first, second, and total order Sobol' indices
32 """
33 # Get sample matrices (N, xdim)
34 A = sampler((num_samples,))
35 B = sampler((num_samples,))
36 xdim = A.shape[-1]
37 AB = np.tile(np.expand_dims(A, axis=-2), (1, xdim, 1))
38 BA = np.tile(np.expand_dims(B, axis=-2), (1, xdim, 1))
39 for i in range(xdim):
40 AB[:, i, i] = B[:, i]
41 BA[:, i, i] = A[:, i]
43 # Evaluate the model; (xdim+2)*N evaluations required
44 fA = model(A) # (N, ydim)
45 fB = model(B) # (N, ydim)
46 fAB = model(AB) # (N, xdim, ydim)
47 fBA = model(BA) # (N, xdim, ydim)
48 ydim = fA.shape[-1]
50 # Normalize model outputs to N(0, 1) for better stability
51 Y = np.concatenate((fA, fB, fAB.reshape((-1, ydim)), fBA.reshape((-1, ydim))), axis=0)
52 mu, std = np.mean(Y, axis=0), np.std(Y, axis=0)
53 fA = (fA - mu) / std
54 fB = (fB - mu) / std
55 fAB = (fAB - mu) / std
56 fBA = (fBA - mu) / std
58 # Compute sensitivity indices
59 vY = np.var(np.concatenate((fA, fB), axis=0), axis=0) # (ydim,)
60 fA = np.expand_dims(fA, axis=-2) # (N, 1, ydim)
61 fB = np.expand_dims(fB, axis=-2) # (N, 1, ydim)
62 S1 = fB * (fAB - fA) / vY # (N, xdim, ydim)
63 ST = 0.5 * (fA - fAB)**2 / vY # (N, xdim, ydim)
65 # Second-order indices
66 if compute_s2:
67 Vij = np.expand_dims(fBA, axis=2) * np.expand_dims(fAB, axis=1) - \
68 np.expand_dims(fA, axis=1) * np.expand_dims(fB, axis=1) # (N, xdim, xdim, ydim)
69 si = fB * (fAB - fA)
70 Vi = np.expand_dims(si, axis=2)
71 Vj = np.expand_dims(si, axis=1)
72 S2 = (Vij - Vi - Vj) / vY # (N, xdim, xdim, ydim)
73 S2_est = np.mean(S2, axis=0)
74 S2_se = np.sqrt(np.var(S2, axis=0) / num_samples)
76 # Get mean values and MC error
77 S1_est = np.mean(S1, axis=0)
78 S1_se = np.sqrt(np.var(S1, axis=0) / num_samples)
79 ST_est = np.mean(ST, axis=0)
80 ST_se = np.sqrt(np.var(ST, axis=0) / num_samples)
82 # Set defaults for qoi indices/labels
83 if qoi_idx is None:
84 qoi_idx = list(np.arange(ydim))
85 if qoi_labels is None:
86 qoi_labels = [f'QoI {i}' for i in range(len(qoi_idx))]
87 if param_labels is None:
88 param_labels = [f'x{i}' for i in range(xdim)]
90 # Print results
91 if verbose:
92 print(f'{"QoI":>10} {"Param":>10} {"S1_mean":>10} {"S1_err":>10} {"ST_mean":>10} {"ST_err":>10}')
93 for i in range(len(qoi_idx)):
94 for j in range(xdim):
95 q = qoi_idx[i]
96 print(f'{qoi_labels[i]:>10} {param_labels[j]:>10} {S1_est[j, q]: 10.3f} {S1_se[j, q]: 10.3f} '
97 f'{ST_est[j, q]: 10.3f} {ST_se[j, q]: 10.3f}')
99 if compute_s2:
100 print(f'\n{"QoI":>10} {"2nd-order":>20} {"S2_mean":>10} {"S2_err":>10}')
101 for i in range(len(qoi_idx)):
102 for j in range(xdim):
103 for k in range(j+1, xdim):
104 q = qoi_idx[i]
105 print(f'{qoi_labels[i]:>10} {"("+param_labels[j]+", "+param_labels[k]+")":>20} '
106 f'{S2_est[j, k, q]: 10.3f} {S2_se[j, k, q]: 10.3f}')
108 S1_total = np.sum(S1_est, axis=0) # (ydim,)
109 S2_total = np.zeros((ydim,)) # (ydim,)
110 if compute_s2:
111 for i in range(xdim):
112 for j in range(i+1, xdim):
113 S2_total += S2_est[i, j, :] # sum the upper diagonal
114 print(f'\n{"QoI":>10} {"S1 total":>10} {"S2 total":>10} {"Higher order":>15}')
115 for i in range(len(qoi_idx)):
116 q = qoi_idx[i]
117 print(f'{qoi_labels[i]:>10} {S1_total[q]: 10.3f} {S2_total[q]: 10.3f} '
118 f'{1 - S1_total[q] - S2_total[q]: 15.3f}')
120 if plot:
121 # Plot bar chart of S1, ST
122 c = plt.get_cmap(cmap)
123 fig, axs = plt.subplots(1, len(qoi_idx))
124 for i in range(len(qoi_idx)):
125 ax = axs[i] if len(qoi_idx) > 1 else axs
126 q = qoi_idx[i]
127 z = st.norm.ppf(1 - (1-0.95)/2) # get z-score from N(0,1), assuming CLT at n>30
128 x = np.arange(xdim)
129 width = 0.2
130 ax.bar(x - width / 2, S1_est[:, q], width, color=c(0.1), yerr=S1_se[:, q] * z,
131 label=r'$S_1$', capsize=3, linewidth=1, edgecolor=[0, 0, 0])
132 ax.bar(x + width / 2, ST_est[:, q], width, color=c(0.9), yerr=ST_se[:, q] * z,
133 label=r'$S_{T}$', capsize=3, linewidth=1, edgecolor=[0, 0, 0])
134 ax_default(ax, "Model parameters", "Sobol' index", legend=True)
135 ax.set_xticks(x, param_labels)
136 ax.set_ylim(bottom=0)
137 ax.set_title(qoi_labels[i])
138 fig.set_size_inches(4*len(qoi_idx), 4)
139 fig.tight_layout()
140 bar_chart = (fig, axs)
142 # Plot pie chart of S1, S2, higher-order
143 fig, axs = plt.subplots(1, len(qoi_idx))
144 for i in range(len(qoi_idx)):
145 ax = axs[i] if len(qoi_idx) > 1 else axs
146 q = qoi_idx[i]
147 values = []
148 labels = []
149 s12_other = 0
150 thresh = 0.05 # Only show indices with > 5% effect
151 for j in range(xdim):
152 if S1_est[j, q] > thresh:
153 values.append(S1_est[j, q])
154 labels.append(param_labels[j])
155 else:
156 s12_other += max(S1_est[j, q], 0)
158 if compute_s2:
159 for j in range(xdim):
160 for k in range(j+1, xdim):
161 if S2_est[j, k, q] > thresh:
162 values.append(S2_est[j, k, q])
163 labels.append("("+param_labels[j]+", "+param_labels[k]+")")
164 else:
165 s12_other += max(S2_est[j, k, q], 0)
167 values.append(max(s12_other, 0))
168 labels.append(r'Other $S_1$, $S_2$')
169 s_higher = max(1 - np.sum(values), 0)
170 values.append(s_higher)
171 labels.append(r'Higher order')
173 # Adjust labels to show percents, sort by value, and threshold small values for plotting
174 labels = [f"{label}, {100*values[i]:.1f}%" if values[i] > thresh else
175 f"{label}, <{max(0.5, round(100*values[i]))}%" for i, label in enumerate(labels)]
176 values = [val if val > thresh else max(0.02, val) for val in values]
177 labels, values = list(zip(*sorted(zip(labels, values), reverse=True, key=lambda ele: ele[1])))
179 # Generate pie chart
180 colors = c(np.linspace(0, 1, len(values)-2))
181 gray_idx = [idx for idx, label in enumerate(labels) if label.startswith('Higher') or
182 label.startswith('Other')]
183 pie_colors = np.empty((len(values), 4))
184 c_idx = 0
185 for idx in range(len(values)):
186 if idx in gray_idx:
187 pie_colors[idx, :] = [0.7, 0.7, 0.7, 1]
188 else:
189 pie_colors[idx, :] = colors[c_idx, :]
190 c_idx += 1
191 radius = 2
192 wedges, label_boxes = ax.pie(values, colors=pie_colors, radius=radius, startangle=270,
193 shadow=True, counterclock=False, frame=True,
194 wedgeprops=dict(linewidth=1.5, width=0.6*radius, edgecolor='w'),
195 textprops={'color': [0, 0, 0, 1], 'fontsize': 10, 'family': 'serif'})
196 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center", fontsize=9, family='serif',
197 bbox=dict(boxstyle="square,pad=0.3", fc="w", ec="k", lw=0))
199 # Put annotations with arrows to each wedge (coordinate system is relative to center of pie)
200 for j, wed in enumerate(wedges):
201 ang = (wed.theta2 - wed.theta1) / 2. + wed.theta1
202 x = radius * np.cos(np.deg2rad(ang))
203 y = radius * np.sin(np.deg2rad(ang))
204 ax.scatter(x, y, s=10, c='k')
205 kw["horizontalalignment"] = "right" if int(np.sign(x)) == -1 else "left"
206 kw["arrowprops"].update({"connectionstyle": f"angle,angleA=0,angleB={ang}"})
207 y_offset = 0.2 if j == len(labels) - 1 else 0
208 ax.annotate(labels[j], xy=(x, y), xytext=((radius+0.2)*np.sign(x), 1.3*y - y_offset), **kw)
209 ax.set(aspect="equal")
210 ax.spines['top'].set_visible(False)
211 ax.spines['right'].set_visible(False)
212 ax.spines['left'].set_visible(False)
213 ax.spines['bottom'].set_visible(False)
214 ax.get_yaxis().set_ticks([])
215 ax.get_xaxis().set_ticks([])
216 ax.set_title(qoi_labels[i])
217 fig.set_size_inches(3*radius*len(qoi_idx), 2.5*radius)
218 fig.tight_layout()
219 fig.subplots_adjust(left=0.15, right=0.75)
220 pie_chart = (fig, axs)
222 if compute_s2:
223 ret = (S1, S2, ST)
224 else:
225 ret = (S1, ST)
226 if plot:
227 ret = ret + (bar_chart, pie_chart)
228 return ret
231def ishigami(x, a=7.0, b=0.1):
232 """For testing Sobol indices: [Ishigami function](https://doi.org/10.1109/ISUMA.1990.151285)"""
233 return {'y': np.sin(x[..., 0:1]) + a*np.sin(x[..., 1:2])**2 + b*(x[..., 2:3]**4)*np.sin(x[..., 0:1])}