Coverage for src/amisc/interpolator.py: 92%
394 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-15 15:05 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-15 15:05 +0000
1"""Provides interpolator classes. Interpolators approximate the input → output mapping of a model given
2a set of training data. The training data consists of input-output pairs, and the interpolator can be
3refined with new training data.
5Includes:
7- `Interpolator`: Abstract class providing basic structure of an interpolator
8- `Lagrange`: Concrete implementation for tensor-product barycentric Lagrange interpolation
9- `Linear`: Concrete implementation for linear regression using `sklearn`
10- `GPR`: Concrete implementation for Gaussian process regression using `sklearn`
11- `InterpolatorState`: Interface for a dataclass that stores the internal state of an interpolator
12- `LagrangeState`: The internal state for a barycentric Lagrange polynomial interpolator
13- `LinearState`: The internal state for a linear interpolator (using sklearn)
14- `GPRState`: The internal state for a Gaussian process regression interpolator (using sklearn)
15"""
16from __future__ import annotations
18import copy
19import itertools
20from abc import ABC, abstractmethod
21from dataclasses import dataclass, field
23import numpy as np
24from sklearn import linear_model, preprocessing
25from sklearn.gaussian_process import GaussianProcessRegressor, kernels
26from sklearn.pipeline import Pipeline
27from sklearn.preprocessing import PolynomialFeatures
29from amisc.serialize import Base64Serializable, Serializable, StringSerializable
30from amisc.typing import Dataset, MultiIndex
32__all__ = ["InterpolatorState", "LagrangeState", "LinearState", "GPRState", "Interpolator", "Lagrange", "Linear", "GPR"]
35class InterpolatorState(Serializable, ABC):
36 """Interface for a dataclass that stores the internal state of an interpolator (e.g. weights and biases)."""
37 pass
40@dataclass
41class LagrangeState(InterpolatorState, Base64Serializable):
42 """The internal state for a barycentric Lagrange polynomial interpolator.
44 :ivar weights: the 1d interpolation grid weights
45 :ivar x_grids: the 1d interpolation grids
46 """
47 weights: dict[str, np.ndarray] = field(default_factory=dict)
48 x_grids: dict[str, np.ndarray] = field(default_factory=dict)
50 def __eq__(self, other):
51 if isinstance(other, LagrangeState):
52 try:
53 return all([np.allclose(self.weights[var], other.weights[var]) for var in self.weights]) and \
54 all([np.allclose(self.x_grids[var], other.x_grids[var]) for var in self.x_grids])
55 except IndexError:
56 return False
57 else:
58 return False
61@dataclass
62class LinearState(InterpolatorState, Base64Serializable):
63 """The internal state for a linear interpolator (using sklearn).
65 :ivar x_vars: the input variables in order
66 :ivar y_vars: the output variables in order
67 :ivar regressor: the sklearn regressor object, a pipeline that consists of a `PolynomialFeatures` and a model from
68 `sklearn.linear_model`, i.e. Ridge, Lasso, etc.
69 """
70 x_vars: list[str] = field(default_factory=list)
71 y_vars: list[str] = field(default_factory=list)
72 regressor: Pipeline = None
74 def __eq__(self, other):
75 if isinstance(other, LinearState):
76 return (self.x_vars == other.x_vars and self.y_vars == other.y_vars and
77 np.allclose(self.regressor['poly'].powers_, other.regressor['poly'].powers_) and
78 np.allclose(self.regressor['linear'].coef_, other.regressor['linear'].coef_) and
79 np.allclose(self.regressor['linear'].intercept_, other.regressor['linear'].intercept_))
80 else:
81 return False
84@dataclass
85class GPRState(InterpolatorState, Base64Serializable):
86 """The internal state for a Gaussian Process Regressor interpolator (using sklearn).
88 :ivar x_vars: the input variables in order
89 :ivar y_vars: the output variables in order
90 :ivar regressor: the sklearn regressor object, a pipeline that consists of a preprocessing scaler and a
91 `GaussianProcessRegressor`.
92 """
93 x_vars: list[str] = field(default_factory=list)
94 y_vars: list[str] = field(default_factory=list)
95 regressor: Pipeline = None
97 def __eq__(self, other):
98 if isinstance(other, GPRState):
99 return (self.x_vars == other.x_vars and self.y_vars == other.y_vars and
100 len(self.regressor.steps) == len(other.regressor.steps) and
101 self.regressor['gpr'].alpha == other.regressor['gpr'].alpha and
102 self.regressor['gpr'].kernel_ == other.regressor['gpr'].kernel_)
103 else:
104 return False
107class Interpolator(Serializable, ABC):
108 """Interface for an interpolator object that approximates a model. An interpolator should:
110 - `refine` - take an old state and new training data and produce a new "refined" state (e.g. new weights/biases)
111 - `predict` - interpolate from the training data to a new set of points (i.e. approximate the underlying model)
112 - `gradient` - compute the grdient/Jacobian at new points (if you want)
113 - `hessian` - compute the 2nd derivative/Hessian at new points (if you want)
115 Currently, `Lagrange`, `Linear`, and `GPR` interpolators are supported and can be constructed from a configuration
116 `dict` via `Interpolator.from_dict()`.
117 """
119 @abstractmethod
120 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset],
121 old_state: InterpolatorState, input_domains: dict[str, tuple]) -> InterpolatorState:
122 """Refine the interpolator state with new training data.
124 :param beta: a multi-index specifying the fidelity "levels" of the new interpolator state (starts at (0,... 0))
125 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data
126 :param old_state: the previous state of the interpolator (None if initializing the first state)
127 :param input_domains: a `dict` mapping input variables to their corresponding domains
128 :returns: the new "refined" interpolator state
129 """
130 raise NotImplementedError
132 @abstractmethod
133 def predict(self, x: dict | Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset:
134 """Interpolate the output of the model at points `x` using the given state and training data
136 :param x: the input Dataset `dict` mapping input variables to locations at which to compute the interpolator
137 :param state: the current state of the interpolator
138 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state
139 :returns: a Dataset `dict` mapping output variables to interpolator outputs
140 """
141 raise NotImplementedError
143 def __call__(self, *args, **kwargs):
144 return self.predict(*args, **kwargs)
146 @abstractmethod
147 def gradient(self, x: Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset:
148 """Evaluate the gradient/Jacobian at points `x` using the interpolator.
150 :param x: the input Dataset `dict` mapping input variables to locations at which to evaluate the Jacobian
151 :param state: the current state of the interpolator
152 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state
153 :returns: a Dataset `dict` mapping output variables to Jacobian matrices of shape `(ydim, xdim)` -- for
154 scalar outputs, the Jacobian is returned as `(xdim,)`
155 """
156 raise NotImplementedError
158 @abstractmethod
159 def hessian(self, x: Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset:
160 """Evaluate the Hessian at points `x` using the interpolator.
162 :param x: the input Dataset `dict` mapping input variables to locations at which to evaluate the Hessian
163 :param state: the current state of the interpolator
164 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state
165 :returns: a Dataset `dict` mapping output variables to Hessian matrices of shape `(xdim, xdim)`
166 """
167 raise NotImplementedError
169 @classmethod
170 def from_dict(cls, config: dict) -> Interpolator:
171 """Create an `Interpolator` object from a `dict` config. Available methods are `lagrange`, `linear`, and `gpr`.
172 Will attempt to find the method if not listed.
174 :param config: a `dict` containing the configuration for the interpolator, with the `method` key specifying the
175 name of the interpolator method to use, and the rest of the keys are options for the method
176 """
177 method = config.pop('method', 'lagrange')
178 match method.lower():
179 case 'lagrange':
180 return Lagrange(**config)
181 case 'linear':
182 return Linear(**config)
183 case 'gpr':
184 return GPR(**config)
185 case _:
186 import amisc.interpolator
188 if hasattr(amisc.interpolator, method):
189 return getattr(amisc.interpolator, method)(**config)
191 raise NotImplementedError(f"Unknown interpolator method: {method}")
194@dataclass
195class Lagrange(Interpolator, StringSerializable):
196 """Implementation of a tensor-product barycentric Lagrange polynomial interpolator. A `LagrangeState` stores
197 the 1d interpolation grids and weights for each input dimension. `Lagrange` computes the tensor-product
198 of 1d Lagrange polynomials to approximate a multi-variate function.
200 :ivar interval_capacity: tuning knob for Lagrange interpolation (see Berrut and Trefethen 2004)
201 """
202 interval_capacity: float = 4.0
204 @staticmethod
205 def _extend_grids(x_grids: dict[str, np.ndarray], x_points: dict[str, np.ndarray]):
206 """Extend the 1d `x` grids with any new points from `x_points`, skipping duplicates. This will preserve the
207 order of new points in the extended grid without duplication. This will maintain the same order as
208 `SparseGrid.x_grids` if that is the underlying training data structure.
210 !!! Example
211 ```python
212 x = {'x': np.array([0, 1, 2])}
213 new_x = {'x': np.array([3, 0, 2, 3, 1, 4])}
214 extended_x = Lagrange._extend_grids(x, new_x)
215 # gives {'x': np.array([0, 1, 2, 3, 4])}
216 ```
218 !!! Warning
219 This will only work for 1d grids; all `x_grids` should be scalar quantities. Field quantities should
220 already be passed in as several separate 1d latent coefficients.
222 :param x_grids: the current 1d interpolation grids
223 :param x_points: the new points to extend the interpolation grids with
224 :returns: the extended grids
225 """
226 extended_grids = copy.deepcopy(x_grids)
227 for var, new_pts in x_points.items():
228 # Get unique new values that are not already in the grid (maintain their order; keep only first one)
229 u = x_grids[var] if var in x_grids else np.array([])
230 u, ind = np.unique(new_pts[~np.isin(new_pts, u)], return_index=True)
231 u = u[np.argsort(ind)]
232 extended_grids[var] = u if var not in x_grids else np.concatenate((x_grids[var], u), axis=0)
234 return extended_grids
236 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset],
237 old_state: LagrangeState, input_domains: dict[str, tuple]) -> LagrangeState:
238 """Refine the interpolator state with new training data.
240 :param beta: the refinement level indices for the interpolator (not used for `Lagrange`)
241 :param training_data: a tuple of dictionaries containing the new training data (`xtrain`, `ytrain`)
242 :param old_state: the old interpolator state to refine (None if initializing)
243 :param input_domains: a `dict` of each input variable's domain; input keys should match `xtrain` keys
244 :returns: the new interpolator state
245 """
246 xtrain, ytrain = training_data # Lagrange only really needs the xtrain data to update barycentric weights/grids
248 # Initialize the interpolator state
249 if old_state is None:
250 x_grids = self._extend_grids({}, xtrain)
251 weights = {}
252 for var, grid in x_grids.items():
253 bds = input_domains[var]
254 Nx = grid.shape[0]
255 C = (bds[1] - bds[0]) / self.interval_capacity # Interval capacity (see Berrut and Trefethen 2004)
256 xj = grid.reshape((Nx, 1))
257 xi = xj.reshape((1, Nx))
258 dist = (xj - xi) / C
259 np.fill_diagonal(dist, 1) # Ignore product when i==j
260 weights[var] = (1.0 / np.prod(dist, axis=1)) # (Nx,)
262 # Otherwise, refine the interpolator state
263 else:
264 x_grids = self._extend_grids(old_state.x_grids, xtrain)
265 weights = copy.deepcopy(old_state.weights)
266 for var, grid in x_grids.items():
267 bds = input_domains[var]
268 Nx_old = old_state.x_grids[var].shape[0]
269 Nx_new = grid.shape[0]
270 if Nx_new > Nx_old:
271 weights[var] = np.pad(weights[var], [(0, Nx_new - Nx_old)], mode='constant', constant_values=np.nan)
272 C = (bds[1] - bds[0]) / self.interval_capacity
273 for j in range(Nx_old, Nx_new):
274 weights[var][:j] *= (C / (grid[:j] - grid[j]))
275 weights[var][j] = np.prod(C / (grid[j] - grid[:j]))
277 return LagrangeState(weights=weights, x_grids=x_grids)
279 def predict(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]):
280 """Predict the output of the model at points `x` with barycentric Lagrange interpolation."""
281 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim)
282 # Inputs `x` may come in unordered, but they should get realigned with the internal `x_grids` state
283 xi, yi = training_data
284 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1)
285 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1)
287 xdim = x_arr.shape[-1]
288 ydim = yi_arr.shape[-1]
289 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()}
290 max_size = max(grid_sizes.values())
291 dims = list(range(xdim))
293 # Create ragged edge matrix of interpolation pts and weights
294 x_j = np.full((xdim, max_size), np.nan) # For example:
295 w_j = np.full((xdim, max_size), np.nan) # A= [#####--
296 for n, var in enumerate(state.x_grids): # #######
297 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----]
298 w_j[n, :grid_sizes[var]] = state.weights[var]
300 diff = x_arr[..., np.newaxis] - x_j
301 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8)
302 check_interp_pts = np.sum(div_zero_idx) > 0 # whether we are evaluating directly on some interp pts
303 diff[div_zero_idx] = 1
304 quotient = w_j / diff # (..., xdim, Nx)
305 qsum = np.nansum(quotient, axis=-1) # (..., xdim)
306 y = np.zeros(x_arr.shape[:-1] + (ydim,)) # (..., ydim)
308 # Loop over multi-indices and compute tensor-product lagrange polynomials
309 indices = [range(s) for s in grid_sizes.values()]
310 for i, j in enumerate(itertools.product(*indices)):
311 L_j = quotient[..., dims, j] / qsum # (..., xdim)
313 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j
314 if check_interp_pts:
315 other_pts = np.copy(div_zero_idx)
316 other_pts[div_zero_idx[..., dims, j]] = False
317 L_j[div_zero_idx[..., dims, j]] = 1
318 L_j[np.any(other_pts, axis=-1)] = 0
320 # Add multivariate basis polynomial contribution to interpolation output
321 y += np.prod(L_j, axis=-1, keepdims=True) * yi_arr[i, :]
323 # Unpack the outputs back into a Dataset
324 y_ret = {}
325 start_idx = 0
326 for var, arr in yi.items():
327 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1
328 end_idx = start_idx + num_vals
329 y_ret[var] = y[..., start_idx:end_idx]
330 if len(arr.shape) == 1:
331 y_ret[var] = np.squeeze(y_ret[var], axis=-1) # for scalars
332 start_idx = end_idx
333 return y_ret
335 def gradient(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]):
336 """Evaluate the gradient/Jacobian at points `x` using the interpolator."""
337 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim)
338 xi, yi = training_data
339 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1)
340 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1)
342 xdim = x_arr.shape[-1]
343 ydim = yi_arr.shape[-1]
344 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()}
345 max_size = max(grid_sizes.values())
347 # Create ragged edge matrix of interpolation pts and weights
348 x_j = np.full((xdim, max_size), np.nan) # For example:
349 w_j = np.full((xdim, max_size), np.nan) # A= [#####--
350 for n, var in enumerate(state.x_grids): # #######
351 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----]
352 w_j[n, :grid_sizes[var]] = state.weights[var]
354 # Compute values ahead of time that will be needed for the gradient
355 diff = x_arr[..., np.newaxis] - x_j
356 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8)
357 check_interp_pts = np.sum(div_zero_idx) > 0
358 diff[div_zero_idx] = 1
359 quotient = w_j / diff # (..., xdim, Nx)
360 qsum = np.nansum(quotient, axis=-1) # (..., xdim)
361 sqsum = np.nansum(w_j / diff ** 2, axis=-1) # (..., xdim)
362 jac = np.zeros(x_arr.shape[:-1] + (ydim, xdim)) # (..., ydim, xdim)
364 # Loop over multi-indices and compute derivative of tensor-product lagrange polynomials
365 indices = [range(s) for s in grid_sizes.values()]
366 for k, var in enumerate(grid_sizes):
367 dims = [idx for idx in range(xdim) if idx != k]
368 for i, j in enumerate(itertools.product(*indices)):
369 j_dims = [j[idx] for idx in dims]
370 L_j = quotient[..., dims, j_dims] / qsum[..., dims] # (..., xdim-1)
372 # Partial derivative of L_j with respect to x_k
373 dLJ_dx = ((w_j[k, j[k]] / (qsum[..., k] * diff[..., k, j[k]])) *
374 (sqsum[..., k] / qsum[..., k] - 1 / diff[..., k, j[k]]))
376 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j
377 if check_interp_pts:
378 other_pts = np.copy(div_zero_idx)
379 other_pts[div_zero_idx[..., list(range(xdim)), j]] = False
380 L_j[div_zero_idx[..., dims, j_dims]] = 1
381 L_j[np.any(other_pts[..., dims, :], axis=-1)] = 0
383 # Set derivatives when x is at the interpolation points (i.e. x==x_j)
384 p_idx = [idx for idx in range(grid_sizes[var]) if idx != j[k]]
385 w_j_large = np.broadcast_to(w_j[k, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy()
386 curr_j_idx = div_zero_idx[..., k, j[k]]
387 other_j_idx = np.any(other_pts[..., k, :], axis=-1)
388 dLJ_dx[curr_j_idx] = -np.nansum((w_j[k, p_idx] / w_j[k, j[k]]) /
389 (x_arr[curr_j_idx, k, np.newaxis] - x_j[k, p_idx]), axis=-1)
390 dLJ_dx[other_j_idx] = ((w_j[k, j[k]] / w_j_large[other_pts[..., k, :]]) /
391 (x_arr[other_j_idx, k] - x_j[k, j[k]]))
393 dLJ_dx = np.expand_dims(dLJ_dx, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1)
395 # Add contribution to the Jacobian
396 jac[..., k] += dLJ_dx * yi_arr[i, :]
398 # Unpack the outputs back into a Dataset (array of length xdim for each y_var giving partial derivatives)
399 jac_ret = {}
400 start_idx = 0
401 for var, arr in yi.items():
402 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1
403 end_idx = start_idx + num_vals
404 jac_ret[var] = jac[..., start_idx:end_idx, :] # (..., ydim, xdim)
405 if len(arr.shape) == 1:
406 jac_ret[var] = np.squeeze(jac_ret[var], axis=-2) # for scalars: (..., xdim) partial derivatives
407 start_idx = end_idx
408 return jac_ret
410 def hessian(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]):
411 """Evaluate the Hessian at points `x` using the interpolator."""
412 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim)
413 xi, yi = training_data
414 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1)
415 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1)
417 xdim = x_arr.shape[-1]
418 ydim = yi_arr.shape[-1]
419 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()}
420 grid_size_list = list(grid_sizes.values())
421 max_size = max(grid_size_list)
423 # Create ragged edge matrix of interpolation pts and weights
424 x_j = np.full((xdim, max_size), np.nan) # For example:
425 w_j = np.full((xdim, max_size), np.nan) # A= [#####--
426 for n, var in enumerate(state.x_grids): # #######
427 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----]
428 w_j[n, :grid_sizes[var]] = state.weights[var]
430 # Compute values ahead of time that will be needed for the gradient
431 diff = x_arr[..., np.newaxis] - x_j
432 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8)
433 check_interp_pts = np.sum(div_zero_idx) > 0
434 diff[div_zero_idx] = 1
435 quotient = w_j / diff # (..., xdim, Nx)
436 qsum = np.nansum(quotient, axis=-1) # (..., xdim)
437 qsum_p = -np.nansum(w_j / diff ** 2, axis=-1) # (..., xdim)
438 qsum_pp = 2 * np.nansum(w_j / diff ** 3, axis=-1) # (..., xdim)
440 # Loop over multi-indices and compute 2nd derivative of tensor-product lagrange polynomials
441 hess = np.zeros(x_arr.shape[:-1] + (ydim, xdim, xdim)) # (..., ydim, xdim, xdim)
442 indices = [range(s) for s in grid_size_list]
443 for m in range(xdim):
444 for n in range(m, xdim):
445 dims = [idx for idx in range(xdim) if idx not in [m, n]]
446 for i, j in enumerate(itertools.product(*indices)):
447 j_dims = [j[idx] for idx in dims]
448 L_j = quotient[..., dims, j_dims] / qsum[..., dims] # (..., xdim-2)
450 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j
451 if check_interp_pts:
452 other_pts = np.copy(div_zero_idx)
453 other_pts[div_zero_idx[..., list(range(xdim)), j]] = False
454 L_j[div_zero_idx[..., dims, j_dims]] = 1
455 L_j[np.any(other_pts[..., dims, :], axis=-1)] = 0
457 # Cross-terms in Hessian
458 if m != n:
459 # Partial derivative of L_j with respect to x_m and x_n
460 d2LJ_dx2 = np.ones(x_arr.shape[:-1])
461 for k in [m, n]:
462 dLJ_dx = ((w_j[k, j[k]] / (qsum[..., k] * diff[..., k, j[k]])) *
463 (-qsum_p[..., k] / qsum[..., k] - 1 / diff[..., k, j[k]]))
465 # Set derivatives when x is at the interpolation points (i.e. x==x_j)
466 if check_interp_pts:
467 p_idx = [idx for idx in range(grid_size_list[k]) if idx != j[k]]
468 w_j_large = np.broadcast_to(w_j[k, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy()
469 curr_j_idx = div_zero_idx[..., k, j[k]]
470 other_j_idx = np.any(other_pts[..., k, :], axis=-1)
471 dLJ_dx[curr_j_idx] = -np.nansum((w_j[k, p_idx] / w_j[k, j[k]]) /
472 (x_arr[curr_j_idx, k, np.newaxis] - x_j[k, p_idx]),
473 axis=-1)
474 dLJ_dx[other_j_idx] = ((w_j[k, j[k]] / w_j_large[other_pts[..., k, :]]) /
475 (x_arr[other_j_idx, k] - x_j[k, j[k]]))
477 d2LJ_dx2 *= dLJ_dx
479 d2LJ_dx2 = np.expand_dims(d2LJ_dx2, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1)
480 hess[..., m, n] += d2LJ_dx2 * yi_arr[i, :]
481 hess[..., n, m] += d2LJ_dx2 * yi_arr[i, :]
483 # Diagonal terms in Hessian:
484 else:
485 front_term = w_j[m, j[m]] / (qsum[..., m] * diff[..., m, j[m]])
486 first_term = (-qsum_pp[..., m] / qsum[..., m]) + 2 * (qsum_p[..., m] / qsum[..., m]) ** 2
487 second_term = (2 * (qsum_p[..., m] / (qsum[..., m] * diff[..., m, j[m]]))
488 + 2 / diff[..., m, j[m]] ** 2)
489 d2LJ_dx2 = front_term * (first_term + second_term)
491 # Set derivatives when x is at the interpolation points (i.e. x==x_j)
492 if check_interp_pts:
493 curr_j_idx = div_zero_idx[..., m, j[m]]
494 other_j_idx = np.any(other_pts[..., m, :], axis=-1)
495 if np.any(curr_j_idx) or np.any(other_j_idx):
496 p_idx = [idx for idx in range(grid_size_list[m]) if idx != j[m]]
497 w_j_large = np.broadcast_to(w_j[m, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy()
498 x_j_large = np.broadcast_to(x_j[m, :], x_arr.shape[:-1] + x_j.shape[-1:]).copy()
500 # if these points are at the current j interpolation point
501 d2LJ_dx2[curr_j_idx] = (2 * np.nansum((w_j[m, p_idx] / w_j[m, j[m]]) /
502 (x_arr[curr_j_idx, m, np.newaxis] - x_j[m, p_idx]), # noqa: E501
503 axis=-1) ** 2 +
504 2 * np.nansum((w_j[m, p_idx] / w_j[m, j[m]]) /
505 (x_arr[curr_j_idx, m, np.newaxis] - x_j[m, p_idx]) ** 2, # noqa: E501
506 axis=-1))
508 # if these points are at any other interpolation point
509 other_pts_inv = other_pts.copy()
510 other_pts_inv[other_j_idx, m, :grid_size_list[m]] = np.invert(
511 other_pts[other_j_idx, m, :grid_size_list[m]]) # noqa: E501
512 curr_x_j = x_j_large[other_pts[..., m, :]].reshape((-1, 1))
513 other_x_j = x_j_large[other_pts_inv[..., m, :]].reshape((-1, len(p_idx)))
514 curr_w_j = w_j_large[other_pts[..., m, :]].reshape((-1, 1))
515 other_w_j = w_j_large[other_pts_inv[..., m, :]].reshape((-1, len(p_idx)))
516 curr_div = w_j[m, j[m]] / np.squeeze(curr_w_j, axis=-1)
517 curr_diff = np.squeeze(curr_x_j, axis=-1) - x_j[m, j[m]]
518 d2LJ_dx2[other_j_idx] = ((-2 * curr_div / curr_diff) * (np.nansum(
519 (other_w_j / curr_w_j) / (curr_x_j - other_x_j), axis=-1) + 1 / curr_diff))
521 d2LJ_dx2 = np.expand_dims(d2LJ_dx2, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1)
522 hess[..., m, n] += d2LJ_dx2 * yi_arr[i, :]
524 # Unpack the outputs back into a Dataset (matrix (xdim, xdim) for each y_var giving 2nd partial derivatives)
525 hess_ret = {}
526 start_idx = 0
527 for var, arr in yi.items():
528 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1
529 end_idx = start_idx + num_vals
530 hess_ret[var] = hess[..., start_idx:end_idx, :, :] # (..., ydim, xdim, xdim)
531 if len(arr.shape) == 1:
532 hess_ret[var] = np.squeeze(hess_ret[var], axis=-3) # for scalars: (..., xdim, xdim) partial derivatives
533 start_idx = end_idx
534 return hess_ret
537@dataclass
538class Linear(Interpolator, StringSerializable):
539 """Implementation of linear regression using `sklearn`. The `Linear` interpolator uses a pipeline of
540 `PolynomialFeatures` and a linear model from `sklearn.linear_model` to approximate the input-output mapping
541 with a linear combination of polynomial features. Defaults to Ridge regression (L2 regularization) with
542 polynomials of degree 1 (i.e. normal linear regression).
544 :ivar regressor: the scikit-learn linear model to use (e.g. 'Ridge', 'Lasso', 'ElasticNet', etc.).
545 :ivar scaler: the scikit-learn preprocessing scaler to use (e.g. 'MinMaxScaler', 'StandardScaler', etc.). If None,
546 no scaling is applied (default).
547 :ivar regressor_opts: options to pass to the regressor constructor
548 (see [scikit-learn](https://scikit-learn.org/stable/) documentation).
549 :ivar scaler_opts: options to pass to the scaler constructor
550 :ivar polynomial_opts: options to pass to the `PolynomialFeatures` constructor (e.g. 'degree', 'include_bias').
551 """
552 regressor: str = 'Ridge'
553 scaler: str = None
554 regressor_opts: dict = field(default_factory=dict)
555 scaler_opts: dict = field(default_factory=dict)
556 polynomial_opts: dict = field(default_factory=lambda: {'degree': 1, 'include_bias': False})
558 def __post_init__(self):
559 try:
560 getattr(linear_model, self.regressor)
561 except AttributeError:
562 raise ImportError(f"Regressor '{self.regressor}' not found in sklearn.linear_model")
564 if self.scaler is not None:
565 try:
566 getattr(preprocessing, self.scaler)
567 except AttributeError:
568 raise ImportError(f"Scaler '{self.scaler}' not found in sklearn.preprocessing")
570 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset],
571 old_state: LinearState, input_domains: dict[str, tuple]) -> InterpolatorState:
572 """Train a new linear regression model.
574 :param beta: if not empty, then the first element is the number of degrees to add to the polynomial features.
575 For example, if `beta=(1,)`, then the polynomial degree will be increased by 1. If the degree
576 is already set to 1 in `polynomial_opts` (default), then the new degree will be 2.
577 :param training_data: a tuple of dictionaries containing the new training data (`xtrain`, `ytrain`)
578 :param old_state: the old linear state to refine (only used to get the order of input/output variables)
579 :param input_domains: (not used for `Linear`)
580 :returns: the new linear state
581 """
582 polynomial_opts = self.polynomial_opts.copy()
583 degree = polynomial_opts.pop('degree', 1)
584 if beta != ():
585 degree += beta[0]
587 pipe = []
588 if self.scaler is not None:
589 pipe.append(('scaler', getattr(preprocessing, self.scaler)(**self.scaler_opts)))
590 pipe.extend([('poly', PolynomialFeatures(degree=degree, **polynomial_opts)),
591 ('linear', getattr(linear_model, self.regressor)(**self.regressor_opts))])
592 regressor = Pipeline(pipe)
594 xtrain, ytrain = training_data
596 # Get order of variables for inputs and outputs
597 if old_state is not None:
598 x_vars = old_state.x_vars
599 y_vars = old_state.y_vars
600 else:
601 x_vars = list(xtrain.keys())
602 y_vars = list(ytrain.keys())
604 # Convert to (N, xdim) and (N, ydim) arrays
605 x_arr = np.concatenate([xtrain[var][..., np.newaxis] for var in x_vars], axis=-1)
606 y_arr = np.concatenate([ytrain[var][..., np.newaxis] for var in y_vars], axis=-1)
608 regressor.fit(x_arr, y_arr)
610 return LinearState(regressor=regressor, x_vars=x_vars, y_vars=y_vars)
612 def predict(self, x: Dataset, state: LinearState, training_data: tuple[Dataset, Dataset]):
613 """Predict the output of the model at points `x` using the linear regressor provided in `state`.
615 :param x: the input Dataset `dict` mapping input variables to prediction locations
616 :param state: the state containing the linear regressor to use
617 :param training_data: not used for `Linear` (since the regressor is already trained in `state`)
618 """
619 # Convert to (N, xdim) array for sklearn
620 x_arr = np.concatenate([x[var][..., np.newaxis] for var in state.x_vars], axis=-1)
621 loop_shape = x_arr.shape[:-1]
622 x_arr = x_arr.reshape((-1, x_arr.shape[-1]))
624 y_arr = state.regressor.predict(x_arr)
625 y_arr = y_arr.reshape(loop_shape + (len(state.y_vars),)) # (..., ydim)
627 # Unpack the outputs back into a Dataset
628 return {var: y_arr[..., i] for i, var in enumerate(state.y_vars)}
630 def gradient(self, x: Dataset, state: LinearState, training_data: tuple[Dataset, Dataset]):
631 raise NotImplementedError
633 def hessian(self, x: Dataset, state: LinearState, training_data: tuple[Dataset, Dataset]):
634 raise NotImplementedError
637@dataclass
638class GPR(Interpolator, StringSerializable):
639 """Implementation of Gaussian Process Regression using `sklearn`. The `GPR` uses a pipeline
640 of a scaler and a `GaussianProcessRegressor` to approximate the input-output mapping.
642 :ivar scaler: the scikit-learn preprocessing scaler to use (e.g. 'MinMaxScaler', 'StandardScaler', etc.). If None,
643 no scaling is applied (default).
644 :ivar kernel: the kernel to use for building the covariance matrix (e.g. 'RBF', 'Matern', 'PairwiseKernel', etc.).
645 If a string is provided, then the specified kernel is used with the given `kernel_opts`.
646 If a list is provided, then kernel operators ('Sum', 'Product', 'Exponentiation') can be used to
647 combine multiple kernels. The first element of the list should be the kernel or operator name, and
648 the remaining elements should be the arguments. Dicts are accepted as **kwargs. For example:
649 `['Sum', ['RBF', {'length_scale': 1.0}], ['Matern', {'length_scale': 1.0}]]` will create a sum of
650 an RBF and a Matern kernel with the specified length scales.
651 :ivar scaler_opts: options to pass to the scaler constructor
652 :ivar kernel_opts: options to pass to the kernel constructor (ignored if kernel is a list, where opts are already
653 specified for combinations of kernels).
654 :ivar regressor_opts: options to pass to the `GaussianProcessRegressor` constructor
655 (see [scikit-learn](https://scikit-learn.org/stable/) documentation).
656 """
657 scaler: str = None
658 kernel: str | list = 'RBF'
659 scaler_opts: dict = field(default_factory=dict)
660 kernel_opts: dict = field(default_factory=dict)
661 regressor_opts: dict = field(default_factory=lambda: {'n_restarts_optimizer': 5})
663 def _construct_kernel(self, kernel_list):
664 """Build a scikit-learn kernel from a list of kernels (e.g. RBF, Matern, etc.) and kernel operators
665 (Sum, Product, Exponentiation).
667 !!! Example
668 `['Sum', ['RBF'], ['Matern', {'length_scale': 1.0}]]` will become `RBF() + Matern(length_scale=1.0)`
670 :param kernel_list: list of kernel/operator names and arguments. Kwarg options can be passed as dicts.
671 :returns: the scikit-learn kernel object
672 """
673 # Base case for single items (just return as is)
674 if not isinstance(kernel_list, list):
675 return kernel_list
677 name = kernel_list[0]
678 args = [self._construct_kernel(ele) for ele in kernel_list[1:]]
680 # Base case for passing a single dict of kwargs
681 if len(args) == 1 and isinstance(args[0], dict):
682 return getattr(kernels, name)(**args[0])
684 # Base case for passing a list of args
685 return getattr(kernels, name)(*args)
687 def _validate_kernel(self, kernel_list):
688 """Make sure all requested kernels are available in scikit-learn."""
689 if not isinstance(kernel_list, list):
690 return
692 name = kernel_list[0]
694 if not hasattr(kernels, name):
695 raise ImportError(f"Kernel '{name}' not found in sklearn.gaussian_process.kernels")
697 for ele in kernel_list[1:]:
698 self._validate_kernel(ele)
700 def __post_init__(self):
701 self._validate_kernel(self.kernel if isinstance(self.kernel, list) else [self.kernel, self.kernel_opts])
703 if self.scaler is not None:
704 try:
705 getattr(preprocessing, self.scaler)
706 except AttributeError:
707 raise ImportError(f"Scaler '{self.scaler}' not found in sklearn.preprocessing")
709 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset],
710 old_state: GPRState, input_domains: dict[str, tuple]) -> InterpolatorState:
711 """Train a new gaussian process regression model.
713 :param beta: refinement level indices (Not used for 'GPR')
714 :param training_data: a tuple of dictionaries containing the new training data (`xtrain`, `ytrain`)
715 :param old_state: the old regressor state to refine (only used to get the order of input/output variables)
716 :param input_domains: (not used for `GPR`)
717 :returns: the new GPR state
718 """
719 gp_kernel = self._construct_kernel(self.kernel if isinstance(self.kernel, list)
720 else [self.kernel, self.kernel_opts])
721 gp = GaussianProcessRegressor(kernel=gp_kernel, **self.regressor_opts)
722 pipe = []
723 if self.scaler is not None:
724 pipe.append(('scaler', getattr(preprocessing, self.scaler)(**self.scaler_opts)))
725 pipe.append(('gpr', gp))
726 regressor = Pipeline(pipe)
728 xtrain, ytrain = training_data
730 # Get order of variables for inputs and outputs
731 if old_state is not None:
732 x_vars = old_state.x_vars
733 y_vars = old_state.y_vars
734 else:
735 x_vars = list(xtrain.keys())
736 y_vars = list(ytrain.keys())
738 # Convert to (N, xdim) and (N, ydim) arrays
739 x_arr = np.concatenate([xtrain[var][..., np.newaxis] for var in x_vars], axis=-1)
740 y_arr = np.concatenate([ytrain[var][..., np.newaxis] for var in y_vars], axis=-1)
742 regressor.fit(x_arr, y_arr)
744 return GPRState(regressor=regressor, x_vars=x_vars, y_vars=y_vars)
746 def predict(self, x: Dataset, state: GPRState, training_data: tuple[Dataset, Dataset]):
747 """Predict the output of the model at points `x` using the Gaussian Process Regressor provided in `state`.
749 :param x: the input Dataset `dict` mapping input variables to prediction locations
750 :param state: the state containing the Gaussian Process Regressor to use
751 :param training_data: not used for `GPR` (since the regressor is already trained in `state`)
752 """
753 # Convert to (N, xdim) array for sklearn
754 x_arr = np.concatenate([x[var][..., np.newaxis] for var in state.x_vars], axis=-1)
755 loop_shape = x_arr.shape[:-1]
756 x_arr = x_arr.reshape((-1, x_arr.shape[-1]))
758 y_arr = state.regressor.predict(x_arr)
759 y_arr = y_arr.reshape(loop_shape + (len(state.y_vars),)) # (..., ydim)
761 # Unpack the outputs back into a Dataset
762 return {var: y_arr[..., i] for i, var in enumerate(state.y_vars)}
764 def gradient(self, x: Dataset, state: GPRState, training_data: tuple[Dataset, Dataset]):
765 raise NotImplementedError
767 def hessian(self, x: Dataset, state: GPRState, training_data: tuple[Dataset, Dataset]):
768 raise NotImplementedError