Coverage for src/amisc/interpolator.py: 94%

320 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-03-10 15:12 +0000

1"""Provides interpolator classes. Interpolators approximate the input → output mapping of a model given 

2a set of training data. The training data consists of input-output pairs, and the interpolator can be 

3refined with new training data. 

4 

5Includes: 

6 

7- `Interpolator`: Abstract class providing basic structure of an interpolator 

8- `Lagrange`: Concrete implementation for tensor-product barycentric Lagrange interpolation 

9- `Linear`: Concrete implementation for linear regression using `sklearn` 

10- `InterpolatorState`: Interface for a dataclass that stores the internal state of an interpolator 

11- `LagrangeState`: The internal state for a barycentric Lagrange polynomial interpolator 

12- `LinearState`: The internal state for a linear interpolator (using sklearn) 

13""" 

14from __future__ import annotations 

15 

16import copy 

17import itertools 

18from abc import ABC, abstractmethod 

19from dataclasses import dataclass, field 

20 

21import numpy as np 

22from sklearn import linear_model, preprocessing 

23from sklearn.pipeline import Pipeline 

24from sklearn.preprocessing import PolynomialFeatures 

25 

26from amisc.serialize import Base64Serializable, Serializable, StringSerializable 

27from amisc.typing import Dataset, MultiIndex 

28 

29__all__ = ["InterpolatorState", "LagrangeState", "LinearState", "Interpolator", "Lagrange", "Linear"] 

30 

31 

32class InterpolatorState(Serializable, ABC): 

33 """Interface for a dataclass that stores the internal state of an interpolator (e.g. weights and biases).""" 

34 pass 

35 

36 

37@dataclass 

38class LagrangeState(InterpolatorState, Base64Serializable): 

39 """The internal state for a barycentric Lagrange polynomial interpolator. 

40 

41 :ivar weights: the 1d interpolation grid weights 

42 :ivar x_grids: the 1d interpolation grids 

43 """ 

44 weights: dict[str, np.ndarray] = field(default_factory=dict) 

45 x_grids: dict[str, np.ndarray] = field(default_factory=dict) 

46 

47 def __eq__(self, other): 

48 if isinstance(other, LagrangeState): 

49 try: 

50 return all([np.allclose(self.weights[var], other.weights[var]) for var in self.weights]) and \ 

51 all([np.allclose(self.x_grids[var], other.x_grids[var]) for var in self.x_grids]) 

52 except IndexError: 

53 return False 

54 else: 

55 return False 

56 

57 

58@dataclass 

59class LinearState(InterpolatorState, Base64Serializable): 

60 """The internal state for a linear interpolator (using sklearn). 

61 

62 :ivar x_vars: the input variables in order 

63 :ivar y_vars: the output variables in order 

64 :ivar regressor: the sklearn regressor object, a pipeline that consists of a `PolynomialFeatures` and a model from 

65 `sklearn.linear_model`, i.e. Ridge, Lasso, etc. 

66 """ 

67 x_vars: list[str] = field(default_factory=list) 

68 y_vars: list[str] = field(default_factory=list) 

69 regressor: Pipeline = None 

70 

71 def __eq__(self, other): 

72 if isinstance(other, LinearState): 

73 return (self.x_vars == other.x_vars and self.y_vars == other.y_vars and 

74 np.allclose(self.regressor['poly'].powers_, other.regressor['poly'].powers_) and 

75 np.allclose(self.regressor['linear'].coef_, other.regressor['linear'].coef_) and 

76 np.allclose(self.regressor['linear'].intercept_, other.regressor['linear'].intercept_)) 

77 else: 

78 return False 

79 

80 

81class Interpolator(Serializable, ABC): 

82 """Interface for an interpolator object that approximates a model. An interpolator should: 

83 

84 - `refine` - take an old state and new training data and produce a new "refined" state (e.g. new weights/biases) 

85 - `predict` - interpolate from the training data to a new set of points (i.e. approximate the underlying model) 

86 - `gradient` - compute the grdient/Jacobian at new points (if you want) 

87 - `hessian` - compute the 2nd derivative/Hessian at new points (if you want) 

88 

89 Currently, `Lagrange` and `Linear` interpolators are supported and can be constructed from a configuration `dict` 

90 via `Interpolator.from_dict()`. 

91 """ 

92 

93 @abstractmethod 

94 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset], 

95 old_state: InterpolatorState, input_domains: dict[str, tuple]) -> InterpolatorState: 

96 """Refine the interpolator state with new training data. 

97 

98 :param beta: a multi-index specifying the fidelity "levels" of the new interpolator state (starts at (0,... 0)) 

99 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data 

100 :param old_state: the previous state of the interpolator (None if initializing the first state) 

101 :param input_domains: a `dict` mapping input variables to their corresponding domains 

102 :returns: the new "refined" interpolator state 

103 """ 

104 raise NotImplementedError 

105 

106 @abstractmethod 

107 def predict(self, x: dict | Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset: 

108 """Interpolate the output of the model at points `x` using the given state and training data 

109 

110 :param x: the input Dataset `dict` mapping input variables to locations at which to compute the interpolator 

111 :param state: the current state of the interpolator 

112 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state 

113 :returns: a Dataset `dict` mapping output variables to interpolator outputs 

114 """ 

115 raise NotImplementedError 

116 

117 def __call__(self, *args, **kwargs): 

118 return self.predict(*args, **kwargs) 

119 

120 @abstractmethod 

121 def gradient(self, x: Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset: 

122 """Evaluate the gradient/Jacobian at points `x` using the interpolator. 

123 

124 :param x: the input Dataset `dict` mapping input variables to locations at which to evaluate the Jacobian 

125 :param state: the current state of the interpolator 

126 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state 

127 :returns: a Dataset `dict` mapping output variables to Jacobian matrices of shape `(ydim, xdim)` -- for 

128 scalar outputs, the Jacobian is returned as `(xdim,)` 

129 """ 

130 raise NotImplementedError 

131 

132 @abstractmethod 

133 def hessian(self, x: Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset: 

134 """Evaluate the Hessian at points `x` using the interpolator. 

135 

136 :param x: the input Dataset `dict` mapping input variables to locations at which to evaluate the Hessian 

137 :param state: the current state of the interpolator 

138 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state 

139 :returns: a Dataset `dict` mapping output variables to Hessian matrices of shape `(xdim, xdim)` 

140 """ 

141 raise NotImplementedError 

142 

143 @classmethod 

144 def from_dict(cls, config: dict) -> Interpolator: 

145 """Create an `Interpolator` object from a `dict` config. Only `method='lagrange'` is supported for now.""" 

146 method = config.pop('method', 'lagrange').lower() 

147 match method: 

148 case 'lagrange': 

149 return Lagrange(**config) 

150 case 'linear': 

151 return Linear(**config) 

152 case other: 

153 raise NotImplementedError(f"Unknown interpolator method: {other}") 

154 

155 

156@dataclass 

157class Lagrange(Interpolator, StringSerializable): 

158 """Implementation of a tensor-product barycentric Lagrange polynomial interpolator. A `LagrangeState` stores 

159 the 1d interpolation grids and weights for each input dimension. `Lagrange` computes the tensor-product 

160 of 1d Lagrange polynomials to approximate a multi-variate function. 

161 

162 :ivar interval_capacity: tuning knob for Lagrange interpolation (see Berrut and Trefethen 2004) 

163 """ 

164 interval_capacity: float = 4.0 

165 

166 @staticmethod 

167 def _extend_grids(x_grids: dict[str, np.ndarray], x_points: dict[str, np.ndarray]): 

168 """Extend the 1d `x` grids with any new points from `x_points`, skipping duplicates. This will preserve the 

169 order of new points in the extended grid without duplication. This will maintain the same order as 

170 `SparseGrid.x_grids` if that is the underlying training data structure. 

171 

172 !!! Example 

173 ```python 

174 x = {'x': np.array([0, 1, 2])} 

175 new_x = {'x': np.array([3, 0, 2, 3, 1, 4])} 

176 extended_x = Lagrange._extend_grids(x, new_x) 

177 # gives {'x': np.array([0, 1, 2, 3, 4])} 

178 ``` 

179 

180 !!! Warning 

181 This will only work for 1d grids; all `x_grids` should be scalar quantities. Field quantities should 

182 already be passed in as several separate 1d latent coefficients. 

183 

184 :param x_grids: the current 1d interpolation grids 

185 :param x_points: the new points to extend the interpolation grids with 

186 :returns: the extended grids 

187 """ 

188 extended_grids = copy.deepcopy(x_grids) 

189 for var, new_pts in x_points.items(): 

190 # Get unique new values that are not already in the grid (maintain their order; keep only first one) 

191 u = x_grids[var] if var in x_grids else np.array([]) 

192 u, ind = np.unique(new_pts[~np.isin(new_pts, u)], return_index=True) 

193 u = u[np.argsort(ind)] 

194 extended_grids[var] = u if var not in x_grids else np.concatenate((x_grids[var], u), axis=0) 

195 

196 return extended_grids 

197 

198 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset], 

199 old_state: LagrangeState, input_domains: dict[str, tuple]) -> LagrangeState: 

200 """Refine the interpolator state with new training data. 

201 

202 :param beta: the refinement level indices for the interpolator (not used for `Lagrange`) 

203 :param training_data: a tuple of dictionaries containing the new training data (`xtrain`, `ytrain`) 

204 :param old_state: the old interpolator state to refine (None if initializing) 

205 :param input_domains: a `dict` of each input variable's domain; input keys should match `xtrain` keys 

206 :returns: the new interpolator state 

207 """ 

208 xtrain, ytrain = training_data # Lagrange only really needs the xtrain data to update barycentric weights/grids 

209 

210 # Initialize the interpolator state 

211 if old_state is None: 

212 x_grids = self._extend_grids({}, xtrain) 

213 weights = {} 

214 for var, grid in x_grids.items(): 

215 bds = input_domains[var] 

216 Nx = grid.shape[0] 

217 C = (bds[1] - bds[0]) / self.interval_capacity # Interval capacity (see Berrut and Trefethen 2004) 

218 xj = grid.reshape((Nx, 1)) 

219 xi = xj.reshape((1, Nx)) 

220 dist = (xj - xi) / C 

221 np.fill_diagonal(dist, 1) # Ignore product when i==j 

222 weights[var] = (1.0 / np.prod(dist, axis=1)) # (Nx,) 

223 

224 # Otherwise, refine the interpolator state 

225 else: 

226 x_grids = self._extend_grids(old_state.x_grids, xtrain) 

227 weights = copy.deepcopy(old_state.weights) 

228 for var, grid in x_grids.items(): 

229 bds = input_domains[var] 

230 Nx_old = old_state.x_grids[var].shape[0] 

231 Nx_new = grid.shape[0] 

232 if Nx_new > Nx_old: 

233 weights[var] = np.pad(weights[var], [(0, Nx_new - Nx_old)], mode='constant', constant_values=np.nan) 

234 C = (bds[1] - bds[0]) / self.interval_capacity 

235 for j in range(Nx_old, Nx_new): 

236 weights[var][:j] *= (C / (grid[:j] - grid[j])) 

237 weights[var][j] = np.prod(C / (grid[j] - grid[:j])) 

238 

239 return LagrangeState(weights=weights, x_grids=x_grids) 

240 

241 def predict(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]): 

242 """Predict the output of the model at points `x` with barycentric Lagrange interpolation.""" 

243 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim) 

244 # Inputs `x` may come in unordered, but they should get realigned with the internal `x_grids` state 

245 xi, yi = training_data 

246 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1) 

247 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1) 

248 

249 xdim = x_arr.shape[-1] 

250 ydim = yi_arr.shape[-1] 

251 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()} 

252 max_size = max(grid_sizes.values()) 

253 dims = list(range(xdim)) 

254 

255 # Create ragged edge matrix of interpolation pts and weights 

256 x_j = np.full((xdim, max_size), np.nan) # For example: 

257 w_j = np.full((xdim, max_size), np.nan) # A= [#####-- 

258 for n, var in enumerate(state.x_grids): # ####### 

259 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----] 

260 w_j[n, :grid_sizes[var]] = state.weights[var] 

261 

262 diff = x_arr[..., np.newaxis] - x_j 

263 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8) 

264 check_interp_pts = np.sum(div_zero_idx) > 0 # whether we are evaluating directly on some interp pts 

265 diff[div_zero_idx] = 1 

266 quotient = w_j / diff # (..., xdim, Nx) 

267 qsum = np.nansum(quotient, axis=-1) # (..., xdim) 

268 y = np.zeros(x_arr.shape[:-1] + (ydim,)) # (..., ydim) 

269 

270 # Loop over multi-indices and compute tensor-product lagrange polynomials 

271 indices = [range(s) for s in grid_sizes.values()] 

272 for i, j in enumerate(itertools.product(*indices)): 

273 L_j = quotient[..., dims, j] / qsum # (..., xdim) 

274 

275 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j 

276 if check_interp_pts: 

277 other_pts = np.copy(div_zero_idx) 

278 other_pts[div_zero_idx[..., dims, j]] = False 

279 L_j[div_zero_idx[..., dims, j]] = 1 

280 L_j[np.any(other_pts, axis=-1)] = 0 

281 

282 # Add multivariate basis polynomial contribution to interpolation output 

283 y += np.prod(L_j, axis=-1, keepdims=True) * yi_arr[i, :] 

284 

285 # Unpack the outputs back into a Dataset 

286 y_ret = {} 

287 start_idx = 0 

288 for var, arr in yi.items(): 

289 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1 

290 end_idx = start_idx + num_vals 

291 y_ret[var] = y[..., start_idx:end_idx] 

292 if len(arr.shape) == 1: 

293 y_ret[var] = np.squeeze(y_ret[var], axis=-1) # for scalars 

294 start_idx = end_idx 

295 return y_ret 

296 

297 def gradient(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]): 

298 """Evaluate the gradient/Jacobian at points `x` using the interpolator.""" 

299 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim) 

300 xi, yi = training_data 

301 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1) 

302 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1) 

303 

304 xdim = x_arr.shape[-1] 

305 ydim = yi_arr.shape[-1] 

306 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()} 

307 max_size = max(grid_sizes.values()) 

308 

309 # Create ragged edge matrix of interpolation pts and weights 

310 x_j = np.full((xdim, max_size), np.nan) # For example: 

311 w_j = np.full((xdim, max_size), np.nan) # A= [#####-- 

312 for n, var in enumerate(state.x_grids): # ####### 

313 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----] 

314 w_j[n, :grid_sizes[var]] = state.weights[var] 

315 

316 # Compute values ahead of time that will be needed for the gradient 

317 diff = x_arr[..., np.newaxis] - x_j 

318 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8) 

319 check_interp_pts = np.sum(div_zero_idx) > 0 

320 diff[div_zero_idx] = 1 

321 quotient = w_j / diff # (..., xdim, Nx) 

322 qsum = np.nansum(quotient, axis=-1) # (..., xdim) 

323 sqsum = np.nansum(w_j / diff ** 2, axis=-1) # (..., xdim) 

324 jac = np.zeros(x_arr.shape[:-1] + (ydim, xdim)) # (..., ydim, xdim) 

325 

326 # Loop over multi-indices and compute derivative of tensor-product lagrange polynomials 

327 indices = [range(s) for s in grid_sizes.values()] 

328 for k, var in enumerate(grid_sizes): 

329 dims = [idx for idx in range(xdim) if idx != k] 

330 for i, j in enumerate(itertools.product(*indices)): 

331 j_dims = [j[idx] for idx in dims] 

332 L_j = quotient[..., dims, j_dims] / qsum[..., dims] # (..., xdim-1) 

333 

334 # Partial derivative of L_j with respect to x_k 

335 dLJ_dx = ((w_j[k, j[k]] / (qsum[..., k] * diff[..., k, j[k]])) * 

336 (sqsum[..., k] / qsum[..., k] - 1 / diff[..., k, j[k]])) 

337 

338 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j 

339 if check_interp_pts: 

340 other_pts = np.copy(div_zero_idx) 

341 other_pts[div_zero_idx[..., list(range(xdim)), j]] = False 

342 L_j[div_zero_idx[..., dims, j_dims]] = 1 

343 L_j[np.any(other_pts[..., dims, :], axis=-1)] = 0 

344 

345 # Set derivatives when x is at the interpolation points (i.e. x==x_j) 

346 p_idx = [idx for idx in range(grid_sizes[var]) if idx != j[k]] 

347 w_j_large = np.broadcast_to(w_j[k, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy() 

348 curr_j_idx = div_zero_idx[..., k, j[k]] 

349 other_j_idx = np.any(other_pts[..., k, :], axis=-1) 

350 dLJ_dx[curr_j_idx] = -np.nansum((w_j[k, p_idx] / w_j[k, j[k]]) / 

351 (x_arr[curr_j_idx, k, np.newaxis] - x_j[k, p_idx]), axis=-1) 

352 dLJ_dx[other_j_idx] = ((w_j[k, j[k]] / w_j_large[other_pts[..., k, :]]) / 

353 (x_arr[other_j_idx, k] - x_j[k, j[k]])) 

354 

355 dLJ_dx = np.expand_dims(dLJ_dx, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1) 

356 

357 # Add contribution to the Jacobian 

358 jac[..., k] += dLJ_dx * yi_arr[i, :] 

359 

360 # Unpack the outputs back into a Dataset (array of length xdim for each y_var giving partial derivatives) 

361 jac_ret = {} 

362 start_idx = 0 

363 for var, arr in yi.items(): 

364 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1 

365 end_idx = start_idx + num_vals 

366 jac_ret[var] = jac[..., start_idx:end_idx, :] # (..., ydim, xdim) 

367 if len(arr.shape) == 1: 

368 jac_ret[var] = np.squeeze(jac_ret[var], axis=-2) # for scalars: (..., xdim) partial derivatives 

369 start_idx = end_idx 

370 return jac_ret 

371 

372 def hessian(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]): 

373 """Evaluate the Hessian at points `x` using the interpolator.""" 

374 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim) 

375 xi, yi = training_data 

376 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1) 

377 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1) 

378 

379 xdim = x_arr.shape[-1] 

380 ydim = yi_arr.shape[-1] 

381 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()} 

382 grid_size_list = list(grid_sizes.values()) 

383 max_size = max(grid_size_list) 

384 

385 # Create ragged edge matrix of interpolation pts and weights 

386 x_j = np.full((xdim, max_size), np.nan) # For example: 

387 w_j = np.full((xdim, max_size), np.nan) # A= [#####-- 

388 for n, var in enumerate(state.x_grids): # ####### 

389 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----] 

390 w_j[n, :grid_sizes[var]] = state.weights[var] 

391 

392 # Compute values ahead of time that will be needed for the gradient 

393 diff = x_arr[..., np.newaxis] - x_j 

394 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8) 

395 check_interp_pts = np.sum(div_zero_idx) > 0 

396 diff[div_zero_idx] = 1 

397 quotient = w_j / diff # (..., xdim, Nx) 

398 qsum = np.nansum(quotient, axis=-1) # (..., xdim) 

399 qsum_p = -np.nansum(w_j / diff ** 2, axis=-1) # (..., xdim) 

400 qsum_pp = 2 * np.nansum(w_j / diff ** 3, axis=-1) # (..., xdim) 

401 

402 # Loop over multi-indices and compute 2nd derivative of tensor-product lagrange polynomials 

403 hess = np.zeros(x_arr.shape[:-1] + (ydim, xdim, xdim)) # (..., ydim, xdim, xdim) 

404 indices = [range(s) for s in grid_size_list] 

405 for m in range(xdim): 

406 for n in range(m, xdim): 

407 dims = [idx for idx in range(xdim) if idx not in [m, n]] 

408 for i, j in enumerate(itertools.product(*indices)): 

409 j_dims = [j[idx] for idx in dims] 

410 L_j = quotient[..., dims, j_dims] / qsum[..., dims] # (..., xdim-2) 

411 

412 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j 

413 if check_interp_pts: 

414 other_pts = np.copy(div_zero_idx) 

415 other_pts[div_zero_idx[..., list(range(xdim)), j]] = False 

416 L_j[div_zero_idx[..., dims, j_dims]] = 1 

417 L_j[np.any(other_pts[..., dims, :], axis=-1)] = 0 

418 

419 # Cross-terms in Hessian 

420 if m != n: 

421 # Partial derivative of L_j with respect to x_m and x_n 

422 d2LJ_dx2 = np.ones(x_arr.shape[:-1]) 

423 for k in [m, n]: 

424 dLJ_dx = ((w_j[k, j[k]] / (qsum[..., k] * diff[..., k, j[k]])) * 

425 (-qsum_p[..., k] / qsum[..., k] - 1 / diff[..., k, j[k]])) 

426 

427 # Set derivatives when x is at the interpolation points (i.e. x==x_j) 

428 if check_interp_pts: 

429 p_idx = [idx for idx in range(grid_size_list[k]) if idx != j[k]] 

430 w_j_large = np.broadcast_to(w_j[k, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy() 

431 curr_j_idx = div_zero_idx[..., k, j[k]] 

432 other_j_idx = np.any(other_pts[..., k, :], axis=-1) 

433 dLJ_dx[curr_j_idx] = -np.nansum((w_j[k, p_idx] / w_j[k, j[k]]) / 

434 (x_arr[curr_j_idx, k, np.newaxis] - x_j[k, p_idx]), 

435 axis=-1) 

436 dLJ_dx[other_j_idx] = ((w_j[k, j[k]] / w_j_large[other_pts[..., k, :]]) / 

437 (x_arr[other_j_idx, k] - x_j[k, j[k]])) 

438 

439 d2LJ_dx2 *= dLJ_dx 

440 

441 d2LJ_dx2 = np.expand_dims(d2LJ_dx2, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1) 

442 hess[..., m, n] += d2LJ_dx2 * yi_arr[i, :] 

443 hess[..., n, m] += d2LJ_dx2 * yi_arr[i, :] 

444 

445 # Diagonal terms in Hessian: 

446 else: 

447 front_term = w_j[m, j[m]] / (qsum[..., m] * diff[..., m, j[m]]) 

448 first_term = (-qsum_pp[..., m] / qsum[..., m]) + 2 * (qsum_p[..., m] / qsum[..., m]) ** 2 

449 second_term = (2 * (qsum_p[..., m] / (qsum[..., m] * diff[..., m, j[m]])) 

450 + 2 / diff[..., m, j[m]] ** 2) 

451 d2LJ_dx2 = front_term * (first_term + second_term) 

452 

453 # Set derivatives when x is at the interpolation points (i.e. x==x_j) 

454 if check_interp_pts: 

455 curr_j_idx = div_zero_idx[..., m, j[m]] 

456 other_j_idx = np.any(other_pts[..., m, :], axis=-1) 

457 if np.any(curr_j_idx) or np.any(other_j_idx): 

458 p_idx = [idx for idx in range(grid_size_list[m]) if idx != j[m]] 

459 w_j_large = np.broadcast_to(w_j[m, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy() 

460 x_j_large = np.broadcast_to(x_j[m, :], x_arr.shape[:-1] + x_j.shape[-1:]).copy() 

461 

462 # if these points are at the current j interpolation point 

463 d2LJ_dx2[curr_j_idx] = (2 * np.nansum((w_j[m, p_idx] / w_j[m, j[m]]) / 

464 (x_arr[curr_j_idx, m, np.newaxis] - x_j[m, p_idx]), # noqa: E501 

465 axis=-1) ** 2 + 

466 2 * np.nansum((w_j[m, p_idx] / w_j[m, j[m]]) / 

467 (x_arr[curr_j_idx, m, np.newaxis] - x_j[m, p_idx]) ** 2, # noqa: E501 

468 axis=-1)) 

469 

470 # if these points are at any other interpolation point 

471 other_pts_inv = other_pts.copy() 

472 other_pts_inv[other_j_idx, m, :grid_size_list[m]] = np.invert( 

473 other_pts[other_j_idx, m, :grid_size_list[m]]) # noqa: E501 

474 curr_x_j = x_j_large[other_pts[..., m, :]].reshape((-1, 1)) 

475 other_x_j = x_j_large[other_pts_inv[..., m, :]].reshape((-1, len(p_idx))) 

476 curr_w_j = w_j_large[other_pts[..., m, :]].reshape((-1, 1)) 

477 other_w_j = w_j_large[other_pts_inv[..., m, :]].reshape((-1, len(p_idx))) 

478 curr_div = w_j[m, j[m]] / np.squeeze(curr_w_j, axis=-1) 

479 curr_diff = np.squeeze(curr_x_j, axis=-1) - x_j[m, j[m]] 

480 d2LJ_dx2[other_j_idx] = ((-2 * curr_div / curr_diff) * (np.nansum( 

481 (other_w_j / curr_w_j) / (curr_x_j - other_x_j), axis=-1) + 1 / curr_diff)) 

482 

483 d2LJ_dx2 = np.expand_dims(d2LJ_dx2, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1) 

484 hess[..., m, n] += d2LJ_dx2 * yi_arr[i, :] 

485 

486 # Unpack the outputs back into a Dataset (matrix (xdim, xdim) for each y_var giving 2nd partial derivatives) 

487 hess_ret = {} 

488 start_idx = 0 

489 for var, arr in yi.items(): 

490 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1 

491 end_idx = start_idx + num_vals 

492 hess_ret[var] = hess[..., start_idx:end_idx, :, :] # (..., ydim, xdim, xdim) 

493 if len(arr.shape) == 1: 

494 hess_ret[var] = np.squeeze(hess_ret[var], axis=-3) # for scalars: (..., xdim, xdim) partial derivatives 

495 start_idx = end_idx 

496 return hess_ret 

497 

498 

499@dataclass 

500class Linear(Interpolator, StringSerializable): 

501 """Implementation of linear regression using `sklearn`. The `Linear` interpolator uses a pipeline of 

502 `PolynomialFeatures` and a linear model from `sklearn.linear_model` to approximate the input-output mapping 

503 with a linear combination of polynomial features. Defaults to Ridge regression (L2 regularization) with 

504 polynomials of degree 1 (i.e. normal linear regression). 

505 

506 :ivar regressor: the scikit-learn linear model to use (e.g. 'Ridge', 'Lasso', 'ElasticNet', etc.). 

507 :ivar scaler: the scikit-learn preprocessing scaler to use (e.g. 'MinMaxScaler', 'StandardScaler', etc.). If None, 

508 no scaling is applied (default). 

509 :ivar regressor_opts: options to pass to the regressor constructor 

510 (see [scikit-learn](https://scikit-learn.org/stable/) documentation). 

511 :ivar scaler_opts: options to pass to the scaler constructor 

512 :ivar polynomial_opts: options to pass to the `PolynomialFeatures` constructor (e.g. 'degree', 'include_bias'). 

513 """ 

514 regressor: str = 'Ridge' 

515 scaler: str = None 

516 regressor_opts: dict = field(default_factory=dict) 

517 scaler_opts: dict = field(default_factory=dict) 

518 polynomial_opts: dict = field(default_factory=lambda: {'degree': 1, 'include_bias': False}) 

519 

520 def __post_init__(self): 

521 try: 

522 getattr(linear_model, self.regressor) 

523 except AttributeError: 

524 raise ImportError(f"Regressor '{self.regressor}' not found in sklearn.linear_model") 

525 

526 if self.scaler is not None: 

527 try: 

528 getattr(preprocessing, self.scaler) 

529 except AttributeError: 

530 raise ImportError(f"Scaler '{self.scaler}' not found in sklearn.preprocessing") 

531 

532 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset], 

533 old_state: LinearState, input_domains: dict[str, tuple]) -> InterpolatorState: 

534 """Train a new linear regression model. 

535 

536 :param beta: if not empty, then the first element is the number of degrees to add to the polynomial features. 

537 For example, if `beta=(1,)`, then the polynomial degree will be increased by 1. If the degree 

538 is already set to 1 in `polynomial_opts` (default), then the new degree will be 2. 

539 :param training_data: a tuple of dictionaries containing the new training data (`xtrain`, `ytrain`) 

540 :param old_state: the old linear state to refine (only used to get the order of input/output variables) 

541 :param input_domains: (not used for `Linear`) 

542 :returns: the new linear state 

543 """ 

544 polynomial_opts = self.polynomial_opts.copy() 

545 degree = polynomial_opts.pop('degree', 1) 

546 if beta != (): 

547 degree += beta[0] 

548 

549 pipe = [] 

550 if self.scaler is not None: 

551 pipe.append(('scaler', getattr(preprocessing, self.scaler)(**self.scaler_opts))) 

552 pipe.extend([('poly', PolynomialFeatures(degree=degree, **polynomial_opts)), 

553 ('linear', getattr(linear_model, self.regressor)(**self.regressor_opts))]) 

554 regressor = Pipeline(pipe) 

555 

556 xtrain, ytrain = training_data 

557 

558 # Get order of variables for inputs and outputs 

559 if old_state is not None: 

560 x_vars = old_state.x_vars 

561 y_vars = old_state.y_vars 

562 else: 

563 x_vars = list(xtrain.keys()) 

564 y_vars = list(ytrain.keys()) 

565 

566 # Convert to (N, xdim) and (N, ydim) arrays 

567 x_arr = np.concatenate([xtrain[var][..., np.newaxis] for var in x_vars], axis=-1) 

568 y_arr = np.concatenate([ytrain[var][..., np.newaxis] for var in y_vars], axis=-1) 

569 

570 regressor.fit(x_arr, y_arr) 

571 

572 return LinearState(regressor=regressor, x_vars=x_vars, y_vars=y_vars) 

573 

574 def predict(self, x: Dataset, state: LinearState, training_data: tuple[Dataset, Dataset]): 

575 """Predict the output of the model at points `x` using the linear regressor provided in `state`. 

576 

577 :param x: the input Dataset `dict` mapping input variables to prediction locations 

578 :param state: the state containing the linear regressor to use 

579 :param training_data: not used for `Linear` (since the regressor is already trained in `state`) 

580 """ 

581 # Convert to (N, xdim) array for sklearn 

582 x_arr = np.concatenate([x[var][..., np.newaxis] for var in state.x_vars], axis=-1) 

583 loop_shape = x_arr.shape[:-1] 

584 x_arr = x_arr.reshape((-1, x_arr.shape[-1])) 

585 

586 y_arr = state.regressor.predict(x_arr) 

587 y_arr = y_arr.reshape(loop_shape + (len(state.y_vars),)) # (..., ydim) 

588 

589 # Unpack the outputs back into a Dataset 

590 return {var: y_arr[..., i] for i, var in enumerate(state.y_vars)} 

591 

592 def gradient(self, x: Dataset, state: LinearState, training_data: tuple[Dataset, Dataset]): 

593 raise NotImplementedError 

594 

595 def hessian(self, x: Dataset, state: LinearState, training_data: tuple[Dataset, Dataset]): 

596 raise NotImplementedError