Coverage for src/uqtils/sobol.py: 100%

158 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 03:45 +0000

1"""Module for Sobol' sensitivity analysis. 

2 

3Includes: 

4 

5- `sobol_sa` - function for global sensitivity analysis 

6""" 

7import matplotlib.pyplot as plt 

8import numpy as np 

9import scipy.stats as st 

10 

11from uqtils import ax_default 

12 

13__all__ = ['sobol_sa', 'ishigami'] 

14 

15 

16def sobol_sa(model, sampler, num_samples: int, qoi_idx: list[int] = None, qoi_labels: list[str] = None, 

17 param_labels: list[str] = None, plot: bool = False, verbose: bool = True, cmap: str = 'viridis', 

18 compute_s2: bool = False): 

19 """Perform a global Sobol' sensitivity analysis. 

20 

21 :param model: callable as `y=model(x)`, with `y=(..., ydim)`, `x=(..., xdim)` 

22 :param sampler: callable as `x=sampler(shape)`, with `x=(*shape, xdim)` 

23 :param num_samples: number of samples 

24 :param qoi_idx: list of indices of model output to report results for 

25 :param qoi_labels: list of labels for plotting QoIs 

26 :param param_labels: list of labels for plotting input parameters 

27 :param plot: whether to plot bar/pie charts 

28 :param verbose: whether to print `S1/ST/S2` results to the console 

29 :param cmap: `str` specifier of `plt.colormap` for bar/pie charts 

30 :param compute_s2: whether to compute the second order indices 

31 :return: `S1`, `[S2]`, `ST`, the first, second, and total order Sobol' indices 

32 """ 

33 # Get sample matrices (N, xdim) 

34 A = sampler((num_samples,)) 

35 B = sampler((num_samples,)) 

36 xdim = A.shape[-1] 

37 AB = np.tile(np.expand_dims(A, axis=-2), (1, xdim, 1)) 

38 BA = np.tile(np.expand_dims(B, axis=-2), (1, xdim, 1)) 

39 for i in range(xdim): 

40 AB[:, i, i] = B[:, i] 

41 BA[:, i, i] = A[:, i] 

42 

43 # Evaluate the model; (xdim+2)*N evaluations required 

44 fA = model(A) # (N, ydim) 

45 fB = model(B) # (N, ydim) 

46 fAB = model(AB) # (N, xdim, ydim) 

47 fBA = model(BA) # (N, xdim, ydim) 

48 ydim = fA.shape[-1] 

49 

50 # Normalize model outputs to N(0, 1) for better stability 

51 Y = np.concatenate((fA, fB, fAB.reshape((-1, ydim)), fBA.reshape((-1, ydim))), axis=0) 

52 mu, std = np.mean(Y, axis=0), np.std(Y, axis=0) 

53 fA = (fA - mu) / std 

54 fB = (fB - mu) / std 

55 fAB = (fAB - mu) / std 

56 fBA = (fBA - mu) / std 

57 

58 # Compute sensitivity indices 

59 vY = np.var(np.concatenate((fA, fB), axis=0), axis=0) # (ydim,) 

60 fA = np.expand_dims(fA, axis=-2) # (N, 1, ydim) 

61 fB = np.expand_dims(fB, axis=-2) # (N, 1, ydim) 

62 S1 = fB * (fAB - fA) / vY # (N, xdim, ydim) 

63 ST = 0.5 * (fA - fAB)**2 / vY # (N, xdim, ydim) 

64 

65 # Second-order indices 

66 if compute_s2: 

67 Vij = np.expand_dims(fBA, axis=2) * np.expand_dims(fAB, axis=1) - \ 

68 np.expand_dims(fA, axis=1) * np.expand_dims(fB, axis=1) # (N, xdim, xdim, ydim) 

69 si = fB * (fAB - fA) 

70 Vi = np.expand_dims(si, axis=2) 

71 Vj = np.expand_dims(si, axis=1) 

72 S2 = (Vij - Vi - Vj) / vY # (N, xdim, xdim, ydim) 

73 S2_est = np.mean(S2, axis=0) 

74 S2_se = np.sqrt(np.var(S2, axis=0) / num_samples) 

75 

76 # Get mean values and MC error 

77 S1_est = np.mean(S1, axis=0) 

78 S1_se = np.sqrt(np.var(S1, axis=0) / num_samples) 

79 ST_est = np.mean(ST, axis=0) 

80 ST_se = np.sqrt(np.var(ST, axis=0) / num_samples) 

81 

82 # Set defaults for qoi indices/labels 

83 if qoi_idx is None: 

84 qoi_idx = list(np.arange(ydim)) 

85 if qoi_labels is None: 

86 qoi_labels = [f'QoI {i}' for i in range(len(qoi_idx))] 

87 if param_labels is None: 

88 param_labels = [f'x{i}' for i in range(xdim)] 

89 

90 # Print results 

91 if verbose: 

92 print(f'{"QoI":>10} {"Param":>10} {"S1_mean":>10} {"S1_err":>10} {"ST_mean":>10} {"ST_err":>10}') 

93 for i in range(len(qoi_idx)): 

94 for j in range(xdim): 

95 q = qoi_idx[i] 

96 print(f'{qoi_labels[i]:>10} {param_labels[j]:>10} {S1_est[j, q]: 10.3f} {S1_se[j, q]: 10.3f} ' 

97 f'{ST_est[j, q]: 10.3f} {ST_se[j, q]: 10.3f}') 

98 

99 if compute_s2: 

100 print(f'\n{"QoI":>10} {"2nd-order":>20} {"S2_mean":>10} {"S2_err":>10}') 

101 for i in range(len(qoi_idx)): 

102 for j in range(xdim): 

103 for k in range(j+1, xdim): 

104 q = qoi_idx[i] 

105 print(f'{qoi_labels[i]:>10} {"("+param_labels[j]+", "+param_labels[k]+")":>20} ' 

106 f'{S2_est[j, k, q]: 10.3f} {S2_se[j, k, q]: 10.3f}') 

107 

108 S1_total = np.sum(S1_est, axis=0) # (ydim,) 

109 S2_total = np.zeros((ydim,)) # (ydim,) 

110 if compute_s2: 

111 for i in range(xdim): 

112 for j in range(i+1, xdim): 

113 S2_total += S2_est[i, j, :] # sum the upper diagonal 

114 print(f'\n{"QoI":>10} {"S1 total":>10} {"S2 total":>10} {"Higher order":>15}') 

115 for i in range(len(qoi_idx)): 

116 q = qoi_idx[i] 

117 print(f'{qoi_labels[i]:>10} {S1_total[q]: 10.3f} {S2_total[q]: 10.3f} ' 

118 f'{1 - S1_total[q] - S2_total[q]: 15.3f}') 

119 

120 if plot: 

121 # Plot bar chart of S1, ST 

122 c = plt.get_cmap(cmap) 

123 fig, axs = plt.subplots(1, len(qoi_idx)) 

124 for i in range(len(qoi_idx)): 

125 ax = axs[i] if len(qoi_idx) > 1 else axs 

126 q = qoi_idx[i] 

127 z = st.norm.ppf(1 - (1-0.95)/2) # get z-score from N(0,1), assuming CLT at n>30 

128 x = np.arange(xdim) 

129 width = 0.2 

130 ax.bar(x - width / 2, S1_est[:, q], width, color=c(0.1), yerr=S1_se[:, q] * z, 

131 label=r'$S_1$', capsize=3, linewidth=1, edgecolor=[0, 0, 0]) 

132 ax.bar(x + width / 2, ST_est[:, q], width, color=c(0.9), yerr=ST_se[:, q] * z, 

133 label=r'$S_{T}$', capsize=3, linewidth=1, edgecolor=[0, 0, 0]) 

134 ax_default(ax, "Model parameters", "Sobol' index", legend=True) 

135 ax.set_xticks(x, param_labels) 

136 ax.set_ylim(bottom=0) 

137 ax.set_title(qoi_labels[i]) 

138 fig.set_size_inches(4*len(qoi_idx), 4) 

139 fig.tight_layout() 

140 bar_chart = (fig, axs) 

141 

142 # Plot pie chart of S1, S2, higher-order 

143 fig, axs = plt.subplots(1, len(qoi_idx)) 

144 for i in range(len(qoi_idx)): 

145 ax = axs[i] if len(qoi_idx) > 1 else axs 

146 q = qoi_idx[i] 

147 values = [] 

148 labels = [] 

149 s12_other = 0 

150 thresh = 0.05 # Only show indices with > 5% effect 

151 for j in range(xdim): 

152 if S1_est[j, q] > thresh: 

153 values.append(S1_est[j, q]) 

154 labels.append(param_labels[j]) 

155 else: 

156 s12_other += max(S1_est[j, q], 0) 

157 

158 if compute_s2: 

159 for j in range(xdim): 

160 for k in range(j+1, xdim): 

161 if S2_est[j, k, q] > thresh: 

162 values.append(S2_est[j, k, q]) 

163 labels.append("("+param_labels[j]+", "+param_labels[k]+")") 

164 else: 

165 s12_other += max(S2_est[j, k, q], 0) 

166 

167 values.append(max(s12_other, 0)) 

168 labels.append(r'Other $S_1$, $S_2$') 

169 s_higher = max(1 - np.sum(values), 0) 

170 values.append(s_higher) 

171 labels.append(r'Higher order') 

172 

173 # Adjust labels to show percents, sort by value, and threshold small values for plotting 

174 labels = [f"{label}, {100*values[i]:.1f}%" if values[i] > thresh else 

175 f"{label}, <{max(0.5, round(100*values[i]))}%" for i, label in enumerate(labels)] 

176 values = [val if val > thresh else max(0.02, val) for val in values] 

177 labels, values = list(zip(*sorted(zip(labels, values), reverse=True, key=lambda ele: ele[1]))) 

178 

179 # Generate pie chart 

180 colors = c(np.linspace(0, 1, len(values)-2)) 

181 gray_idx = [idx for idx, label in enumerate(labels) if label.startswith('Higher') or 

182 label.startswith('Other')] 

183 pie_colors = np.empty((len(values), 4)) 

184 c_idx = 0 

185 for idx in range(len(values)): 

186 if idx in gray_idx: 

187 pie_colors[idx, :] = [0.7, 0.7, 0.7, 1] 

188 else: 

189 pie_colors[idx, :] = colors[c_idx, :] 

190 c_idx += 1 

191 radius = 2 

192 wedges, label_boxes = ax.pie(values, colors=pie_colors, radius=radius, startangle=270, 

193 shadow=True, counterclock=False, frame=True, 

194 wedgeprops=dict(linewidth=1.5, width=0.6*radius, edgecolor='w'), 

195 textprops={'color': [0, 0, 0, 1], 'fontsize': 10, 'family': 'serif'}) 

196 kw = dict(arrowprops=dict(arrowstyle="-"), zorder=0, va="center", fontsize=9, family='serif', 

197 bbox=dict(boxstyle="square,pad=0.3", fc="w", ec="k", lw=0)) 

198 

199 # Put annotations with arrows to each wedge (coordinate system is relative to center of pie) 

200 for j, wed in enumerate(wedges): 

201 ang = (wed.theta2 - wed.theta1) / 2. + wed.theta1 

202 x = radius * np.cos(np.deg2rad(ang)) 

203 y = radius * np.sin(np.deg2rad(ang)) 

204 ax.scatter(x, y, s=10, c='k') 

205 kw["horizontalalignment"] = "right" if int(np.sign(x)) == -1 else "left" 

206 kw["arrowprops"].update({"connectionstyle": f"angle,angleA=0,angleB={ang}"}) 

207 y_offset = 0.2 if j == len(labels) - 1 else 0 

208 ax.annotate(labels[j], xy=(x, y), xytext=((radius+0.2)*np.sign(x), 1.3*y - y_offset), **kw) 

209 ax.set(aspect="equal") 

210 ax.spines['top'].set_visible(False) 

211 ax.spines['right'].set_visible(False) 

212 ax.spines['left'].set_visible(False) 

213 ax.spines['bottom'].set_visible(False) 

214 ax.get_yaxis().set_ticks([]) 

215 ax.get_xaxis().set_ticks([]) 

216 ax.set_title(qoi_labels[i]) 

217 fig.set_size_inches(3*radius*len(qoi_idx), 2.5*radius) 

218 fig.tight_layout() 

219 fig.subplots_adjust(left=0.15, right=0.75) 

220 pie_chart = (fig, axs) 

221 

222 if compute_s2: 

223 ret = (S1, S2, ST) 

224 else: 

225 ret = (S1, ST) 

226 if plot: 

227 ret = ret + (bar_chart, pie_chart) 

228 return ret 

229 

230 

231def ishigami(x, a=7.0, b=0.1): 

232 """For testing Sobol indices: [Ishigami function](https://doi.org/10.1109/ISUMA.1990.151285)""" 

233 return {'y': np.sin(x[..., 0:1]) + a*np.sin(x[..., 1:2])**2 + b*(x[..., 2:3]**4)*np.sin(x[..., 0:1])}