Coverage for src/amisc/interpolator.py: 94%
320 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-10 15:12 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-10 15:12 +0000
1"""Provides interpolator classes. Interpolators approximate the input → output mapping of a model given
2a set of training data. The training data consists of input-output pairs, and the interpolator can be
3refined with new training data.
5Includes:
7- `Interpolator`: Abstract class providing basic structure of an interpolator
8- `Lagrange`: Concrete implementation for tensor-product barycentric Lagrange interpolation
9- `Linear`: Concrete implementation for linear regression using `sklearn`
10- `InterpolatorState`: Interface for a dataclass that stores the internal state of an interpolator
11- `LagrangeState`: The internal state for a barycentric Lagrange polynomial interpolator
12- `LinearState`: The internal state for a linear interpolator (using sklearn)
13"""
14from __future__ import annotations
16import copy
17import itertools
18from abc import ABC, abstractmethod
19from dataclasses import dataclass, field
21import numpy as np
22from sklearn import linear_model, preprocessing
23from sklearn.pipeline import Pipeline
24from sklearn.preprocessing import PolynomialFeatures
26from amisc.serialize import Base64Serializable, Serializable, StringSerializable
27from amisc.typing import Dataset, MultiIndex
29__all__ = ["InterpolatorState", "LagrangeState", "LinearState", "Interpolator", "Lagrange", "Linear"]
32class InterpolatorState(Serializable, ABC):
33 """Interface for a dataclass that stores the internal state of an interpolator (e.g. weights and biases)."""
34 pass
37@dataclass
38class LagrangeState(InterpolatorState, Base64Serializable):
39 """The internal state for a barycentric Lagrange polynomial interpolator.
41 :ivar weights: the 1d interpolation grid weights
42 :ivar x_grids: the 1d interpolation grids
43 """
44 weights: dict[str, np.ndarray] = field(default_factory=dict)
45 x_grids: dict[str, np.ndarray] = field(default_factory=dict)
47 def __eq__(self, other):
48 if isinstance(other, LagrangeState):
49 try:
50 return all([np.allclose(self.weights[var], other.weights[var]) for var in self.weights]) and \
51 all([np.allclose(self.x_grids[var], other.x_grids[var]) for var in self.x_grids])
52 except IndexError:
53 return False
54 else:
55 return False
58@dataclass
59class LinearState(InterpolatorState, Base64Serializable):
60 """The internal state for a linear interpolator (using sklearn).
62 :ivar x_vars: the input variables in order
63 :ivar y_vars: the output variables in order
64 :ivar regressor: the sklearn regressor object, a pipeline that consists of a `PolynomialFeatures` and a model from
65 `sklearn.linear_model`, i.e. Ridge, Lasso, etc.
66 """
67 x_vars: list[str] = field(default_factory=list)
68 y_vars: list[str] = field(default_factory=list)
69 regressor: Pipeline = None
71 def __eq__(self, other):
72 if isinstance(other, LinearState):
73 return (self.x_vars == other.x_vars and self.y_vars == other.y_vars and
74 np.allclose(self.regressor['poly'].powers_, other.regressor['poly'].powers_) and
75 np.allclose(self.regressor['linear'].coef_, other.regressor['linear'].coef_) and
76 np.allclose(self.regressor['linear'].intercept_, other.regressor['linear'].intercept_))
77 else:
78 return False
81class Interpolator(Serializable, ABC):
82 """Interface for an interpolator object that approximates a model. An interpolator should:
84 - `refine` - take an old state and new training data and produce a new "refined" state (e.g. new weights/biases)
85 - `predict` - interpolate from the training data to a new set of points (i.e. approximate the underlying model)
86 - `gradient` - compute the grdient/Jacobian at new points (if you want)
87 - `hessian` - compute the 2nd derivative/Hessian at new points (if you want)
89 Currently, `Lagrange` and `Linear` interpolators are supported and can be constructed from a configuration `dict`
90 via `Interpolator.from_dict()`.
91 """
93 @abstractmethod
94 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset],
95 old_state: InterpolatorState, input_domains: dict[str, tuple]) -> InterpolatorState:
96 """Refine the interpolator state with new training data.
98 :param beta: a multi-index specifying the fidelity "levels" of the new interpolator state (starts at (0,... 0))
99 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data
100 :param old_state: the previous state of the interpolator (None if initializing the first state)
101 :param input_domains: a `dict` mapping input variables to their corresponding domains
102 :returns: the new "refined" interpolator state
103 """
104 raise NotImplementedError
106 @abstractmethod
107 def predict(self, x: dict | Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset:
108 """Interpolate the output of the model at points `x` using the given state and training data
110 :param x: the input Dataset `dict` mapping input variables to locations at which to compute the interpolator
111 :param state: the current state of the interpolator
112 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state
113 :returns: a Dataset `dict` mapping output variables to interpolator outputs
114 """
115 raise NotImplementedError
117 def __call__(self, *args, **kwargs):
118 return self.predict(*args, **kwargs)
120 @abstractmethod
121 def gradient(self, x: Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset:
122 """Evaluate the gradient/Jacobian at points `x` using the interpolator.
124 :param x: the input Dataset `dict` mapping input variables to locations at which to evaluate the Jacobian
125 :param state: the current state of the interpolator
126 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state
127 :returns: a Dataset `dict` mapping output variables to Jacobian matrices of shape `(ydim, xdim)` -- for
128 scalar outputs, the Jacobian is returned as `(xdim,)`
129 """
130 raise NotImplementedError
132 @abstractmethod
133 def hessian(self, x: Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset:
134 """Evaluate the Hessian at points `x` using the interpolator.
136 :param x: the input Dataset `dict` mapping input variables to locations at which to evaluate the Hessian
137 :param state: the current state of the interpolator
138 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state
139 :returns: a Dataset `dict` mapping output variables to Hessian matrices of shape `(xdim, xdim)`
140 """
141 raise NotImplementedError
143 @classmethod
144 def from_dict(cls, config: dict) -> Interpolator:
145 """Create an `Interpolator` object from a `dict` config. Only `method='lagrange'` is supported for now."""
146 method = config.pop('method', 'lagrange').lower()
147 match method:
148 case 'lagrange':
149 return Lagrange(**config)
150 case 'linear':
151 return Linear(**config)
152 case other:
153 raise NotImplementedError(f"Unknown interpolator method: {other}")
156@dataclass
157class Lagrange(Interpolator, StringSerializable):
158 """Implementation of a tensor-product barycentric Lagrange polynomial interpolator. A `LagrangeState` stores
159 the 1d interpolation grids and weights for each input dimension. `Lagrange` computes the tensor-product
160 of 1d Lagrange polynomials to approximate a multi-variate function.
162 :ivar interval_capacity: tuning knob for Lagrange interpolation (see Berrut and Trefethen 2004)
163 """
164 interval_capacity: float = 4.0
166 @staticmethod
167 def _extend_grids(x_grids: dict[str, np.ndarray], x_points: dict[str, np.ndarray]):
168 """Extend the 1d `x` grids with any new points from `x_points`, skipping duplicates. This will preserve the
169 order of new points in the extended grid without duplication. This will maintain the same order as
170 `SparseGrid.x_grids` if that is the underlying training data structure.
172 !!! Example
173 ```python
174 x = {'x': np.array([0, 1, 2])}
175 new_x = {'x': np.array([3, 0, 2, 3, 1, 4])}
176 extended_x = Lagrange._extend_grids(x, new_x)
177 # gives {'x': np.array([0, 1, 2, 3, 4])}
178 ```
180 !!! Warning
181 This will only work for 1d grids; all `x_grids` should be scalar quantities. Field quantities should
182 already be passed in as several separate 1d latent coefficients.
184 :param x_grids: the current 1d interpolation grids
185 :param x_points: the new points to extend the interpolation grids with
186 :returns: the extended grids
187 """
188 extended_grids = copy.deepcopy(x_grids)
189 for var, new_pts in x_points.items():
190 # Get unique new values that are not already in the grid (maintain their order; keep only first one)
191 u = x_grids[var] if var in x_grids else np.array([])
192 u, ind = np.unique(new_pts[~np.isin(new_pts, u)], return_index=True)
193 u = u[np.argsort(ind)]
194 extended_grids[var] = u if var not in x_grids else np.concatenate((x_grids[var], u), axis=0)
196 return extended_grids
198 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset],
199 old_state: LagrangeState, input_domains: dict[str, tuple]) -> LagrangeState:
200 """Refine the interpolator state with new training data.
202 :param beta: the refinement level indices for the interpolator (not used for `Lagrange`)
203 :param training_data: a tuple of dictionaries containing the new training data (`xtrain`, `ytrain`)
204 :param old_state: the old interpolator state to refine (None if initializing)
205 :param input_domains: a `dict` of each input variable's domain; input keys should match `xtrain` keys
206 :returns: the new interpolator state
207 """
208 xtrain, ytrain = training_data # Lagrange only really needs the xtrain data to update barycentric weights/grids
210 # Initialize the interpolator state
211 if old_state is None:
212 x_grids = self._extend_grids({}, xtrain)
213 weights = {}
214 for var, grid in x_grids.items():
215 bds = input_domains[var]
216 Nx = grid.shape[0]
217 C = (bds[1] - bds[0]) / self.interval_capacity # Interval capacity (see Berrut and Trefethen 2004)
218 xj = grid.reshape((Nx, 1))
219 xi = xj.reshape((1, Nx))
220 dist = (xj - xi) / C
221 np.fill_diagonal(dist, 1) # Ignore product when i==j
222 weights[var] = (1.0 / np.prod(dist, axis=1)) # (Nx,)
224 # Otherwise, refine the interpolator state
225 else:
226 x_grids = self._extend_grids(old_state.x_grids, xtrain)
227 weights = copy.deepcopy(old_state.weights)
228 for var, grid in x_grids.items():
229 bds = input_domains[var]
230 Nx_old = old_state.x_grids[var].shape[0]
231 Nx_new = grid.shape[0]
232 if Nx_new > Nx_old:
233 weights[var] = np.pad(weights[var], [(0, Nx_new - Nx_old)], mode='constant', constant_values=np.nan)
234 C = (bds[1] - bds[0]) / self.interval_capacity
235 for j in range(Nx_old, Nx_new):
236 weights[var][:j] *= (C / (grid[:j] - grid[j]))
237 weights[var][j] = np.prod(C / (grid[j] - grid[:j]))
239 return LagrangeState(weights=weights, x_grids=x_grids)
241 def predict(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]):
242 """Predict the output of the model at points `x` with barycentric Lagrange interpolation."""
243 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim)
244 # Inputs `x` may come in unordered, but they should get realigned with the internal `x_grids` state
245 xi, yi = training_data
246 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1)
247 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1)
249 xdim = x_arr.shape[-1]
250 ydim = yi_arr.shape[-1]
251 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()}
252 max_size = max(grid_sizes.values())
253 dims = list(range(xdim))
255 # Create ragged edge matrix of interpolation pts and weights
256 x_j = np.full((xdim, max_size), np.nan) # For example:
257 w_j = np.full((xdim, max_size), np.nan) # A= [#####--
258 for n, var in enumerate(state.x_grids): # #######
259 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----]
260 w_j[n, :grid_sizes[var]] = state.weights[var]
262 diff = x_arr[..., np.newaxis] - x_j
263 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8)
264 check_interp_pts = np.sum(div_zero_idx) > 0 # whether we are evaluating directly on some interp pts
265 diff[div_zero_idx] = 1
266 quotient = w_j / diff # (..., xdim, Nx)
267 qsum = np.nansum(quotient, axis=-1) # (..., xdim)
268 y = np.zeros(x_arr.shape[:-1] + (ydim,)) # (..., ydim)
270 # Loop over multi-indices and compute tensor-product lagrange polynomials
271 indices = [range(s) for s in grid_sizes.values()]
272 for i, j in enumerate(itertools.product(*indices)):
273 L_j = quotient[..., dims, j] / qsum # (..., xdim)
275 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j
276 if check_interp_pts:
277 other_pts = np.copy(div_zero_idx)
278 other_pts[div_zero_idx[..., dims, j]] = False
279 L_j[div_zero_idx[..., dims, j]] = 1
280 L_j[np.any(other_pts, axis=-1)] = 0
282 # Add multivariate basis polynomial contribution to interpolation output
283 y += np.prod(L_j, axis=-1, keepdims=True) * yi_arr[i, :]
285 # Unpack the outputs back into a Dataset
286 y_ret = {}
287 start_idx = 0
288 for var, arr in yi.items():
289 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1
290 end_idx = start_idx + num_vals
291 y_ret[var] = y[..., start_idx:end_idx]
292 if len(arr.shape) == 1:
293 y_ret[var] = np.squeeze(y_ret[var], axis=-1) # for scalars
294 start_idx = end_idx
295 return y_ret
297 def gradient(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]):
298 """Evaluate the gradient/Jacobian at points `x` using the interpolator."""
299 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim)
300 xi, yi = training_data
301 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1)
302 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1)
304 xdim = x_arr.shape[-1]
305 ydim = yi_arr.shape[-1]
306 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()}
307 max_size = max(grid_sizes.values())
309 # Create ragged edge matrix of interpolation pts and weights
310 x_j = np.full((xdim, max_size), np.nan) # For example:
311 w_j = np.full((xdim, max_size), np.nan) # A= [#####--
312 for n, var in enumerate(state.x_grids): # #######
313 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----]
314 w_j[n, :grid_sizes[var]] = state.weights[var]
316 # Compute values ahead of time that will be needed for the gradient
317 diff = x_arr[..., np.newaxis] - x_j
318 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8)
319 check_interp_pts = np.sum(div_zero_idx) > 0
320 diff[div_zero_idx] = 1
321 quotient = w_j / diff # (..., xdim, Nx)
322 qsum = np.nansum(quotient, axis=-1) # (..., xdim)
323 sqsum = np.nansum(w_j / diff ** 2, axis=-1) # (..., xdim)
324 jac = np.zeros(x_arr.shape[:-1] + (ydim, xdim)) # (..., ydim, xdim)
326 # Loop over multi-indices and compute derivative of tensor-product lagrange polynomials
327 indices = [range(s) for s in grid_sizes.values()]
328 for k, var in enumerate(grid_sizes):
329 dims = [idx for idx in range(xdim) if idx != k]
330 for i, j in enumerate(itertools.product(*indices)):
331 j_dims = [j[idx] for idx in dims]
332 L_j = quotient[..., dims, j_dims] / qsum[..., dims] # (..., xdim-1)
334 # Partial derivative of L_j with respect to x_k
335 dLJ_dx = ((w_j[k, j[k]] / (qsum[..., k] * diff[..., k, j[k]])) *
336 (sqsum[..., k] / qsum[..., k] - 1 / diff[..., k, j[k]]))
338 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j
339 if check_interp_pts:
340 other_pts = np.copy(div_zero_idx)
341 other_pts[div_zero_idx[..., list(range(xdim)), j]] = False
342 L_j[div_zero_idx[..., dims, j_dims]] = 1
343 L_j[np.any(other_pts[..., dims, :], axis=-1)] = 0
345 # Set derivatives when x is at the interpolation points (i.e. x==x_j)
346 p_idx = [idx for idx in range(grid_sizes[var]) if idx != j[k]]
347 w_j_large = np.broadcast_to(w_j[k, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy()
348 curr_j_idx = div_zero_idx[..., k, j[k]]
349 other_j_idx = np.any(other_pts[..., k, :], axis=-1)
350 dLJ_dx[curr_j_idx] = -np.nansum((w_j[k, p_idx] / w_j[k, j[k]]) /
351 (x_arr[curr_j_idx, k, np.newaxis] - x_j[k, p_idx]), axis=-1)
352 dLJ_dx[other_j_idx] = ((w_j[k, j[k]] / w_j_large[other_pts[..., k, :]]) /
353 (x_arr[other_j_idx, k] - x_j[k, j[k]]))
355 dLJ_dx = np.expand_dims(dLJ_dx, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1)
357 # Add contribution to the Jacobian
358 jac[..., k] += dLJ_dx * yi_arr[i, :]
360 # Unpack the outputs back into a Dataset (array of length xdim for each y_var giving partial derivatives)
361 jac_ret = {}
362 start_idx = 0
363 for var, arr in yi.items():
364 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1
365 end_idx = start_idx + num_vals
366 jac_ret[var] = jac[..., start_idx:end_idx, :] # (..., ydim, xdim)
367 if len(arr.shape) == 1:
368 jac_ret[var] = np.squeeze(jac_ret[var], axis=-2) # for scalars: (..., xdim) partial derivatives
369 start_idx = end_idx
370 return jac_ret
372 def hessian(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]):
373 """Evaluate the Hessian at points `x` using the interpolator."""
374 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim)
375 xi, yi = training_data
376 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1)
377 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1)
379 xdim = x_arr.shape[-1]
380 ydim = yi_arr.shape[-1]
381 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()}
382 grid_size_list = list(grid_sizes.values())
383 max_size = max(grid_size_list)
385 # Create ragged edge matrix of interpolation pts and weights
386 x_j = np.full((xdim, max_size), np.nan) # For example:
387 w_j = np.full((xdim, max_size), np.nan) # A= [#####--
388 for n, var in enumerate(state.x_grids): # #######
389 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----]
390 w_j[n, :grid_sizes[var]] = state.weights[var]
392 # Compute values ahead of time that will be needed for the gradient
393 diff = x_arr[..., np.newaxis] - x_j
394 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8)
395 check_interp_pts = np.sum(div_zero_idx) > 0
396 diff[div_zero_idx] = 1
397 quotient = w_j / diff # (..., xdim, Nx)
398 qsum = np.nansum(quotient, axis=-1) # (..., xdim)
399 qsum_p = -np.nansum(w_j / diff ** 2, axis=-1) # (..., xdim)
400 qsum_pp = 2 * np.nansum(w_j / diff ** 3, axis=-1) # (..., xdim)
402 # Loop over multi-indices and compute 2nd derivative of tensor-product lagrange polynomials
403 hess = np.zeros(x_arr.shape[:-1] + (ydim, xdim, xdim)) # (..., ydim, xdim, xdim)
404 indices = [range(s) for s in grid_size_list]
405 for m in range(xdim):
406 for n in range(m, xdim):
407 dims = [idx for idx in range(xdim) if idx not in [m, n]]
408 for i, j in enumerate(itertools.product(*indices)):
409 j_dims = [j[idx] for idx in dims]
410 L_j = quotient[..., dims, j_dims] / qsum[..., dims] # (..., xdim-2)
412 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j
413 if check_interp_pts:
414 other_pts = np.copy(div_zero_idx)
415 other_pts[div_zero_idx[..., list(range(xdim)), j]] = False
416 L_j[div_zero_idx[..., dims, j_dims]] = 1
417 L_j[np.any(other_pts[..., dims, :], axis=-1)] = 0
419 # Cross-terms in Hessian
420 if m != n:
421 # Partial derivative of L_j with respect to x_m and x_n
422 d2LJ_dx2 = np.ones(x_arr.shape[:-1])
423 for k in [m, n]:
424 dLJ_dx = ((w_j[k, j[k]] / (qsum[..., k] * diff[..., k, j[k]])) *
425 (-qsum_p[..., k] / qsum[..., k] - 1 / diff[..., k, j[k]]))
427 # Set derivatives when x is at the interpolation points (i.e. x==x_j)
428 if check_interp_pts:
429 p_idx = [idx for idx in range(grid_size_list[k]) if idx != j[k]]
430 w_j_large = np.broadcast_to(w_j[k, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy()
431 curr_j_idx = div_zero_idx[..., k, j[k]]
432 other_j_idx = np.any(other_pts[..., k, :], axis=-1)
433 dLJ_dx[curr_j_idx] = -np.nansum((w_j[k, p_idx] / w_j[k, j[k]]) /
434 (x_arr[curr_j_idx, k, np.newaxis] - x_j[k, p_idx]),
435 axis=-1)
436 dLJ_dx[other_j_idx] = ((w_j[k, j[k]] / w_j_large[other_pts[..., k, :]]) /
437 (x_arr[other_j_idx, k] - x_j[k, j[k]]))
439 d2LJ_dx2 *= dLJ_dx
441 d2LJ_dx2 = np.expand_dims(d2LJ_dx2, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1)
442 hess[..., m, n] += d2LJ_dx2 * yi_arr[i, :]
443 hess[..., n, m] += d2LJ_dx2 * yi_arr[i, :]
445 # Diagonal terms in Hessian:
446 else:
447 front_term = w_j[m, j[m]] / (qsum[..., m] * diff[..., m, j[m]])
448 first_term = (-qsum_pp[..., m] / qsum[..., m]) + 2 * (qsum_p[..., m] / qsum[..., m]) ** 2
449 second_term = (2 * (qsum_p[..., m] / (qsum[..., m] * diff[..., m, j[m]]))
450 + 2 / diff[..., m, j[m]] ** 2)
451 d2LJ_dx2 = front_term * (first_term + second_term)
453 # Set derivatives when x is at the interpolation points (i.e. x==x_j)
454 if check_interp_pts:
455 curr_j_idx = div_zero_idx[..., m, j[m]]
456 other_j_idx = np.any(other_pts[..., m, :], axis=-1)
457 if np.any(curr_j_idx) or np.any(other_j_idx):
458 p_idx = [idx for idx in range(grid_size_list[m]) if idx != j[m]]
459 w_j_large = np.broadcast_to(w_j[m, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy()
460 x_j_large = np.broadcast_to(x_j[m, :], x_arr.shape[:-1] + x_j.shape[-1:]).copy()
462 # if these points are at the current j interpolation point
463 d2LJ_dx2[curr_j_idx] = (2 * np.nansum((w_j[m, p_idx] / w_j[m, j[m]]) /
464 (x_arr[curr_j_idx, m, np.newaxis] - x_j[m, p_idx]), # noqa: E501
465 axis=-1) ** 2 +
466 2 * np.nansum((w_j[m, p_idx] / w_j[m, j[m]]) /
467 (x_arr[curr_j_idx, m, np.newaxis] - x_j[m, p_idx]) ** 2, # noqa: E501
468 axis=-1))
470 # if these points are at any other interpolation point
471 other_pts_inv = other_pts.copy()
472 other_pts_inv[other_j_idx, m, :grid_size_list[m]] = np.invert(
473 other_pts[other_j_idx, m, :grid_size_list[m]]) # noqa: E501
474 curr_x_j = x_j_large[other_pts[..., m, :]].reshape((-1, 1))
475 other_x_j = x_j_large[other_pts_inv[..., m, :]].reshape((-1, len(p_idx)))
476 curr_w_j = w_j_large[other_pts[..., m, :]].reshape((-1, 1))
477 other_w_j = w_j_large[other_pts_inv[..., m, :]].reshape((-1, len(p_idx)))
478 curr_div = w_j[m, j[m]] / np.squeeze(curr_w_j, axis=-1)
479 curr_diff = np.squeeze(curr_x_j, axis=-1) - x_j[m, j[m]]
480 d2LJ_dx2[other_j_idx] = ((-2 * curr_div / curr_diff) * (np.nansum(
481 (other_w_j / curr_w_j) / (curr_x_j - other_x_j), axis=-1) + 1 / curr_diff))
483 d2LJ_dx2 = np.expand_dims(d2LJ_dx2, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1)
484 hess[..., m, n] += d2LJ_dx2 * yi_arr[i, :]
486 # Unpack the outputs back into a Dataset (matrix (xdim, xdim) for each y_var giving 2nd partial derivatives)
487 hess_ret = {}
488 start_idx = 0
489 for var, arr in yi.items():
490 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1
491 end_idx = start_idx + num_vals
492 hess_ret[var] = hess[..., start_idx:end_idx, :, :] # (..., ydim, xdim, xdim)
493 if len(arr.shape) == 1:
494 hess_ret[var] = np.squeeze(hess_ret[var], axis=-3) # for scalars: (..., xdim, xdim) partial derivatives
495 start_idx = end_idx
496 return hess_ret
499@dataclass
500class Linear(Interpolator, StringSerializable):
501 """Implementation of linear regression using `sklearn`. The `Linear` interpolator uses a pipeline of
502 `PolynomialFeatures` and a linear model from `sklearn.linear_model` to approximate the input-output mapping
503 with a linear combination of polynomial features. Defaults to Ridge regression (L2 regularization) with
504 polynomials of degree 1 (i.e. normal linear regression).
506 :ivar regressor: the scikit-learn linear model to use (e.g. 'Ridge', 'Lasso', 'ElasticNet', etc.).
507 :ivar scaler: the scikit-learn preprocessing scaler to use (e.g. 'MinMaxScaler', 'StandardScaler', etc.). If None,
508 no scaling is applied (default).
509 :ivar regressor_opts: options to pass to the regressor constructor
510 (see [scikit-learn](https://scikit-learn.org/stable/) documentation).
511 :ivar scaler_opts: options to pass to the scaler constructor
512 :ivar polynomial_opts: options to pass to the `PolynomialFeatures` constructor (e.g. 'degree', 'include_bias').
513 """
514 regressor: str = 'Ridge'
515 scaler: str = None
516 regressor_opts: dict = field(default_factory=dict)
517 scaler_opts: dict = field(default_factory=dict)
518 polynomial_opts: dict = field(default_factory=lambda: {'degree': 1, 'include_bias': False})
520 def __post_init__(self):
521 try:
522 getattr(linear_model, self.regressor)
523 except AttributeError:
524 raise ImportError(f"Regressor '{self.regressor}' not found in sklearn.linear_model")
526 if self.scaler is not None:
527 try:
528 getattr(preprocessing, self.scaler)
529 except AttributeError:
530 raise ImportError(f"Scaler '{self.scaler}' not found in sklearn.preprocessing")
532 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset],
533 old_state: LinearState, input_domains: dict[str, tuple]) -> InterpolatorState:
534 """Train a new linear regression model.
536 :param beta: if not empty, then the first element is the number of degrees to add to the polynomial features.
537 For example, if `beta=(1,)`, then the polynomial degree will be increased by 1. If the degree
538 is already set to 1 in `polynomial_opts` (default), then the new degree will be 2.
539 :param training_data: a tuple of dictionaries containing the new training data (`xtrain`, `ytrain`)
540 :param old_state: the old linear state to refine (only used to get the order of input/output variables)
541 :param input_domains: (not used for `Linear`)
542 :returns: the new linear state
543 """
544 polynomial_opts = self.polynomial_opts.copy()
545 degree = polynomial_opts.pop('degree', 1)
546 if beta != ():
547 degree += beta[0]
549 pipe = []
550 if self.scaler is not None:
551 pipe.append(('scaler', getattr(preprocessing, self.scaler)(**self.scaler_opts)))
552 pipe.extend([('poly', PolynomialFeatures(degree=degree, **polynomial_opts)),
553 ('linear', getattr(linear_model, self.regressor)(**self.regressor_opts))])
554 regressor = Pipeline(pipe)
556 xtrain, ytrain = training_data
558 # Get order of variables for inputs and outputs
559 if old_state is not None:
560 x_vars = old_state.x_vars
561 y_vars = old_state.y_vars
562 else:
563 x_vars = list(xtrain.keys())
564 y_vars = list(ytrain.keys())
566 # Convert to (N, xdim) and (N, ydim) arrays
567 x_arr = np.concatenate([xtrain[var][..., np.newaxis] for var in x_vars], axis=-1)
568 y_arr = np.concatenate([ytrain[var][..., np.newaxis] for var in y_vars], axis=-1)
570 regressor.fit(x_arr, y_arr)
572 return LinearState(regressor=regressor, x_vars=x_vars, y_vars=y_vars)
574 def predict(self, x: Dataset, state: LinearState, training_data: tuple[Dataset, Dataset]):
575 """Predict the output of the model at points `x` using the linear regressor provided in `state`.
577 :param x: the input Dataset `dict` mapping input variables to prediction locations
578 :param state: the state containing the linear regressor to use
579 :param training_data: not used for `Linear` (since the regressor is already trained in `state`)
580 """
581 # Convert to (N, xdim) array for sklearn
582 x_arr = np.concatenate([x[var][..., np.newaxis] for var in state.x_vars], axis=-1)
583 loop_shape = x_arr.shape[:-1]
584 x_arr = x_arr.reshape((-1, x_arr.shape[-1]))
586 y_arr = state.regressor.predict(x_arr)
587 y_arr = y_arr.reshape(loop_shape + (len(state.y_vars),)) # (..., ydim)
589 # Unpack the outputs back into a Dataset
590 return {var: y_arr[..., i] for i, var in enumerate(state.y_vars)}
592 def gradient(self, x: Dataset, state: LinearState, training_data: tuple[Dataset, Dataset]):
593 raise NotImplementedError
595 def hessian(self, x: Dataset, state: LinearState, training_data: tuple[Dataset, Dataset]):
596 raise NotImplementedError