Coverage for src/amisc/training.py: 88%
265 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
1"""Classes for storing and managing training data for surrogate models. The `TrainingData` interface also
2specifies how new training data should be sampled over the input space (i.e. experimental design).
4Includes:
6- `TrainingData` — an interface for storing surrogate training data.
7- `SparseGrid` — a class for storing training data in a sparse grid format.
8"""
9from __future__ import annotations
11import copy
12import itertools
13from abc import ABC, abstractmethod
14from dataclasses import dataclass, field
15from typing import Any, ClassVar
17import numpy as np
18from numpy.typing import ArrayLike
19from scipy.optimize import direct
21from amisc.serialize import PickleSerializable, Serializable
22from amisc.typing import LATENT_STR_ID, Dataset, MultiIndex
23from amisc.utils import _RidgeRegression
25__all__ = ['TrainingData', 'SparseGrid']
28class TrainingData(Serializable, ABC):
29 """Interface for storing and collecting surrogate training data. `TrainingData` objects should:
31 - `get` - retrieve the training data
32 - `set` - store the training data
33 - `refine` - generate new design points for the parent `Component` model
34 - `clear` - clear all training data
35 - `set_errors` - store error information (if desired)
36 - `impute_missing_data` - fill in missing values in the training data (if desired)
37 """
39 @abstractmethod
40 def get(self, alpha: MultiIndex, beta: MultiIndex, y_vars: list[str] = None,
41 skip_nan: bool = False) -> tuple[Dataset, Dataset]:
42 """Return the training data for a given multi-index pair.
44 :param alpha: the model fidelity indices
45 :param beta: the surrogate fidelity indices
46 :param y_vars: the keys of the outputs to return (if `None`, return all outputs)
47 :param skip_nan: skip any data points with remaining `nan` values if `skip_nan=True`
48 :returns: `dicts` of model inputs `x_train` and outputs `y_train`
49 """
50 raise NotImplementedError
52 @abstractmethod
53 def set(self, alpha: MultiIndex, beta: MultiIndex, coords: list[Any], yi_dict: Dataset):
54 """Store training data for a given multi-index pair.
56 :param alpha: the model fidelity indices
57 :param beta: the surrogate fidelity indices
58 :param coords: locations for storing the `yi` values in the underlying data structure
59 :param yi_dict: a `dict` of model output `yi` values, each entry should be the same length as `coords`
60 """
61 raise NotImplementedError
63 @abstractmethod
64 def set_errors(self, alpha: MultiIndex, beta: MultiIndex, coords: list[Any], errors: list[dict]):
65 """Store error information for a given multi-index pair (just pass if you don't care).
67 :param alpha: the model fidelity indices
68 :param beta: the surrogate fidelity indices
69 :param coords: locations for storing the error information in the underlying data structure
70 :param errors: a list of error dictionaries, should be the same length as `coords`
71 """
72 raise NotImplementedError
74 @abstractmethod
75 def impute_missing_data(self, alpha: MultiIndex, beta: MultiIndex):
76 """Impute missing values in the training data for a given multi-index pair (just pass if you don't care).
78 :param alpha: the model fidelity indices
79 :param beta: the surrogate fidelity indices
80 """
81 raise NotImplementedError
83 @abstractmethod
84 def refine(self, alpha: MultiIndex, beta: MultiIndex, input_domains: dict[str, tuple],
85 weight_fcns: dict[str, callable] = None) -> tuple[list[Any], Dataset]:
86 """Return new design/training points for a given multi-index pair and their coordinates/locations in the
87 `TrainingData` storage structure.
89 !!! Example
90 ```python
91 domains = {'x1': (0, 1), 'x2': (0, 1)}
92 alpha, beta = (0, 1), (1, 1)
93 coords, x_train = training_data.refine(alpha, beta, domains)
94 y_train = my_model(x_train)
95 training_data.set(alpha, beta, coords, y_train)
96 ```
98 The returned data coordinates `coords` should be any object that can be used to locate the corresponding
99 `x_train` training points in the `TrainingData` storage structure. These `coords` will be passed back to the
100 `set` function to store the training data at a later time (i.e. after model evaluation).
102 :param alpha: the model fidelity indices
103 :param beta: the surrogate fidelity indices
104 :param input_domains: a `dict` specifying domain bounds for each input variable
105 :param weight_fcns: a `dict` of weighting functions for each input variable
106 :returns: a list of new data coordinates `coords` and the corresponding training points `x_train`
107 """
108 raise NotImplementedError
110 @abstractmethod
111 def clear(self):
112 """Clear all training data."""
113 raise NotImplementedError
115 @classmethod
116 def from_dict(cls, config: dict) -> TrainingData:
117 """Create a `TrainingData` object from a `dict` configuration. Currently, only `method='sparse-grid'` is
118 supported for the `SparseGrid` class.
119 """
120 method = config.pop('method', 'sparse-grid').lower()
121 match method:
122 case 'sparse-grid':
123 return SparseGrid(**config)
124 case other:
125 raise NotImplementedError(f"Unknown training data method: {other}")
128@dataclass
129class SparseGrid(TrainingData, PickleSerializable):
130 """A class for storing training data in a sparse grid format. The `SparseGrid` class stores training points
131 by their coordinate location in a larger tensor-product grid, and obtains new training data by refining
132 a single 1d grid at a time.
134 !!! Note "MISC and sparse grids"
135 MISC itself can be thought of as an extension to the well-known sparse grid technique, so this class
136 readily integrates with the MISC implementation in `Component`. Sparse grids limit the curse
137 of dimensionality up to about `dim = 10-15` for the input space (which would otherwise be infeasible with a
138 normal full tensor-product grid of the same size).
140 !!! Info "About points in a sparse grid"
141 A sparse grid approximates a full tensor-product grid $(N_1, N_2, ..., N_d)$, where $N_i$ is the number of grid
142 points along dimension $i$, for a $d$-dimensional space. Each point is uniquely identified in the sparse grid
143 by a list of indices $(j_1, j_2, ..., j_d)$, where $j_i = 0 ... N_i$. We refer to this unique identifier as a
144 "grid coordinate". In the `SparseGrid` data structure, these coordinates are used along with the `alpha`
145 fidelity index to uniquely locate the training data for a given multi-index pair.
147 :ivar collocation_rule: the collocation rule to use for generating new grid points (only 'leja' is supported)
148 :ivar knots_per_level: the number of grid knots/points per level in the `beta` fidelity multi-index
149 :ivar expand_latent_method: method for expanding latent grids, either 'round-robin' or 'tensor-product'
150 :ivar opt_args: extra arguments for the global 1d `direct` optimizer
151 :ivar betas: a set of all `beta` multi-indices that have been seen so far
152 :ivar x_grids: a `dict` of grid points for each 1d input dimension
153 :ivar yi_map: a `dict` of model outputs for each grid coordinate
154 :ivar yi_nan_map: a `dict` of imputed model outputs for each grid coordinate where the model failed (or gave nan)
155 :ivar error_map: a `dict` of error information for each grid coordinate where the model failed
156 :ivar latent_size: the number of latent coefficients for each variable (0 if scalar)
157 """
158 MAX_IMPUTE_SIZE: ClassVar[int] = 10 # don't try to impute large arrays
160 collocation_rule: str = 'leja'
161 knots_per_level: int = 2
162 expand_latent_method: str = 'round-robin' # or 'tensor-product', for converting beta to latent grid sizes
163 opt_args: dict = field(default_factory=lambda: {'locally_biased': False, 'maxfun': 300}) # for leja optimizer
165 betas: set[MultiIndex] = field(default_factory=set)
166 x_grids: dict[str, ArrayLike] = field(default_factory=dict)
167 yi_map: dict[MultiIndex, dict[tuple[int, ...], dict[str, ArrayLike]]] = field(default_factory=dict)
168 yi_nan_map: dict[MultiIndex, dict[tuple[int, ...], dict[str, ArrayLike]]] = field(default_factory=dict)
169 error_map: dict[MultiIndex, dict[tuple[int, ...], dict[str, Any]]] = field(default_factory=dict)
170 latent_size: dict[str, int] = field(default_factory=dict) # keep track of latent grid sizes for each variable
172 def clear(self):
173 """Clear all training data."""
174 self.betas.clear()
175 self.x_grids.clear()
176 self.yi_map.clear()
177 self.yi_nan_map.clear()
178 self.error_map.clear()
179 self.latent_size.clear()
181 def get_by_coord(self, alpha: MultiIndex, coords: list, y_vars: list = None, skip_nan: bool = False):
182 """Get training data from the sparse grid for a given `alpha` and list of grid coordinates. Try to replace
183 `nan` values with imputed values. Skip any data points with remaining `nan` values if `skip_nan=True`.
185 :param alpha: the model fidelity indices
186 :param coords: a list of grid coordinates to locate the `yi` values in the sparse grid data structure
187 :param y_vars: the keys of the outputs to return (if `None`, return all outputs)
188 :param skip_nan: skip any data points with remaining `nan` values if `skip_nan=True` (only for numeric outputs)
189 :returns: `dicts` of model inputs `xi_dict` and outputs `yi_dict`
190 """
191 N = len(coords)
192 is_numeric = {}
193 is_singleton = {}
194 xi_dict = self._extract_grid_points(coords)
195 yi_dict = {}
197 first_yi = next(iter(self.yi_map[alpha].values()))
198 if y_vars is None:
199 y_vars = first_yi.keys()
201 for var in y_vars:
202 yi = np.atleast_1d(first_yi[var])
203 is_numeric[var] = self._is_numeric(yi)
204 is_singleton[var] = self._is_singleton(yi)
205 yi_dict[var] = np.empty(N, dtype=np.float64 if is_numeric[var] and is_singleton[var] else object)
207 for i, coord in enumerate(coords):
208 try:
209 yi_curr = self.yi_map[alpha][coord]
210 for var in y_vars:
211 yi = arr if (arr := self.yi_nan_map[alpha].get(coord, {}).get(var)) is not None else yi_curr[var]
212 yi_dict[var][i] = yi if is_singleton[var] else np.atleast_1d(yi)
214 except KeyError as e:
215 raise KeyError(f"Can't access sparse grid data for alpha={alpha}, coord={coord}. "
216 f"Make sure the data has been set first.") from e
218 # Delete nans if requested (only for numeric singleton outputs)
219 if skip_nan:
220 nan_idx = np.full(N, False)
221 for var in y_vars:
222 if is_numeric[var] and is_singleton[var]:
223 nan_idx |= np.isnan(yi_dict[var])
225 xi_dict = {k: v[~nan_idx] for k, v in xi_dict.items()}
226 yi_dict = {k: v[~nan_idx] for k, v in yi_dict.items()}
228 return xi_dict, yi_dict # Both with elements of shape (N, ...) for N grid points
230 def get(self, alpha: MultiIndex, beta: MultiIndex, y_vars: list[str] = None, skip_nan: bool = False):
231 """Get the training data from the sparse grid for a given `alpha` and `beta` pair."""
232 return self.get_by_coord(alpha, list(self._expand_grid_coords(beta)), y_vars=y_vars, skip_nan=skip_nan)
234 def set_errors(self, alpha: MultiIndex, beta: MultiIndex, coords: list, errors: list[dict]):
235 """Store error information in the sparse-grid for a given multi-index pair."""
236 for coord, error in zip(coords, errors):
237 self.error_map[alpha][coord] = copy.deepcopy(error)
239 def set(self, alpha: MultiIndex, beta: MultiIndex, coords: list, yi_dict: dict[str, ArrayLike]):
240 """Store model output `yi_dict` values.
242 :param alpha: the model fidelity indices
243 :param beta: the surrogate fidelity indices
244 :param coords: a list of grid coordinates to locate the `yi` values in the sparse grid data structure
245 :param yi_dict: a `dict` of model output `yi` values
246 """
247 for i, coord in enumerate(coords): # First dim of yi is loop dim aligning with coords
248 new_yi = {}
249 for var, yi in yi_dict.items():
250 yi = np.atleast_1d(yi[i])
251 new_yi[var] = (float(yi[0]) if self._is_numeric(yi) else yi[0]) if self._is_singleton(yi) else yi.tolist() # noqa: E501
252 self.yi_map[alpha][coord] = copy.deepcopy(new_yi)
254 def impute_missing_data(self, alpha: MultiIndex, beta: MultiIndex):
255 """Impute missing values in the sparse grid for a given multi-index pair by linear regression imputation."""
256 imputer, xi_all, yi_all = None, None, None
258 # only impute (small-length) numeric quantities
259 yi_dict = next(iter(self.yi_map[alpha].values()))
260 output_vars = [var for var in self._numeric_outputs(yi_dict)
261 if len(np.ravel(yi_dict[var])) <= self.MAX_IMPUTE_SIZE]
263 for coord, yi_dict in self.yi_map[alpha].items():
264 if any([np.any(np.isnan(yi_dict[var])) for var in output_vars]):
265 if imputer is None:
266 # Grab all 'good' interpolation points and train a simple linear regression fit
267 xi_all, yi_all = self.get(alpha, beta, y_vars=output_vars, skip_nan=True)
268 if len(xi_all) == 0 or len(next(iter(xi_all.values()))) == 0:
269 continue # possible if no good data has been set yet
271 N = next(iter(xi_all.values())).shape[0] # number of grid points
272 xi_mat = np.concatenate([xi_all[var][:, np.newaxis] if len(xi_all[var].shape) == 1 else
273 xi_all[var] for var in xi_all.keys()], axis=-1)
274 yi_mat = np.concatenate([yi_all[var][:, np.newaxis] if len(yi_all[var].shape) == 1 else
275 yi_all[var].reshape((N, -1)) for var in output_vars], axis=-1)
277 imputer = _RidgeRegression(alpha=1.0)
278 imputer.fit(xi_mat, yi_mat)
280 # Run the imputer for this coordinate
281 x_interp = self._extract_grid_points(coord)
282 x_interp = np.concatenate([x_interp[var][:, np.newaxis] if len(x_interp[var].shape) == 1 else
283 x_interp[var] for var in x_interp.keys()], axis=-1)
284 y_interp = imputer.predict(x_interp)
286 # Unpack the imputed value
287 y_impute = {}
288 start_idx = 0
289 for var in output_vars:
290 var_shape = yi_all[var].shape[1:] or (1,)
291 end_idx = start_idx + int(np.prod(var_shape))
292 yi = np.atleast_1d(y_interp[0, start_idx:end_idx]).reshape(var_shape)
293 nan_idx = np.isnan(np.atleast_1d(yi_dict[var]))
294 yi[~nan_idx] = np.atleast_1d(yi_dict[var])[~nan_idx] # Only keep imputed values where yi is nan
295 y_impute[var] = float(yi[0]) if self._is_singleton(yi) else yi.tolist()
296 start_idx = end_idx
298 self.yi_nan_map[alpha][coord] = copy.deepcopy(y_impute)
300 def refine(self, alpha: MultiIndex, beta: MultiIndex, input_domains: dict, weight_fcns: dict = None):
301 """Refine the sparse grid for a given `alpha` and `beta` pair and given collocation rules. Return any new
302 grid points that do not have model evaluations saved yet.
304 !!! Note
305 The `beta` multi-index is used to determine the number of collocation points in each input dimension. The
306 length of `beta` should therefore match the number of variables in `x_vars`.
307 """
308 weight_fcns = weight_fcns or {}
310 # Initialize a sparse grid for beta=(0, 0, ..., 0)
311 if np.sum(beta) == 0:
312 if len(self.x_grids) == 0:
313 num_latent = {}
314 for var in input_domains:
315 if LATENT_STR_ID in var:
316 base_id = var.split(LATENT_STR_ID)[0]
317 num_latent[base_id] = 1 if base_id not in num_latent else num_latent[base_id] + 1
318 else:
319 num_latent[var] = 0
320 self.latent_size = num_latent
322 new_pt = {}
323 domains = iter(input_domains.items())
324 for grid_size in self.beta_to_knots(beta):
325 if isinstance(grid_size, int): # scalars
326 var, domain = next(domains)
327 new_pt[var] = self.collocation_1d(grid_size, domain, method=self.collocation_rule,
328 wt_fcn=weight_fcns.get(var, None),
329 opt_args=self.opt_args).tolist()
330 else: # latent coeffs
331 for s in grid_size:
332 var, domain = next(domains)
333 new_pt[var] = self.collocation_1d(s, domain, method=self.collocation_rule,
334 wt_fcn=weight_fcns.get(var, None),
335 opt_args=self.opt_args).tolist()
336 self.x_grids = new_pt
337 self.betas.add(beta)
338 self.yi_map.setdefault(alpha, dict())
339 self.yi_nan_map.setdefault(alpha, dict())
340 self.error_map.setdefault(alpha, dict())
341 new_coords = list(self._expand_grid_coords(beta))
342 return new_coords, self._extract_grid_points(new_coords)
344 # Otherwise, refine the sparse grid
345 for beta_old in self.betas:
346 # Get the first lower neighbor in the sparse grid and refine the 1d grid if necessary
347 if self.is_one_level_refinement(beta_old, beta):
348 new_grid_size = self.beta_to_knots(beta)
349 inputs = zip(self.x_grids.keys(), self.x_grids.values(), input_domains.values())
351 for new_size in new_grid_size:
352 if isinstance(new_size, int): # scalar grid
353 var, grid, domain = next(inputs)
354 if len(grid) < new_size:
355 num_new_pts = new_size - len(grid)
356 self.x_grids[var] = self.collocation_1d(num_new_pts, domain, grid, opt_args=self.opt_args,
357 wt_fcn=weight_fcns.get(var, None),
358 method=self.collocation_rule).tolist()
359 else: # latent grid
360 for s_new in new_size:
361 var, grid, domain = next(inputs)
362 if len(grid) < s_new:
363 num_new_pts = s_new - len(grid)
364 self.x_grids[var] = self.collocation_1d(num_new_pts, domain, grid,
365 opt_args=self.opt_args,
366 wt_fcn=weight_fcns.get(var, None),
367 method=self.collocation_rule).tolist()
368 break
370 new_coords = []
371 for coord in self._expand_grid_coords(beta):
372 if coord not in self.yi_map[alpha]:
373 # If we have not computed this grid coordinate yet
374 new_coords.append(coord)
376 new_pts = self._extract_grid_points(new_coords)
378 self.betas.add(beta)
379 return new_coords, new_pts
381 def _extract_grid_points(self, coords: list[tuple] | tuple):
382 """Extract the `x` grid points located at `coords` from `x_grids` and return as the `pts` dictionary."""
383 if not isinstance(coords, list):
384 coords = [coords]
385 pts = {var: np.empty(len(coords)) for var in self.x_grids}
387 for k, coord in enumerate(coords):
388 grids = iter(self.x_grids.items())
389 for idx in coord:
390 if isinstance(idx, int): # scalar grid point
391 var, grid = next(grids)
392 pts[var][k] = grid[idx]
393 else: # latent coefficients
394 for i in idx:
395 var, grid = next(grids)
396 pts[var][k] = grid[i]
398 return pts
400 def _expand_grid_coords(self, beta: MultiIndex):
401 """Iterable over all grid coordinates for a given `beta`, accounting for scalars and latent coefficients."""
402 grid_size = self.beta_to_knots(beta)
403 grid_coords = []
404 for s in grid_size:
405 if isinstance(s, int): # scalars
406 grid_coords.append(range(s))
407 else: # latent coefficients
408 grid_coords.append(itertools.product(*[range(latent_size) for latent_size in s]))
410 yield from itertools.product(*grid_coords)
412 @staticmethod
413 def _is_singleton(arr: np.ndarray):
414 return len(arr.shape) == 1 and arr.shape[0] == 1
416 @staticmethod
417 def _is_numeric(arr: np.ndarray):
418 return np.issubdtype(arr.dtype, np.number)
420 @classmethod
421 def _numeric_outputs(cls, yi_dict: dict[str, ArrayLike]) -> list[str]:
422 """Return a list of the output variables that have numeric data."""
423 output_vars = []
424 for var in yi_dict.keys():
425 try:
426 if cls._is_numeric(np.atleast_1d(yi_dict[var])):
427 output_vars.append(var)
428 except Exception:
429 continue
430 return output_vars
432 @staticmethod
433 def is_one_level_refinement(beta_old: tuple, beta_new: tuple) -> bool:
434 """Check if a new `beta` multi-index is a one-level refinement from a previous `beta`.
436 !!! Example
437 Refining from `(0, 1, 2)` to the new multi-index `(1, 1, 2)` is a one-level refinement. But refining to
438 either `(2, 1, 2)` or `(1, 2, 2)` are not, since more than one refinement occurs at the same time.
440 :param beta_old: the starting multi-index
441 :param beta_new: the new refined multi-index
442 :returns: whether `beta_new` is a one-level refinement from `beta_old`
443 """
444 level_diff = np.array(beta_new, dtype=int) - np.array(beta_old, dtype=int)
445 ind = np.nonzero(level_diff)[0]
446 return ind.shape[0] == 1 and level_diff[ind] == 1
448 def beta_to_knots(self, beta: MultiIndex, knots_per_level: int = None, latent_size: dict = None,
449 expand_latent_method: str = None) -> tuple:
450 """Convert a `beta` multi-index to the number of knots per dimension in the sparse grid.
452 :param beta: refinement level indices
453 :param knots_per_level: level-to-grid-size multiplier, i.e. number of new points (or knots) for each beta level
454 :param latent_size: the number of latent coefficients for each variable (0 if scalar); number of variables and
455 order should match the `beta` multi-index
456 :param expand_latent_method: method for expanding latent grids, either 'round-robin' or 'tensor-product'
457 :returns: the number of knots/points per dimension for the sparse grid
458 """
459 knots_per_level = knots_per_level or self.knots_per_level
460 latent_size = latent_size or self.latent_size
461 expand_latent_method = expand_latent_method or self.expand_latent_method
463 grid_size = []
464 for i, (var, num_latent) in enumerate(latent_size.items()):
465 if num_latent > 0:
466 match expand_latent_method:
467 case 'round-robin':
468 if beta[i] == 0:
469 grid_size.append((1,) * num_latent) # initializes all latent grids to 1
470 else:
471 latent_refine_idx = (beta[i] - 1) % num_latent
472 latent_refine_num = ((beta[i] - 1) // num_latent) + 1
473 latent_beta = tuple([latent_refine_num] * (latent_refine_idx + 1) +
474 [latent_refine_num - 1] * (num_latent - latent_refine_idx - 1))
475 latent_grid = [knots_per_level * latent_beta[j] + 1 for j in range(num_latent)]
476 grid_size.append(tuple(latent_grid))
477 case 'tensor-product':
478 grid_size.append((knots_per_level * beta[i] + 1,) * num_latent)
479 case other:
480 raise NotImplementedError(f"Unknown method for expanding latent grids: {other}")
481 else:
482 grid_size.append(knots_per_level * beta[i] + 1)
484 return tuple(grid_size)
486 @staticmethod
487 def collocation_1d(N: int, z_bds: tuple, z_pts: np.ndarray = None,
488 wt_fcn: callable = None, method='leja', opt_args=None) -> np.ndarray:
489 """Find the next `N` points in the 1d sequence of `z_pts` using the provided collocation method.
491 :param N: number of new points to add to the sequence
492 :param z_bds: bounds on the 1d domain
493 :param z_pts: current univariate sequence `(Nz,)`, start at middle of `z_bds` if `None`
494 :param wt_fcn: weighting function, uses a constant weight if `None`, callable as `wt_fcn(z)`
495 :param method: collocation method to use, currently only 'leja' is supported
496 :param opt_args: extra arguments for the global 1d `direct` optimizer
497 :returns: the univariate sequence `z_pts` augmented by `N` new points
498 """
499 opt_args = opt_args or {}
500 if wt_fcn is None:
501 wt_fcn = lambda z: 1
502 if z_pts is None:
503 z_pts = (z_bds[1] + z_bds[0]) / 2
504 N = N - 1
505 z_pts = np.atleast_1d(z_pts)
507 match method:
508 case 'leja':
509 # Construct Leja sequence by maximizing the Leja objective sequentially
510 for i in range(N):
511 obj_fun = lambda z: -wt_fcn(np.array(z)) * np.prod(np.abs(z - z_pts))
512 res = direct(obj_fun, [z_bds], **opt_args) # Use global DIRECT optimization over 1d domain
513 z_star = res.x
514 z_pts = np.concatenate((z_pts, z_star))
515 case other:
516 raise NotImplementedError(f"Unknown collocation method: {other}")
518 return z_pts