Coverage for src/amisc/interpolator.py: 95%

258 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-01-24 04:51 +0000

1"""Provides interpolator classes. Interpolators approximate the input → output mapping of a model given 

2a set of training data. The training data consists of input-output pairs, and the interpolator can be 

3refined with new training data. 

4 

5Includes: 

6 

7- `Interpolator`: Abstract class providing basic structure of an interpolator 

8- `Lagrange`: Concrete implementation for tensor-product barycentric Lagrange interpolation 

9- `InterpolatorState`: Interface for a dataclass that stores the internal state of an interpolator 

10- `LagrangeState`: The internal state for a barycentric Lagrange polynomial interpolator 

11""" 

12from __future__ import annotations 

13 

14import copy 

15import itertools 

16from abc import ABC, abstractmethod 

17from dataclasses import dataclass, field 

18 

19import numpy as np 

20 

21from amisc.serialize import Base64Serializable, Serializable, StringSerializable 

22from amisc.typing import Dataset, MultiIndex 

23 

24__all__ = ["InterpolatorState", "LagrangeState", "Interpolator", "Lagrange"] 

25 

26 

27class InterpolatorState(Serializable, ABC): 

28 """Interface for a dataclass that stores the internal state of an interpolator (e.g. weights and biases).""" 

29 pass 

30 

31 

32@dataclass 

33class LagrangeState(InterpolatorState, Base64Serializable): 

34 """The internal state for a barycentric Lagrange polynomial interpolator. 

35 

36 :ivar weights: the 1d interpolation grid weights 

37 :ivar x_grids: the 1d interpolation grids 

38 """ 

39 weights: dict[str, np.ndarray] = field(default_factory=dict) 

40 x_grids: dict[str, np.ndarray] = field(default_factory=dict) 

41 

42 def __eq__(self, other): 

43 if isinstance(other, LagrangeState): 

44 try: 

45 return all([np.allclose(self.weights[var], other.weights[var]) for var in self.weights]) and \ 

46 all([np.allclose(self.x_grids[var], other.x_grids[var]) for var in self.x_grids]) 

47 except IndexError: 

48 return False 

49 else: 

50 return False 

51 

52 

53class Interpolator(Serializable, ABC): 

54 """Interface for an interpolator object that approximates a model. An interpolator should: 

55 

56 - `refine` - take an old state and new training data and produce a new "refined" state (e.g. new weights/biases) 

57 - `predict` - interpolate from the training data to a new set of points (i.e. approximate the underlying model) 

58 - `gradient` - compute the grdient/Jacobian at new points (if you want) 

59 - `hessian` - compute the 2nd derivative/Hessian at new points (if you want) 

60 

61 Currently, only the `Lagrange` interpolator is supported and can be constructed from a configuration `dict` 

62 via `Interpolator.from_dict()`. 

63 """ 

64 

65 @abstractmethod 

66 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset], 

67 old_state: InterpolatorState, input_domains: dict[str, tuple]) -> InterpolatorState: 

68 """Refine the interpolator state with new training data. 

69 

70 :param beta: a multi-index specifying the fidelity "levels" of the new interpolator state (starts at (0,... 0)) 

71 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data 

72 :param old_state: the previous state of the interpolator (None if initializing the first state) 

73 :param input_domains: a `dict` mapping input variables to their corresponding domains 

74 :returns: the new "refined" interpolator state 

75 """ 

76 raise NotImplementedError 

77 

78 @abstractmethod 

79 def predict(self, x: dict | Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset: 

80 """Interpolate the output of the model at points `x` using the given state and training data 

81 

82 :param x: the input Dataset `dict` mapping input variables to locations at which to compute the interpolator 

83 :param state: the current state of the interpolator 

84 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state 

85 :returns: a Dataset `dict` mapping output variables to interpolator outputs 

86 """ 

87 raise NotImplementedError 

88 

89 def __call__(self, *args, **kwargs): 

90 return self.predict(*args, **kwargs) 

91 

92 @abstractmethod 

93 def gradient(self, x: Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset: 

94 """Evaluate the gradient/Jacobian at points `x` using the interpolator. 

95 

96 :param x: the input Dataset `dict` mapping input variables to locations at which to evaluate the Jacobian 

97 :param state: the current state of the interpolator 

98 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state 

99 :returns: a Dataset `dict` mapping output variables to Jacobian matrices of shape `(ydim, xdim)` -- for 

100 scalar outputs, the Jacobian is returned as `(xdim,)` 

101 """ 

102 raise NotImplementedError 

103 

104 @abstractmethod 

105 def hessian(self, x: Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset: 

106 """Evaluate the Hessian at points `x` using the interpolator. 

107 

108 :param x: the input Dataset `dict` mapping input variables to locations at which to evaluate the Hessian 

109 :param state: the current state of the interpolator 

110 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state 

111 :returns: a Dataset `dict` mapping output variables to Hessian matrices of shape `(xdim, xdim)` 

112 """ 

113 raise NotImplementedError 

114 

115 @classmethod 

116 def from_dict(cls, config: dict) -> Interpolator: 

117 """Create an `Interpolator` object from a `dict` config. Only `method='lagrange'` is supported for now.""" 

118 method = config.pop('method', 'lagrange').lower() 

119 match method: 

120 case 'lagrange': 

121 return Lagrange(**config) 

122 case other: 

123 raise NotImplementedError(f"Unknown interpolator method: {other}") 

124 

125 

126@dataclass 

127class Lagrange(Interpolator, StringSerializable): 

128 """Implementation of a tensor-product barycentric Lagrange polynomial interpolator. A `LagrangeState` stores 

129 the 1d interpolation grids and weights for each input dimension. `Lagrange` computes the tensor-product 

130 of 1d Lagrange polynomials to approximate a multi-variate function. 

131 

132 :ivar interval_capacity: tuning knob for Lagrange interpolation (see Berrut and Trefethen 2004) 

133 """ 

134 interval_capacity: float = 4.0 

135 

136 @staticmethod 

137 def _extend_grids(x_grids: dict[str, np.ndarray], x_points: dict[str, np.ndarray]): 

138 """Extend the 1d `x` grids with any new points from `x_points`, skipping duplicates. This will preserve the 

139 order of new points in the extended grid without duplication. This will maintain the same order as 

140 `SparseGrid.x_grids` if that is the underlying training data structure. 

141 

142 !!! Example 

143 ```python 

144 x = {'x': np.array([0, 1, 2])} 

145 new_x = {'x': np.array([3, 0, 2, 3, 1, 4])} 

146 extended_x = Lagrange._extend_grids(x, new_x) 

147 # gives {'x': np.array([0, 1, 2, 3, 4])} 

148 ``` 

149 

150 !!! Warning 

151 This will only work for 1d grids; all `x_grids` should be scalar quantities. Field quantities should 

152 already be passed in as several separate 1d latent coefficients. 

153 

154 :param x_grids: the current 1d interpolation grids 

155 :param x_points: the new points to extend the interpolation grids with 

156 :returns: the extended grids 

157 """ 

158 extended_grids = copy.deepcopy(x_grids) 

159 for var, new_pts in x_points.items(): 

160 # Get unique new values that are not already in the grid (maintain their order; keep only first one) 

161 u = x_grids[var] if var in x_grids else np.array([]) 

162 u, ind = np.unique(new_pts[~np.isin(new_pts, u)], return_index=True) 

163 u = u[np.argsort(ind)] 

164 extended_grids[var] = u if var not in x_grids else np.concatenate((x_grids[var], u), axis=0) 

165 

166 return extended_grids 

167 

168 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset], 

169 old_state: LagrangeState, input_domains: dict[str, tuple]) -> LagrangeState: 

170 """Refine the interpolator state with new training data. 

171 

172 :param beta: the refinement level indices for the interpolator (not used for `Lagrange`) 

173 :param training_data: a tuple of dictionaries containing the new training data (`xtrain`, `ytrain`) 

174 :param old_state: the old interpolator state to refine (None if initializing) 

175 :param input_domains: a `dict` of each input variable's domain; input keys should match `xtrain` keys 

176 :returns: the new interpolator state 

177 """ 

178 xtrain, ytrain = training_data # Lagrange only really needs the xtrain data to update barycentric weights/grids 

179 

180 # Initialize the interpolator state 

181 if old_state is None: 

182 x_grids = self._extend_grids({}, xtrain) 

183 weights = {} 

184 for var, grid in x_grids.items(): 

185 bds = input_domains[var] 

186 Nx = grid.shape[0] 

187 C = (bds[1] - bds[0]) / self.interval_capacity # Interval capacity (see Berrut and Trefethen 2004) 

188 xj = grid.reshape((Nx, 1)) 

189 xi = xj.reshape((1, Nx)) 

190 dist = (xj - xi) / C 

191 np.fill_diagonal(dist, 1) # Ignore product when i==j 

192 weights[var] = (1.0 / np.prod(dist, axis=1)) # (Nx,) 

193 

194 # Otherwise, refine the interpolator state 

195 else: 

196 x_grids = self._extend_grids(old_state.x_grids, xtrain) 

197 weights = copy.deepcopy(old_state.weights) 

198 for var, grid in x_grids.items(): 

199 bds = input_domains[var] 

200 Nx_old = old_state.x_grids[var].shape[0] 

201 Nx_new = grid.shape[0] 

202 if Nx_new > Nx_old: 

203 weights[var] = np.pad(weights[var], [(0, Nx_new - Nx_old)], mode='constant', constant_values=np.nan) 

204 C = (bds[1] - bds[0]) / self.interval_capacity 

205 for j in range(Nx_old, Nx_new): 

206 weights[var][:j] *= (C / (grid[:j] - grid[j])) 

207 weights[var][j] = np.prod(C / (grid[j] - grid[:j])) 

208 

209 return LagrangeState(weights=weights, x_grids=x_grids) 

210 

211 def predict(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]): 

212 """Predict the output of the model at points `x` with barycentric Lagrange interpolation.""" 

213 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim) 

214 # Inputs `x` may come in unordered, but they should get realigned with the internal `x_grids` state 

215 xi, yi = training_data 

216 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1) 

217 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1) 

218 

219 xdim = x_arr.shape[-1] 

220 ydim = yi_arr.shape[-1] 

221 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()} 

222 max_size = max(grid_sizes.values()) 

223 dims = list(range(xdim)) 

224 

225 # Create ragged edge matrix of interpolation pts and weights 

226 x_j = np.full((xdim, max_size), np.nan) # For example: 

227 w_j = np.full((xdim, max_size), np.nan) # A= [#####-- 

228 for n, var in enumerate(state.x_grids): # ####### 

229 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----] 

230 w_j[n, :grid_sizes[var]] = state.weights[var] 

231 

232 diff = x_arr[..., np.newaxis] - x_j 

233 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8) 

234 check_interp_pts = np.sum(div_zero_idx) > 0 # whether we are evaluating directly on some interp pts 

235 diff[div_zero_idx] = 1 

236 quotient = w_j / diff # (..., xdim, Nx) 

237 qsum = np.nansum(quotient, axis=-1) # (..., xdim) 

238 y = np.zeros(x_arr.shape[:-1] + (ydim,)) # (..., ydim) 

239 

240 # Loop over multi-indices and compute tensor-product lagrange polynomials 

241 indices = [range(s) for s in grid_sizes.values()] 

242 for i, j in enumerate(itertools.product(*indices)): 

243 L_j = quotient[..., dims, j] / qsum # (..., xdim) 

244 

245 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j 

246 if check_interp_pts: 

247 other_pts = np.copy(div_zero_idx) 

248 other_pts[div_zero_idx[..., dims, j]] = False 

249 L_j[div_zero_idx[..., dims, j]] = 1 

250 L_j[np.any(other_pts, axis=-1)] = 0 

251 

252 # Add multivariate basis polynomial contribution to interpolation output 

253 y += np.prod(L_j, axis=-1, keepdims=True) * yi_arr[i, :] 

254 

255 # Unpack the outputs back into a Dataset 

256 y_ret = {} 

257 start_idx = 0 

258 for var, arr in yi.items(): 

259 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1 

260 end_idx = start_idx + num_vals 

261 y_ret[var] = y[..., start_idx:end_idx] 

262 if len(arr.shape) == 1: 

263 y_ret[var] = np.squeeze(y_ret[var], axis=-1) # for scalars 

264 start_idx = end_idx 

265 return y_ret 

266 

267 def gradient(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]): 

268 """Evaluate the gradient/Jacobian at points `x` using the interpolator.""" 

269 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim) 

270 xi, yi = training_data 

271 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1) 

272 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1) 

273 

274 xdim = x_arr.shape[-1] 

275 ydim = yi_arr.shape[-1] 

276 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()} 

277 max_size = max(grid_sizes.values()) 

278 

279 # Create ragged edge matrix of interpolation pts and weights 

280 x_j = np.full((xdim, max_size), np.nan) # For example: 

281 w_j = np.full((xdim, max_size), np.nan) # A= [#####-- 

282 for n, var in enumerate(state.x_grids): # ####### 

283 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----] 

284 w_j[n, :grid_sizes[var]] = state.weights[var] 

285 

286 # Compute values ahead of time that will be needed for the gradient 

287 diff = x_arr[..., np.newaxis] - x_j 

288 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8) 

289 check_interp_pts = np.sum(div_zero_idx) > 0 

290 diff[div_zero_idx] = 1 

291 quotient = w_j / diff # (..., xdim, Nx) 

292 qsum = np.nansum(quotient, axis=-1) # (..., xdim) 

293 sqsum = np.nansum(w_j / diff ** 2, axis=-1) # (..., xdim) 

294 jac = np.zeros(x_arr.shape[:-1] + (ydim, xdim)) # (..., ydim, xdim) 

295 

296 # Loop over multi-indices and compute derivative of tensor-product lagrange polynomials 

297 indices = [range(s) for s in grid_sizes.values()] 

298 for k, var in enumerate(grid_sizes): 

299 dims = [idx for idx in range(xdim) if idx != k] 

300 for i, j in enumerate(itertools.product(*indices)): 

301 j_dims = [j[idx] for idx in dims] 

302 L_j = quotient[..., dims, j_dims] / qsum[..., dims] # (..., xdim-1) 

303 

304 # Partial derivative of L_j with respect to x_k 

305 dLJ_dx = ((w_j[k, j[k]] / (qsum[..., k] * diff[..., k, j[k]])) * 

306 (sqsum[..., k] / qsum[..., k] - 1 / diff[..., k, j[k]])) 

307 

308 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j 

309 if check_interp_pts: 

310 other_pts = np.copy(div_zero_idx) 

311 other_pts[div_zero_idx[..., list(range(xdim)), j]] = False 

312 L_j[div_zero_idx[..., dims, j_dims]] = 1 

313 L_j[np.any(other_pts[..., dims, :], axis=-1)] = 0 

314 

315 # Set derivatives when x is at the interpolation points (i.e. x==x_j) 

316 p_idx = [idx for idx in range(grid_sizes[var]) if idx != j[k]] 

317 w_j_large = np.broadcast_to(w_j[k, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy() 

318 curr_j_idx = div_zero_idx[..., k, j[k]] 

319 other_j_idx = np.any(other_pts[..., k, :], axis=-1) 

320 dLJ_dx[curr_j_idx] = -np.nansum((w_j[k, p_idx] / w_j[k, j[k]]) / 

321 (x_arr[curr_j_idx, k, np.newaxis] - x_j[k, p_idx]), axis=-1) 

322 dLJ_dx[other_j_idx] = ((w_j[k, j[k]] / w_j_large[other_pts[..., k, :]]) / 

323 (x_arr[other_j_idx, k] - x_j[k, j[k]])) 

324 

325 dLJ_dx = np.expand_dims(dLJ_dx, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1) 

326 

327 # Add contribution to the Jacobian 

328 jac[..., k] += dLJ_dx * yi_arr[i, :] 

329 

330 # Unpack the outputs back into a Dataset (array of length xdim for each y_var giving partial derivatives) 

331 jac_ret = {} 

332 start_idx = 0 

333 for var, arr in yi.items(): 

334 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1 

335 end_idx = start_idx + num_vals 

336 jac_ret[var] = jac[..., start_idx:end_idx, :] # (..., ydim, xdim) 

337 if len(arr.shape) == 1: 

338 jac_ret[var] = np.squeeze(jac_ret[var], axis=-2) # for scalars: (..., xdim) partial derivatives 

339 start_idx = end_idx 

340 return jac_ret 

341 

342 def hessian(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]): 

343 """Evaluate the Hessian at points `x` using the interpolator.""" 

344 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim) 

345 xi, yi = training_data 

346 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1) 

347 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1) 

348 

349 xdim = x_arr.shape[-1] 

350 ydim = yi_arr.shape[-1] 

351 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()} 

352 grid_size_list = list(grid_sizes.values()) 

353 max_size = max(grid_size_list) 

354 

355 # Create ragged edge matrix of interpolation pts and weights 

356 x_j = np.full((xdim, max_size), np.nan) # For example: 

357 w_j = np.full((xdim, max_size), np.nan) # A= [#####-- 

358 for n, var in enumerate(state.x_grids): # ####### 

359 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----] 

360 w_j[n, :grid_sizes[var]] = state.weights[var] 

361 

362 # Compute values ahead of time that will be needed for the gradient 

363 diff = x_arr[..., np.newaxis] - x_j 

364 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8) 

365 check_interp_pts = np.sum(div_zero_idx) > 0 

366 diff[div_zero_idx] = 1 

367 quotient = w_j / diff # (..., xdim, Nx) 

368 qsum = np.nansum(quotient, axis=-1) # (..., xdim) 

369 qsum_p = -np.nansum(w_j / diff ** 2, axis=-1) # (..., xdim) 

370 qsum_pp = 2 * np.nansum(w_j / diff ** 3, axis=-1) # (..., xdim) 

371 

372 # Loop over multi-indices and compute 2nd derivative of tensor-product lagrange polynomials 

373 hess = np.zeros(x_arr.shape[:-1] + (ydim, xdim, xdim)) # (..., ydim, xdim, xdim) 

374 indices = [range(s) for s in grid_size_list] 

375 for m in range(xdim): 

376 for n in range(m, xdim): 

377 dims = [idx for idx in range(xdim) if idx not in [m, n]] 

378 for i, j in enumerate(itertools.product(*indices)): 

379 j_dims = [j[idx] for idx in dims] 

380 L_j = quotient[..., dims, j_dims] / qsum[..., dims] # (..., xdim-2) 

381 

382 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j 

383 if check_interp_pts: 

384 other_pts = np.copy(div_zero_idx) 

385 other_pts[div_zero_idx[..., list(range(xdim)), j]] = False 

386 L_j[div_zero_idx[..., dims, j_dims]] = 1 

387 L_j[np.any(other_pts[..., dims, :], axis=-1)] = 0 

388 

389 # Cross-terms in Hessian 

390 if m != n: 

391 # Partial derivative of L_j with respect to x_m and x_n 

392 d2LJ_dx2 = np.ones(x_arr.shape[:-1]) 

393 for k in [m, n]: 

394 dLJ_dx = ((w_j[k, j[k]] / (qsum[..., k] * diff[..., k, j[k]])) * 

395 (-qsum_p[..., k] / qsum[..., k] - 1 / diff[..., k, j[k]])) 

396 

397 # Set derivatives when x is at the interpolation points (i.e. x==x_j) 

398 if check_interp_pts: 

399 p_idx = [idx for idx in range(grid_size_list[k]) if idx != j[k]] 

400 w_j_large = np.broadcast_to(w_j[k, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy() 

401 curr_j_idx = div_zero_idx[..., k, j[k]] 

402 other_j_idx = np.any(other_pts[..., k, :], axis=-1) 

403 dLJ_dx[curr_j_idx] = -np.nansum((w_j[k, p_idx] / w_j[k, j[k]]) / 

404 (x_arr[curr_j_idx, k, np.newaxis] - x_j[k, p_idx]), 

405 axis=-1) 

406 dLJ_dx[other_j_idx] = ((w_j[k, j[k]] / w_j_large[other_pts[..., k, :]]) / 

407 (x_arr[other_j_idx, k] - x_j[k, j[k]])) 

408 

409 d2LJ_dx2 *= dLJ_dx 

410 

411 d2LJ_dx2 = np.expand_dims(d2LJ_dx2, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1) 

412 hess[..., m, n] += d2LJ_dx2 * yi_arr[i, :] 

413 hess[..., n, m] += d2LJ_dx2 * yi_arr[i, :] 

414 

415 # Diagonal terms in Hessian: 

416 else: 

417 front_term = w_j[m, j[m]] / (qsum[..., m] * diff[..., m, j[m]]) 

418 first_term = (-qsum_pp[..., m] / qsum[..., m]) + 2 * (qsum_p[..., m] / qsum[..., m]) ** 2 

419 second_term = (2 * (qsum_p[..., m] / (qsum[..., m] * diff[..., m, j[m]])) 

420 + 2 / diff[..., m, j[m]] ** 2) 

421 d2LJ_dx2 = front_term * (first_term + second_term) 

422 

423 # Set derivatives when x is at the interpolation points (i.e. x==x_j) 

424 if check_interp_pts: 

425 curr_j_idx = div_zero_idx[..., m, j[m]] 

426 other_j_idx = np.any(other_pts[..., m, :], axis=-1) 

427 if np.any(curr_j_idx) or np.any(other_j_idx): 

428 p_idx = [idx for idx in range(grid_size_list[m]) if idx != j[m]] 

429 w_j_large = np.broadcast_to(w_j[m, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy() 

430 x_j_large = np.broadcast_to(x_j[m, :], x_arr.shape[:-1] + x_j.shape[-1:]).copy() 

431 

432 # if these points are at the current j interpolation point 

433 d2LJ_dx2[curr_j_idx] = (2 * np.nansum((w_j[m, p_idx] / w_j[m, j[m]]) / 

434 (x_arr[curr_j_idx, m, np.newaxis] - x_j[m, p_idx]), # noqa: E501 

435 axis=-1) ** 2 + 

436 2 * np.nansum((w_j[m, p_idx] / w_j[m, j[m]]) / 

437 (x_arr[curr_j_idx, m, np.newaxis] - x_j[m, p_idx]) ** 2, # noqa: E501 

438 axis=-1)) 

439 

440 # if these points are at any other interpolation point 

441 other_pts_inv = other_pts.copy() 

442 other_pts_inv[other_j_idx, m, :grid_size_list[m]] = np.invert( 

443 other_pts[other_j_idx, m, :grid_size_list[m]]) # noqa: E501 

444 curr_x_j = x_j_large[other_pts[..., m, :]].reshape((-1, 1)) 

445 other_x_j = x_j_large[other_pts_inv[..., m, :]].reshape((-1, len(p_idx))) 

446 curr_w_j = w_j_large[other_pts[..., m, :]].reshape((-1, 1)) 

447 other_w_j = w_j_large[other_pts_inv[..., m, :]].reshape((-1, len(p_idx))) 

448 curr_div = w_j[m, j[m]] / np.squeeze(curr_w_j, axis=-1) 

449 curr_diff = np.squeeze(curr_x_j, axis=-1) - x_j[m, j[m]] 

450 d2LJ_dx2[other_j_idx] = ((-2 * curr_div / curr_diff) * (np.nansum( 

451 (other_w_j / curr_w_j) / (curr_x_j - other_x_j), axis=-1) + 1 / curr_diff)) 

452 

453 d2LJ_dx2 = np.expand_dims(d2LJ_dx2, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1) 

454 hess[..., m, n] += d2LJ_dx2 * yi_arr[i, :] 

455 

456 # Unpack the outputs back into a Dataset (matrix (xdim, xdim) for each y_var giving 2nd partial derivatives) 

457 hess_ret = {} 

458 start_idx = 0 

459 for var, arr in yi.items(): 

460 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1 

461 end_idx = start_idx + num_vals 

462 hess_ret[var] = hess[..., start_idx:end_idx, :, :] # (..., ydim, xdim, xdim) 

463 if len(arr.shape) == 1: 

464 hess_ret[var] = np.squeeze(hess_ret[var], axis=-3) # for scalars: (..., xdim, xdim) partial derivatives 

465 start_idx = end_idx 

466 return hess_ret