Coverage for src/amisc/interpolator.py: 95%
258 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
1"""Provides interpolator classes. Interpolators approximate the input → output mapping of a model given
2a set of training data. The training data consists of input-output pairs, and the interpolator can be
3refined with new training data.
5Includes:
7- `Interpolator`: Abstract class providing basic structure of an interpolator
8- `Lagrange`: Concrete implementation for tensor-product barycentric Lagrange interpolation
9- `InterpolatorState`: Interface for a dataclass that stores the internal state of an interpolator
10- `LagrangeState`: The internal state for a barycentric Lagrange polynomial interpolator
11"""
12from __future__ import annotations
14import copy
15import itertools
16from abc import ABC, abstractmethod
17from dataclasses import dataclass, field
19import numpy as np
21from amisc.serialize import Base64Serializable, Serializable, StringSerializable
22from amisc.typing import Dataset, MultiIndex
24__all__ = ["InterpolatorState", "LagrangeState", "Interpolator", "Lagrange"]
27class InterpolatorState(Serializable, ABC):
28 """Interface for a dataclass that stores the internal state of an interpolator (e.g. weights and biases)."""
29 pass
32@dataclass
33class LagrangeState(InterpolatorState, Base64Serializable):
34 """The internal state for a barycentric Lagrange polynomial interpolator.
36 :ivar weights: the 1d interpolation grid weights
37 :ivar x_grids: the 1d interpolation grids
38 """
39 weights: dict[str, np.ndarray] = field(default_factory=dict)
40 x_grids: dict[str, np.ndarray] = field(default_factory=dict)
42 def __eq__(self, other):
43 if isinstance(other, LagrangeState):
44 try:
45 return all([np.allclose(self.weights[var], other.weights[var]) for var in self.weights]) and \
46 all([np.allclose(self.x_grids[var], other.x_grids[var]) for var in self.x_grids])
47 except IndexError:
48 return False
49 else:
50 return False
53class Interpolator(Serializable, ABC):
54 """Interface for an interpolator object that approximates a model. An interpolator should:
56 - `refine` - take an old state and new training data and produce a new "refined" state (e.g. new weights/biases)
57 - `predict` - interpolate from the training data to a new set of points (i.e. approximate the underlying model)
58 - `gradient` - compute the grdient/Jacobian at new points (if you want)
59 - `hessian` - compute the 2nd derivative/Hessian at new points (if you want)
61 Currently, only the `Lagrange` interpolator is supported and can be constructed from a configuration `dict`
62 via `Interpolator.from_dict()`.
63 """
65 @abstractmethod
66 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset],
67 old_state: InterpolatorState, input_domains: dict[str, tuple]) -> InterpolatorState:
68 """Refine the interpolator state with new training data.
70 :param beta: a multi-index specifying the fidelity "levels" of the new interpolator state (starts at (0,... 0))
71 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data
72 :param old_state: the previous state of the interpolator (None if initializing the first state)
73 :param input_domains: a `dict` mapping input variables to their corresponding domains
74 :returns: the new "refined" interpolator state
75 """
76 raise NotImplementedError
78 @abstractmethod
79 def predict(self, x: dict | Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset:
80 """Interpolate the output of the model at points `x` using the given state and training data
82 :param x: the input Dataset `dict` mapping input variables to locations at which to compute the interpolator
83 :param state: the current state of the interpolator
84 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state
85 :returns: a Dataset `dict` mapping output variables to interpolator outputs
86 """
87 raise NotImplementedError
89 def __call__(self, *args, **kwargs):
90 return self.predict(*args, **kwargs)
92 @abstractmethod
93 def gradient(self, x: Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset:
94 """Evaluate the gradient/Jacobian at points `x` using the interpolator.
96 :param x: the input Dataset `dict` mapping input variables to locations at which to evaluate the Jacobian
97 :param state: the current state of the interpolator
98 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state
99 :returns: a Dataset `dict` mapping output variables to Jacobian matrices of shape `(ydim, xdim)` -- for
100 scalar outputs, the Jacobian is returned as `(xdim,)`
101 """
102 raise NotImplementedError
104 @abstractmethod
105 def hessian(self, x: Dataset, state: InterpolatorState, training_data: tuple[Dataset, Dataset]) -> Dataset:
106 """Evaluate the Hessian at points `x` using the interpolator.
108 :param x: the input Dataset `dict` mapping input variables to locations at which to evaluate the Hessian
109 :param state: the current state of the interpolator
110 :param training_data: a tuple of `xi, yi` Datasets for the input/output training data for the current state
111 :returns: a Dataset `dict` mapping output variables to Hessian matrices of shape `(xdim, xdim)`
112 """
113 raise NotImplementedError
115 @classmethod
116 def from_dict(cls, config: dict) -> Interpolator:
117 """Create an `Interpolator` object from a `dict` config. Only `method='lagrange'` is supported for now."""
118 method = config.pop('method', 'lagrange').lower()
119 match method:
120 case 'lagrange':
121 return Lagrange(**config)
122 case other:
123 raise NotImplementedError(f"Unknown interpolator method: {other}")
126@dataclass
127class Lagrange(Interpolator, StringSerializable):
128 """Implementation of a tensor-product barycentric Lagrange polynomial interpolator. A `LagrangeState` stores
129 the 1d interpolation grids and weights for each input dimension. `Lagrange` computes the tensor-product
130 of 1d Lagrange polynomials to approximate a multi-variate function.
132 :ivar interval_capacity: tuning knob for Lagrange interpolation (see Berrut and Trefethen 2004)
133 """
134 interval_capacity: float = 4.0
136 @staticmethod
137 def _extend_grids(x_grids: dict[str, np.ndarray], x_points: dict[str, np.ndarray]):
138 """Extend the 1d `x` grids with any new points from `x_points`, skipping duplicates. This will preserve the
139 order of new points in the extended grid without duplication. This will maintain the same order as
140 `SparseGrid.x_grids` if that is the underlying training data structure.
142 !!! Example
143 ```python
144 x = {'x': np.array([0, 1, 2])}
145 new_x = {'x': np.array([3, 0, 2, 3, 1, 4])}
146 extended_x = Lagrange._extend_grids(x, new_x)
147 # gives {'x': np.array([0, 1, 2, 3, 4])}
148 ```
150 !!! Warning
151 This will only work for 1d grids; all `x_grids` should be scalar quantities. Field quantities should
152 already be passed in as several separate 1d latent coefficients.
154 :param x_grids: the current 1d interpolation grids
155 :param x_points: the new points to extend the interpolation grids with
156 :returns: the extended grids
157 """
158 extended_grids = copy.deepcopy(x_grids)
159 for var, new_pts in x_points.items():
160 # Get unique new values that are not already in the grid (maintain their order; keep only first one)
161 u = x_grids[var] if var in x_grids else np.array([])
162 u, ind = np.unique(new_pts[~np.isin(new_pts, u)], return_index=True)
163 u = u[np.argsort(ind)]
164 extended_grids[var] = u if var not in x_grids else np.concatenate((x_grids[var], u), axis=0)
166 return extended_grids
168 def refine(self, beta: MultiIndex, training_data: tuple[Dataset, Dataset],
169 old_state: LagrangeState, input_domains: dict[str, tuple]) -> LagrangeState:
170 """Refine the interpolator state with new training data.
172 :param beta: the refinement level indices for the interpolator (not used for `Lagrange`)
173 :param training_data: a tuple of dictionaries containing the new training data (`xtrain`, `ytrain`)
174 :param old_state: the old interpolator state to refine (None if initializing)
175 :param input_domains: a `dict` of each input variable's domain; input keys should match `xtrain` keys
176 :returns: the new interpolator state
177 """
178 xtrain, ytrain = training_data # Lagrange only really needs the xtrain data to update barycentric weights/grids
180 # Initialize the interpolator state
181 if old_state is None:
182 x_grids = self._extend_grids({}, xtrain)
183 weights = {}
184 for var, grid in x_grids.items():
185 bds = input_domains[var]
186 Nx = grid.shape[0]
187 C = (bds[1] - bds[0]) / self.interval_capacity # Interval capacity (see Berrut and Trefethen 2004)
188 xj = grid.reshape((Nx, 1))
189 xi = xj.reshape((1, Nx))
190 dist = (xj - xi) / C
191 np.fill_diagonal(dist, 1) # Ignore product when i==j
192 weights[var] = (1.0 / np.prod(dist, axis=1)) # (Nx,)
194 # Otherwise, refine the interpolator state
195 else:
196 x_grids = self._extend_grids(old_state.x_grids, xtrain)
197 weights = copy.deepcopy(old_state.weights)
198 for var, grid in x_grids.items():
199 bds = input_domains[var]
200 Nx_old = old_state.x_grids[var].shape[0]
201 Nx_new = grid.shape[0]
202 if Nx_new > Nx_old:
203 weights[var] = np.pad(weights[var], [(0, Nx_new - Nx_old)], mode='constant', constant_values=np.nan)
204 C = (bds[1] - bds[0]) / self.interval_capacity
205 for j in range(Nx_old, Nx_new):
206 weights[var][:j] *= (C / (grid[:j] - grid[j]))
207 weights[var][j] = np.prod(C / (grid[j] - grid[:j]))
209 return LagrangeState(weights=weights, x_grids=x_grids)
211 def predict(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]):
212 """Predict the output of the model at points `x` with barycentric Lagrange interpolation."""
213 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim)
214 # Inputs `x` may come in unordered, but they should get realigned with the internal `x_grids` state
215 xi, yi = training_data
216 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1)
217 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1)
219 xdim = x_arr.shape[-1]
220 ydim = yi_arr.shape[-1]
221 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()}
222 max_size = max(grid_sizes.values())
223 dims = list(range(xdim))
225 # Create ragged edge matrix of interpolation pts and weights
226 x_j = np.full((xdim, max_size), np.nan) # For example:
227 w_j = np.full((xdim, max_size), np.nan) # A= [#####--
228 for n, var in enumerate(state.x_grids): # #######
229 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----]
230 w_j[n, :grid_sizes[var]] = state.weights[var]
232 diff = x_arr[..., np.newaxis] - x_j
233 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8)
234 check_interp_pts = np.sum(div_zero_idx) > 0 # whether we are evaluating directly on some interp pts
235 diff[div_zero_idx] = 1
236 quotient = w_j / diff # (..., xdim, Nx)
237 qsum = np.nansum(quotient, axis=-1) # (..., xdim)
238 y = np.zeros(x_arr.shape[:-1] + (ydim,)) # (..., ydim)
240 # Loop over multi-indices and compute tensor-product lagrange polynomials
241 indices = [range(s) for s in grid_sizes.values()]
242 for i, j in enumerate(itertools.product(*indices)):
243 L_j = quotient[..., dims, j] / qsum # (..., xdim)
245 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j
246 if check_interp_pts:
247 other_pts = np.copy(div_zero_idx)
248 other_pts[div_zero_idx[..., dims, j]] = False
249 L_j[div_zero_idx[..., dims, j]] = 1
250 L_j[np.any(other_pts, axis=-1)] = 0
252 # Add multivariate basis polynomial contribution to interpolation output
253 y += np.prod(L_j, axis=-1, keepdims=True) * yi_arr[i, :]
255 # Unpack the outputs back into a Dataset
256 y_ret = {}
257 start_idx = 0
258 for var, arr in yi.items():
259 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1
260 end_idx = start_idx + num_vals
261 y_ret[var] = y[..., start_idx:end_idx]
262 if len(arr.shape) == 1:
263 y_ret[var] = np.squeeze(y_ret[var], axis=-1) # for scalars
264 start_idx = end_idx
265 return y_ret
267 def gradient(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]):
268 """Evaluate the gradient/Jacobian at points `x` using the interpolator."""
269 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim)
270 xi, yi = training_data
271 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1)
272 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1)
274 xdim = x_arr.shape[-1]
275 ydim = yi_arr.shape[-1]
276 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()}
277 max_size = max(grid_sizes.values())
279 # Create ragged edge matrix of interpolation pts and weights
280 x_j = np.full((xdim, max_size), np.nan) # For example:
281 w_j = np.full((xdim, max_size), np.nan) # A= [#####--
282 for n, var in enumerate(state.x_grids): # #######
283 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----]
284 w_j[n, :grid_sizes[var]] = state.weights[var]
286 # Compute values ahead of time that will be needed for the gradient
287 diff = x_arr[..., np.newaxis] - x_j
288 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8)
289 check_interp_pts = np.sum(div_zero_idx) > 0
290 diff[div_zero_idx] = 1
291 quotient = w_j / diff # (..., xdim, Nx)
292 qsum = np.nansum(quotient, axis=-1) # (..., xdim)
293 sqsum = np.nansum(w_j / diff ** 2, axis=-1) # (..., xdim)
294 jac = np.zeros(x_arr.shape[:-1] + (ydim, xdim)) # (..., ydim, xdim)
296 # Loop over multi-indices and compute derivative of tensor-product lagrange polynomials
297 indices = [range(s) for s in grid_sizes.values()]
298 for k, var in enumerate(grid_sizes):
299 dims = [idx for idx in range(xdim) if idx != k]
300 for i, j in enumerate(itertools.product(*indices)):
301 j_dims = [j[idx] for idx in dims]
302 L_j = quotient[..., dims, j_dims] / qsum[..., dims] # (..., xdim-1)
304 # Partial derivative of L_j with respect to x_k
305 dLJ_dx = ((w_j[k, j[k]] / (qsum[..., k] * diff[..., k, j[k]])) *
306 (sqsum[..., k] / qsum[..., k] - 1 / diff[..., k, j[k]]))
308 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j
309 if check_interp_pts:
310 other_pts = np.copy(div_zero_idx)
311 other_pts[div_zero_idx[..., list(range(xdim)), j]] = False
312 L_j[div_zero_idx[..., dims, j_dims]] = 1
313 L_j[np.any(other_pts[..., dims, :], axis=-1)] = 0
315 # Set derivatives when x is at the interpolation points (i.e. x==x_j)
316 p_idx = [idx for idx in range(grid_sizes[var]) if idx != j[k]]
317 w_j_large = np.broadcast_to(w_j[k, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy()
318 curr_j_idx = div_zero_idx[..., k, j[k]]
319 other_j_idx = np.any(other_pts[..., k, :], axis=-1)
320 dLJ_dx[curr_j_idx] = -np.nansum((w_j[k, p_idx] / w_j[k, j[k]]) /
321 (x_arr[curr_j_idx, k, np.newaxis] - x_j[k, p_idx]), axis=-1)
322 dLJ_dx[other_j_idx] = ((w_j[k, j[k]] / w_j_large[other_pts[..., k, :]]) /
323 (x_arr[other_j_idx, k] - x_j[k, j[k]]))
325 dLJ_dx = np.expand_dims(dLJ_dx, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1)
327 # Add contribution to the Jacobian
328 jac[..., k] += dLJ_dx * yi_arr[i, :]
330 # Unpack the outputs back into a Dataset (array of length xdim for each y_var giving partial derivatives)
331 jac_ret = {}
332 start_idx = 0
333 for var, arr in yi.items():
334 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1
335 end_idx = start_idx + num_vals
336 jac_ret[var] = jac[..., start_idx:end_idx, :] # (..., ydim, xdim)
337 if len(arr.shape) == 1:
338 jac_ret[var] = np.squeeze(jac_ret[var], axis=-2) # for scalars: (..., xdim) partial derivatives
339 start_idx = end_idx
340 return jac_ret
342 def hessian(self, x: Dataset, state: LagrangeState, training_data: tuple[Dataset, Dataset]):
343 """Evaluate the Hessian at points `x` using the interpolator."""
344 # Convert `x` and `yi` to 2d arrays: (N, xdim) and (N, ydim)
345 xi, yi = training_data
346 x_arr = np.concatenate([x[var][..., np.newaxis] for var in xi], axis=-1)
347 yi_arr = np.concatenate([yi[var][..., np.newaxis] for var in yi], axis=-1)
349 xdim = x_arr.shape[-1]
350 ydim = yi_arr.shape[-1]
351 grid_sizes = {var: grid.shape[-1] for var, grid in state.x_grids.items()}
352 grid_size_list = list(grid_sizes.values())
353 max_size = max(grid_size_list)
355 # Create ragged edge matrix of interpolation pts and weights
356 x_j = np.full((xdim, max_size), np.nan) # For example:
357 w_j = np.full((xdim, max_size), np.nan) # A= [#####--
358 for n, var in enumerate(state.x_grids): # #######
359 x_j[n, :grid_sizes[var]] = state.x_grids[var] # ###----]
360 w_j[n, :grid_sizes[var]] = state.weights[var]
362 # Compute values ahead of time that will be needed for the gradient
363 diff = x_arr[..., np.newaxis] - x_j
364 div_zero_idx = np.isclose(diff, 0, rtol=1e-4, atol=1e-8)
365 check_interp_pts = np.sum(div_zero_idx) > 0
366 diff[div_zero_idx] = 1
367 quotient = w_j / diff # (..., xdim, Nx)
368 qsum = np.nansum(quotient, axis=-1) # (..., xdim)
369 qsum_p = -np.nansum(w_j / diff ** 2, axis=-1) # (..., xdim)
370 qsum_pp = 2 * np.nansum(w_j / diff ** 3, axis=-1) # (..., xdim)
372 # Loop over multi-indices and compute 2nd derivative of tensor-product lagrange polynomials
373 hess = np.zeros(x_arr.shape[:-1] + (ydim, xdim, xdim)) # (..., ydim, xdim, xdim)
374 indices = [range(s) for s in grid_size_list]
375 for m in range(xdim):
376 for n in range(m, xdim):
377 dims = [idx for idx in range(xdim) if idx not in [m, n]]
378 for i, j in enumerate(itertools.product(*indices)):
379 j_dims = [j[idx] for idx in dims]
380 L_j = quotient[..., dims, j_dims] / qsum[..., dims] # (..., xdim-2)
382 # Set L_j(x==x_j)=1 for the current j and set L_j(x==x_j)=0 for x_j = x_i, i != j
383 if check_interp_pts:
384 other_pts = np.copy(div_zero_idx)
385 other_pts[div_zero_idx[..., list(range(xdim)), j]] = False
386 L_j[div_zero_idx[..., dims, j_dims]] = 1
387 L_j[np.any(other_pts[..., dims, :], axis=-1)] = 0
389 # Cross-terms in Hessian
390 if m != n:
391 # Partial derivative of L_j with respect to x_m and x_n
392 d2LJ_dx2 = np.ones(x_arr.shape[:-1])
393 for k in [m, n]:
394 dLJ_dx = ((w_j[k, j[k]] / (qsum[..., k] * diff[..., k, j[k]])) *
395 (-qsum_p[..., k] / qsum[..., k] - 1 / diff[..., k, j[k]]))
397 # Set derivatives when x is at the interpolation points (i.e. x==x_j)
398 if check_interp_pts:
399 p_idx = [idx for idx in range(grid_size_list[k]) if idx != j[k]]
400 w_j_large = np.broadcast_to(w_j[k, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy()
401 curr_j_idx = div_zero_idx[..., k, j[k]]
402 other_j_idx = np.any(other_pts[..., k, :], axis=-1)
403 dLJ_dx[curr_j_idx] = -np.nansum((w_j[k, p_idx] / w_j[k, j[k]]) /
404 (x_arr[curr_j_idx, k, np.newaxis] - x_j[k, p_idx]),
405 axis=-1)
406 dLJ_dx[other_j_idx] = ((w_j[k, j[k]] / w_j_large[other_pts[..., k, :]]) /
407 (x_arr[other_j_idx, k] - x_j[k, j[k]]))
409 d2LJ_dx2 *= dLJ_dx
411 d2LJ_dx2 = np.expand_dims(d2LJ_dx2, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1)
412 hess[..., m, n] += d2LJ_dx2 * yi_arr[i, :]
413 hess[..., n, m] += d2LJ_dx2 * yi_arr[i, :]
415 # Diagonal terms in Hessian:
416 else:
417 front_term = w_j[m, j[m]] / (qsum[..., m] * diff[..., m, j[m]])
418 first_term = (-qsum_pp[..., m] / qsum[..., m]) + 2 * (qsum_p[..., m] / qsum[..., m]) ** 2
419 second_term = (2 * (qsum_p[..., m] / (qsum[..., m] * diff[..., m, j[m]]))
420 + 2 / diff[..., m, j[m]] ** 2)
421 d2LJ_dx2 = front_term * (first_term + second_term)
423 # Set derivatives when x is at the interpolation points (i.e. x==x_j)
424 if check_interp_pts:
425 curr_j_idx = div_zero_idx[..., m, j[m]]
426 other_j_idx = np.any(other_pts[..., m, :], axis=-1)
427 if np.any(curr_j_idx) or np.any(other_j_idx):
428 p_idx = [idx for idx in range(grid_size_list[m]) if idx != j[m]]
429 w_j_large = np.broadcast_to(w_j[m, :], x_arr.shape[:-1] + w_j.shape[-1:]).copy()
430 x_j_large = np.broadcast_to(x_j[m, :], x_arr.shape[:-1] + x_j.shape[-1:]).copy()
432 # if these points are at the current j interpolation point
433 d2LJ_dx2[curr_j_idx] = (2 * np.nansum((w_j[m, p_idx] / w_j[m, j[m]]) /
434 (x_arr[curr_j_idx, m, np.newaxis] - x_j[m, p_idx]), # noqa: E501
435 axis=-1) ** 2 +
436 2 * np.nansum((w_j[m, p_idx] / w_j[m, j[m]]) /
437 (x_arr[curr_j_idx, m, np.newaxis] - x_j[m, p_idx]) ** 2, # noqa: E501
438 axis=-1))
440 # if these points are at any other interpolation point
441 other_pts_inv = other_pts.copy()
442 other_pts_inv[other_j_idx, m, :grid_size_list[m]] = np.invert(
443 other_pts[other_j_idx, m, :grid_size_list[m]]) # noqa: E501
444 curr_x_j = x_j_large[other_pts[..., m, :]].reshape((-1, 1))
445 other_x_j = x_j_large[other_pts_inv[..., m, :]].reshape((-1, len(p_idx)))
446 curr_w_j = w_j_large[other_pts[..., m, :]].reshape((-1, 1))
447 other_w_j = w_j_large[other_pts_inv[..., m, :]].reshape((-1, len(p_idx)))
448 curr_div = w_j[m, j[m]] / np.squeeze(curr_w_j, axis=-1)
449 curr_diff = np.squeeze(curr_x_j, axis=-1) - x_j[m, j[m]]
450 d2LJ_dx2[other_j_idx] = ((-2 * curr_div / curr_diff) * (np.nansum(
451 (other_w_j / curr_w_j) / (curr_x_j - other_x_j), axis=-1) + 1 / curr_diff))
453 d2LJ_dx2 = np.expand_dims(d2LJ_dx2, axis=-1) * np.prod(L_j, axis=-1, keepdims=True) # (..., 1)
454 hess[..., m, n] += d2LJ_dx2 * yi_arr[i, :]
456 # Unpack the outputs back into a Dataset (matrix (xdim, xdim) for each y_var giving 2nd partial derivatives)
457 hess_ret = {}
458 start_idx = 0
459 for var, arr in yi.items():
460 num_vals = arr.shape[-1] if len(arr.shape) > 1 else 1
461 end_idx = start_idx + num_vals
462 hess_ret[var] = hess[..., start_idx:end_idx, :, :] # (..., ydim, xdim, xdim)
463 if len(arr.shape) == 1:
464 hess_ret[var] = np.squeeze(hess_ret[var], axis=-3) # for scalars: (..., xdim, xdim) partial derivatives
465 start_idx = end_idx
466 return hess_ret