Coverage for src/amisc/system.py: 87%
1029 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-10 15:12 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-10 15:12 +0000
1"""The `System` object is a framework for multidisciplinary models. It manages multiple single discipline component
2models and the connections between them. It provides a top-level interface for constructing and evaluating surrogates.
4Features:
6- Manages multidisciplinary models in a graph data structure, supports feedforward and feedback connections
7- Feedback connections are solved with a fixed-point iteration (FPI) nonlinear solver with anderson acceleration
8- Top-level interface for training and using surrogates of each component model
9- Adaptive experimental design for choosing training data efficiently
10- Convenient testing, plotting, and performance metrics provided to assess quality of surrogates
11- Detailed logging and traceback information
12- Supports parallel or vectorized execution of component models
13- Abstract and flexible interfacing with component models
14- Easy serialization and deserialization to/from YAML files
15- Supports approximating field quantities via compression
17Includes:
19- `TrainHistory` — a history of training iterations for the system surrogate
20- `System` — the top-level object for managing multidisciplinary models
21"""
22# ruff: noqa: E702
23from __future__ import annotations
25import copy
26import datetime
27import functools
28import logging
29import os
30import pickle
31import random
32import string
33import time
34import warnings
35from collections import ChainMap, UserList, deque
36from concurrent.futures import ALL_COMPLETED, Executor, wait
37from datetime import timezone
38from pathlib import Path
39from typing import Annotated, Callable, ClassVar, Literal, Optional
41import matplotlib.pyplot as plt
42import networkx as nx
43import numpy as np
44import yaml
45from matplotlib.ticker import MaxNLocator
46from pydantic import BaseModel, ConfigDict, Field, field_validator
48from amisc.component import Component, IndexSet, MiscTree
49from amisc.serialize import Serializable, _builtin
50from amisc.typing import COORDS_STR_ID, LATENT_STR_ID, Dataset, MultiIndex, TrainIteration
51from amisc.utils import (
52 _combine_latent_arrays,
53 constrained_lls,
54 format_inputs,
55 format_outputs,
56 get_logger,
57 relative_error,
58 to_model_dataset,
59 to_surrogate_dataset,
60)
61from amisc.variable import VariableList
63__all__ = ['TrainHistory', 'System']
66class TrainHistory(UserList, Serializable):
67 """Stores the training history of a system surrogate as a list of `TrainIteration` objects."""
69 def __init__(self, data: list = None):
70 data = data or []
71 super().__init__(self._validate_data(data))
73 def serialize(self) -> list[dict]:
74 """Return a list of each result in the history serialized to a `dict`."""
75 ret_list = []
76 for res in self:
77 new_res = res.copy()
78 new_res['alpha'] = str(res['alpha'])
79 new_res['beta'] = str(res['beta'])
80 ret_list.append(new_res)
81 return ret_list
83 @classmethod
84 def deserialize(cls, serialized_data: list[dict]) -> TrainHistory:
85 """Deserialize a list of `dict` objects into a `TrainHistory` object."""
86 return TrainHistory(serialized_data)
88 @classmethod
89 def _validate_data(cls, data: list[dict]) -> list[TrainIteration]:
90 return [cls._validate_item(item) for item in data]
92 @classmethod
93 def _validate_item(cls, item: dict):
94 """Format a `TrainIteration` `dict` item before appending to the history."""
95 item.setdefault('test_error', None)
96 item.setdefault('overhead_s', 0.0)
97 item.setdefault('model_s', 0.0)
98 item['alpha'] = MultiIndex(item['alpha'])
99 item['beta'] = MultiIndex(item['beta'])
100 item['num_evals'] = int(item['num_evals'])
101 item['added_cost'] = float(item['added_cost'])
102 item['added_error'] = float(item['added_error'])
103 item['overhead_s'] = float(item['overhead_s'])
104 item['model_s'] = float(item['model_s'])
105 return item
107 def append(self, item: dict):
108 super().append(self._validate_item(item))
110 def __add__(self, other):
111 other_list = other.data if isinstance(other, TrainHistory) else other
112 return TrainHistory(data=self.data + other_list)
114 def extend(self, items):
115 super().extend([self._validate_item(item) for item in items])
117 def insert(self, index, item):
118 super().insert(index, self._validate_item(item))
120 def __setitem__(self, key, value):
121 super().__setitem__(key, self._validate_item(value))
123 def __eq__(self, other):
124 """Two `TrainHistory` objects are equal if they have the same length and all items are equal, excluding nans."""
125 if not isinstance(other, TrainHistory):
126 return False
127 if len(self) != len(other):
128 return False
129 for item_self, item_other in zip(self, other):
130 for key in item_self:
131 if key in item_other:
132 val_self = item_self[key]
133 val_other = item_other[key]
134 if isinstance(val_self, float) and isinstance(val_other, float):
135 if not (np.isnan(val_self) and np.isnan(val_other)) and val_self != val_other:
136 return False
137 elif isinstance(val_self, dict) and isinstance(val_other, dict):
138 for v, err in val_self.items():
139 if v in val_other:
140 err_other = val_other[v]
141 if not (np.isnan(err) and np.isnan(err_other)) and err != err_other:
142 return False
143 else:
144 return False
145 elif val_self != val_other:
146 return False
147 else:
148 return False
149 return True
152class _Converged:
153 """Helper class to track which samples have converged during `System.predict()`."""
154 def __init__(self, num_samples):
155 self.num_samples = num_samples
156 self.valid_idx = np.full(num_samples, True) # All samples are valid by default
157 self.converged_idx = np.full(num_samples, False) # For FPI convergence
159 def reset_convergence(self):
160 self.converged_idx = np.full(self.num_samples, False)
162 @property
163 def curr_idx(self):
164 return np.logical_and(self.valid_idx, ~self.converged_idx)
167def _merge_shapes(target_shape, arr):
168 """Helper to merge an array into the target shape."""
169 shape1, shape2 = target_shape, arr.shape
170 if len(shape2) > len(shape1):
171 shape1, shape2 = shape2, shape1
172 result = []
173 for i in range(len(shape1)):
174 if i < len(shape2):
175 if shape1[i] == 1:
176 result.append(shape2[i])
177 elif shape2[i] == 1:
178 result.append(shape1[i])
179 else:
180 result.append(shape1[i])
181 else:
182 result.append(1)
183 arr = arr.reshape(tuple(result))
184 return np.broadcast_to(arr, target_shape).copy()
187class System(BaseModel, Serializable):
188 """
189 Multidisciplinary (MD) surrogate framework top-level class. Construct a `System` from a list of
190 `Component` models.
192 !!! Example
193 ```python
194 def f1(x):
195 y = x ** 2
196 return y
197 def f2(y):
198 z = y + 1
199 return z
201 system = System(f1, f2)
202 ```
204 A `System` object can saved/loaded from `.yml` files using the `!System` yaml tag.
206 :ivar name: the name of the system
207 :ivar components: list of `Component` models that make up the MD system
208 :ivar train_history: history of training iterations for the system surrogate (filled in during training)
210 :ivar _root_dir: root directory where all surrogate build products are saved to file
211 :ivar _logger: logger object for the system
212 """
213 yaml_tag: ClassVar[str] = u'!System'
214 model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True, validate_default=True,
215 extra='allow')
217 name: Annotated[str, Field(default_factory=lambda: "System_" + "".join(random.choices(string.digits, k=3)))]
218 components: Callable | Component | list[Callable | Component]
219 train_history: list[dict] | TrainHistory = TrainHistory()
220 amisc_version: str = None
222 _root_dir: Optional[str]
223 _logger: Optional[logging.Logger] = None
225 def __init__(self, /, *args, components=None, root_dir=None, **kwargs):
226 """Construct a `System` object from a list of `Component` models in `*args` or `components`. If
227 a `root_dir` is provided, then a new directory will be created under `root_dir` with the name
228 `amisc_{timestamp}`. This directory will be used to save all build products and log files.
230 :param components: list of `Component` models that make up the MD system
231 :param root_dir: root directory where all surrogate build products are saved to file (optional)
232 """
233 if components is None:
234 components = []
235 for a in args:
236 if isinstance(a, Component) or callable(a):
237 components.append(a)
238 else:
239 try:
240 components.extend(a)
241 except TypeError as e:
242 raise ValueError(f"Invalid component: {a}") from e
244 import amisc
245 amisc_version = kwargs.pop('amisc_version', amisc.__version__)
246 super().__init__(components=components, amisc_version=amisc_version, **kwargs)
247 self.root_dir = root_dir
249 def __repr__(self):
250 s = f'---- {self.name} ----\n'
251 s += f'amisc version: {self.amisc_version}\n'
252 s += f'Refinement level: {self.refine_level}\n'
253 s += f'Components: {", ".join([comp.name for comp in self.components])}\n'
254 s += f'Inputs: {", ".join([var.name for var in self.inputs()])}\n'
255 s += f'Outputs: {", ".join([var.name for var in self.outputs()])}'
256 return s
258 def __str__(self):
259 return self.__repr__()
261 @field_validator('components')
262 @classmethod
263 def _validate_components(cls, comps) -> list[Component]:
264 if not isinstance(comps, list):
265 comps = [comps]
266 comps = [Component.deserialize(c) for c in comps]
268 # Merge all variables to avoid name conflicts
269 merged_vars = VariableList.merge(*[comp.inputs for comp in comps], *[comp.outputs for comp in comps])
270 for comp in comps:
271 comp.inputs.update({var.name: var for var in merged_vars.values() if var in comp.inputs})
272 comp.outputs.update({var.name: var for var in merged_vars.values() if var in comp.outputs})
274 return comps
276 @field_validator('train_history')
277 @classmethod
278 def _validate_train_history(cls, history) -> TrainHistory:
279 if isinstance(history, TrainHistory):
280 return history
281 else:
282 return TrainHistory.deserialize(history)
284 def graph(self) -> nx.DiGraph:
285 """Build a directed graph of the system components based on their input-output relationships."""
286 graph = nx.DiGraph()
287 model_deps = {}
288 for comp in self.components:
289 graph.add_node(comp.name)
290 for output in comp.outputs:
291 model_deps[output] = comp.name
292 for comp in self.components:
293 for in_var in comp.inputs:
294 if in_var in model_deps:
295 graph.add_edge(model_deps[in_var], comp.name)
297 return graph
299 def _save_on_error(func):
300 """Gracefully exit and save the `System` object on any errors."""
301 @functools.wraps(func)
302 def wrap(self, *args, **kwargs):
303 try:
304 return func(self, *args, **kwargs)
305 except:
306 if self.root_dir is not None:
307 self.save_to_file(f'{self.name}_error.yml')
308 self.logger.critical(f'An error occurred during execution of "{func.__name__}". Saving '
309 f'System object to {self.name}_error.yml', exc_info=True)
310 self.logger.info(f'Final system surrogate on exit: \n {self}')
311 raise
312 return wrap
313 _save_on_error = staticmethod(_save_on_error)
315 def insert_components(self, components: list | Callable | Component):
316 """Insert new components into the system."""
317 components = components if isinstance(components, list) else [components]
318 self.components = self.components + components
320 def swap_component(self, old_component: str | Component, new_component: Callable | Component):
321 """Replace an old component with a new component."""
322 old_name = old_component if isinstance(old_component, str) else old_component.name
323 comps = [comp if comp.name != old_name else new_component for comp in self.components]
324 self.components = comps
326 def remove_component(self, component: str | Component):
327 """Remove a component from the system."""
328 comp_name = component if isinstance(component, str) else component.name
329 self.components = [comp for comp in self.components if comp.name != comp_name]
331 def inputs(self) -> VariableList:
332 """Collect all inputs from each component in the `System` and combine them into a
333 single [`VariableList`][amisc.variable.VariableList] object, excluding variables that are also outputs of
334 any component.
336 :returns: A [`VariableList`][amisc.variable.VariableList] containing all inputs from the components.
337 """
338 all_inputs = ChainMap(*[comp.inputs for comp in self.components])
339 return VariableList({k: all_inputs[k] for k in all_inputs.keys() - self.outputs().keys()})
341 def outputs(self) -> VariableList:
342 """Collect all outputs from each component in the `System` and combine them into a
343 single [`VariableList`][amisc.variable.VariableList] object.
345 :returns: A [`VariableList`][amisc.variable.VariableList] containing all outputs from the components.
346 """
347 return VariableList({k: v for k, v in ChainMap(*[comp.outputs for comp in self.components]).items()})
349 def coupling_variables(self) -> VariableList:
350 """Collect all coupling variables from each component in the `System` and combine them into a
351 single [`VariableList`][amisc.variable.VariableList] object.
353 :returns: A [`VariableList`][amisc.variable.VariableList] containing all coupling variables from the components.
354 """
355 all_outputs = self.outputs()
356 return VariableList({k: all_outputs[k] for k in (all_outputs.keys() &
357 ChainMap(*[comp.inputs for comp in self.components]).keys())})
359 def variables(self):
360 """Iterator over all variables in the system (inputs and outputs)."""
361 yield from ChainMap(self.inputs(), self.outputs()).values()
363 @property
364 def refine_level(self) -> int:
365 """The total number of training iterations."""
366 return len(self.train_history)
368 @property
369 def logger(self) -> logging.Logger:
370 return self._logger
372 @logger.setter
373 def logger(self, logger: logging.Logger):
374 self._logger = logger
375 for comp in self.components:
376 comp.logger = logger
378 @staticmethod
379 def timestamp() -> str:
380 """Return a UTC timestamp string in the isoformat `YYYY-MM-DDTHH.MM.SS`."""
381 return datetime.datetime.now(tz=timezone.utc).isoformat().split('.')[0].replace(':', '.')
383 @property
384 def root_dir(self):
385 """Return the root directory of the surrogate (if available), otherwise `None`."""
386 return Path(self._root_dir) if self._root_dir is not None else None
388 @root_dir.setter
389 def root_dir(self, root_dir: str | Path):
390 """Set the root directory for all build products. If `root_dir` is `None`, then no products will be saved.
391 Otherwise, log files, model outputs, surrogate files, etc. will be saved under this directory.
393 !!! Note "`amisc` root directory"
394 If `root_dir` is not `None`, then a new directory will be created under `root_dir` with the name
395 `amisc_{timestamp}`. This directory will be used to save all build products. If `root_dir` matches the
396 `amisc_*` format, then it will be used directly.
398 :param root_dir: the root directory for all build products
399 """
400 if root_dir is not None:
401 parts = Path(root_dir).resolve().parts
402 if parts[-1].startswith('amisc_'):
403 self._root_dir = Path(root_dir).resolve().as_posix()
404 if not self.root_dir.is_dir():
405 os.mkdir(self.root_dir)
406 else:
407 root_dir = Path(root_dir) / ('amisc_' + self.timestamp())
408 os.mkdir(root_dir)
409 self._root_dir = Path(root_dir).resolve().as_posix()
411 log_file = None
412 if not (pth := self.root_dir / 'surrogates').is_dir():
413 os.mkdir(pth)
414 if not (pth := self.root_dir / 'components').is_dir():
415 os.mkdir(pth)
416 for comp in self.components:
417 if comp.model_kwarg_requested('output_path'):
418 if not (comp_pth := pth / comp.name).is_dir():
419 os.mkdir(comp_pth)
420 for f in os.listdir(self.root_dir):
421 if f.endswith('.log'):
422 log_file = (self.root_dir / f).resolve().as_posix()
423 break
424 if log_file is None:
425 log_file = (self.root_dir / f'amisc_{self.timestamp()}.log').resolve().as_posix()
426 self.set_logger(log_file=log_file)
428 else:
429 self._root_dir = None
430 self.set_logger(log_file=None)
432 def set_logger(self, log_file: str | Path | bool = None, stdout: bool = None, logger: logging.Logger = None,
433 level: int = logging.INFO):
434 """Set a new `logging.Logger` object.
436 :param log_file: log to this file if str or Path (defaults to whatever is currently set or empty);
437 set `False` to remove file logging or set `True` to create a default log file in the root dir
438 :param stdout: whether to connect the logger to console (defaults to whatever is currently set or `False`)
439 :param logger: the logging object to use (this will override the `log_file` and `stdout` arguments if set);
440 if `None`, then a new logger is created according to `log_file` and `stdout`
441 :param level: the logging level to set the logger to (defaults to `logging.INFO`)
442 """
443 # Decide whether to use stdout
444 if stdout is None:
445 stdout = False
446 if self._logger is not None:
447 for handler in self._logger.handlers:
448 if isinstance(handler, logging.StreamHandler):
449 stdout = True
450 break
452 # Decide what log_file to use (if any)
453 if log_file is True:
454 log_file = pth / f'amisc_{self.timestamp()}.log' if (pth := self.root_dir) is not None else (
455 f'amisc_{self.timestamp()}.log')
456 elif log_file is None:
457 if self._logger is not None:
458 for handler in self._logger.handlers:
459 if isinstance(handler, logging.FileHandler):
460 log_file = handler.baseFilename
461 break
462 elif log_file is False:
463 log_file = None
465 self._logger = logger or get_logger(self.name, log_file=log_file, stdout=stdout, level=level)
467 for comp in self.components:
468 comp.set_logger(log_file=log_file, stdout=stdout, logger=logger, level=level)
470 def sample_inputs(self, size: tuple | int,
471 component: str = 'System',
472 normalize: bool = True,
473 use_pdf: bool | str | list[str] = False,
474 include: str | list[str] = None,
475 exclude: str | list[str] = None,
476 nominal: dict[str, float] = None) -> Dataset:
477 """Return samples of the inputs according to provided options. Will return samples in the
478 normalized/compressed space of the surrogate by default. See [`to_model_dataset`][amisc.utils.to_model_dataset]
479 to convert the samples to be usable by the true model directly.
481 :param size: tuple or integer specifying shape or number of samples to obtain
482 :param component: which component to sample inputs for (defaults to full system exogenous inputs)
483 :param normalize: whether to normalize the samples (defaults to True)
484 :param use_pdf: whether to sample from variable pdfs (defaults to False, which will instead sample from the
485 variable domain bounds). If a string or list of strings is provided, then only those variables
486 or variable categories will be sampled using their pdfs.
487 :param include: a list of variable or variable categories to include in the sampling. Defaults to using all
488 input variables.
489 :param exclude: a list of variable or variable categories to exclude from the sampling. Empty by default.
490 :param nominal: `dict(var_id=value)` of nominal values for params with relative uncertainty. Specify nominal
491 values as unnormalized (will be normalized if `normalize=True`)
492 :returns: `dict` of `(*size,)` samples for each selected input variable
493 """
494 size = (size, ) if isinstance(size, int) else size
495 nominal = nominal or dict()
496 inputs = self.inputs() if component == 'System' else self[component].inputs
497 if include is None:
498 include = []
499 if not isinstance(include, list):
500 include = [include]
501 if exclude is None:
502 exclude = []
503 if not isinstance(exclude, list):
504 exclude = [exclude]
505 if isinstance(use_pdf, str):
506 use_pdf = [use_pdf]
508 selected_inputs = []
509 for var in inputs:
510 if len(include) == 0 or var.name in include or var.category in include:
511 if var.name not in exclude and var.category not in exclude:
512 selected_inputs.append(var)
514 samples = {}
515 for var in selected_inputs:
516 # Sample from latent variable domains for field quantities
517 if var.compression is not None:
518 latent = var.sample_domain(size)
519 for i in range(latent.shape[-1]):
520 samples[f'{var.name}{LATENT_STR_ID}{i}'] = latent[..., i]
522 # Sample scalars normally
523 else:
524 if (domain := var.get_domain()) is None:
525 raise RuntimeError(f"Trying to sample variable '{var}' with empty domain. Please set a domain "
526 f"for this variable. Samples outside the provided domain will be rejected.")
527 lb, ub = domain
528 pdf = (var.name in use_pdf or var.category in use_pdf) if isinstance(use_pdf, list) else use_pdf
529 nom = nominal.get(var.name, None)
531 x_sample = var.sample(size, nominal=nom) if pdf else var.sample_domain(size)
532 good_idx = (x_sample < ub) & (x_sample > lb)
533 num_reject = np.sum(~good_idx)
535 while num_reject > 0:
536 new_sample = var.sample((num_reject,), nominal=nom) if pdf else var.sample_domain((num_reject,))
537 x_sample[~good_idx] = new_sample
538 good_idx = (x_sample < ub) & (x_sample > lb)
539 num_reject = np.sum(~good_idx)
541 samples[var.name] = var.normalize(x_sample) if normalize else x_sample
543 return samples
545 def simulate_fit(self):
546 """Loop back through training history and simulate each iteration. Will yield the internal data structures
547 of each `Component` surrogate after each iteration of training (without needing to call `fit()` or any
548 of the underlying models). This might be useful, for example, for computing the surrogate predictions on
549 a new test set or viewing cumulative training costs.
551 !!! Example
552 Say you have a new test set: `(new_xtest, new_ytest)`, and you want to compute the accuracy of the
553 surrogate fit at each iteration of the training history:
555 ```python
556 for train_iter, active_sets, candidate_sets, misc_coeff_train, misc_coeff_test in system.simulate_fit():
557 # Do something with the surrogate data structures
558 new_ysurr = system.predict(new_xtest, index_set=active_sets, misc_coeff=misc_coeff_train)
559 train_error = relative_error(new_ysurr, new_ytest)
560 ```
562 :return: a generator of the active index sets, candidate index sets, and MISC coefficients
563 of each component model at each iteration of the training history
564 """
565 # "Simulated" data structures for each component
566 active_sets = {comp.name: IndexSet() for comp in self.components} # active index sets for each component
567 candidate_sets = {comp.name: IndexSet() for comp in self.components} # candidate sets for each component
568 misc_coeff_train = {comp.name: MiscTree() for comp in self.components} # MISC coeff for active sets
569 misc_coeff_test = {comp.name: MiscTree() for comp in self.components} # MISC coeff for active + candidate sets
571 for train_result in self.train_history:
572 # The selected refinement component and indices
573 comp_star = train_result['component']
574 alpha_star = train_result['alpha']
575 beta_star = train_result['beta']
576 comp = self[comp_star]
578 # Get forward neighbors for the selected index
579 neighbors = comp._neighbors(alpha_star, beta_star, active_set=active_sets[comp_star], forward=True)
581 # "Activate" the index in the simulated data structure
582 s = set()
583 s.add((alpha_star, beta_star))
584 comp.update_misc_coeff(IndexSet(s), index_set=active_sets[comp_star],
585 misc_coeff=misc_coeff_train[comp_star])
587 if (alpha_star, beta_star) in candidate_sets[comp_star]:
588 candidate_sets[comp_star].remove((alpha_star, beta_star))
589 else:
590 # Only for initial index which didn't come from the candidate set
591 comp.update_misc_coeff(IndexSet(s), index_set=active_sets[comp_star].union(candidate_sets[comp_star]),
592 misc_coeff=misc_coeff_test[comp_star])
593 active_sets[comp_star].update(s)
595 comp.update_misc_coeff(neighbors, index_set=active_sets[comp_star].union(candidate_sets[comp_star]),
596 misc_coeff=misc_coeff_test[comp_star]) # neighbors will only ever pass here once
597 candidate_sets[comp_star].update(neighbors)
599 # Caller can now do whatever they want as if the system surrogate were at this training iteration
600 # See the "index_set" and "misc_coeff" overrides for `System.predict()` for example
601 yield train_result, active_sets, candidate_sets, misc_coeff_train, misc_coeff_test
603 def add_output(self):
604 """Add an output variable retroactively to a component surrogate. User should provide a callable that
605 takes a save path and extracts the model output data for given training point/location.
606 """
607 # TODO
608 # Loop back through the surrogate training history
609 # Simulate activate_index and extract the model output from file rather than calling the model
610 # Update all interpolator states
611 raise NotImplementedError
613 @_save_on_error
614 def fit(self, targets: list = None,
615 num_refine: int = 100,
616 max_iter: int = 20,
617 max_tol: float = 1e-3,
618 runtime_hr: float = 1.,
619 estimate_bounds: bool = False,
620 update_bounds: bool = True,
621 test_set: tuple | str | Path = None,
622 start_test_check: int = None,
623 save_interval: int = 0,
624 plot_interval: int = 1,
625 cache_interval: int = 0,
626 executor: Executor = None,
627 weight_fcns: dict[str, callable] | Literal['pdf'] | None = 'pdf'):
628 """Train the system surrogate adaptively by iterative refinement until an end condition is met.
630 :param targets: list of system output variables to focus refinement on, use all outputs if not specified
631 :param num_refine: number of input samples to compute error indicators on
632 :param max_iter: the maximum number of refinement steps to take
633 :param max_tol: the max allowable value in relative L2 error to achieve
634 :param runtime_hr: the threshold wall clock time (hr) at which to stop further refinement (will go
635 until all models finish the current iteration)
636 :param estimate_bounds: whether to estimate bounds for the coupling variables; will only try to estimate from
637 the `test_set` if provided (defaults to `True`). Otherwise, you should manually
638 provide domains for all coupling variables.
639 :param update_bounds: whether to continuously update coupling variable bounds during refinement
640 :param test_set: `tuple` of `(xtest, ytest)` to show convergence of surrogate to the true model. The test set
641 inputs and outputs are specified as `dicts` of `np.ndarrays` with keys corresponding to the
642 variable names. Can also pass a path to a `.pkl` file that has the test set data as
643 {'test_set': (xtest, ytest)}.
644 :param start_test_check: the iteration to start checking the test set error (defaults to the number
645 of components); surrogate evaluation isn't useful during initialization so you
646 should at least allow one iteration per component before checking test set error
647 :param save_interval: number of refinement steps between each progress save, none if 0; `System.root_dir`
648 must be specified to save to file
649 :param plot_interval: how often to plot the error indicator and test set error (defaults to every iteration);
650 will only plot and save to file if a root directory is set
651 :param cache_interval: how often to cache component data in order to speed up future training iterations (at
652 the cost of additional memory usage); defaults to 0 (no caching)
653 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations (optional, but
654 recommended for expensive models)
655 :param weight_fcns: a `dict` of weight functions to apply to each input variable for training data selection;
656 defaults to using the pdf of each variable. If None, then no weighting is applied.
657 """
658 start_test_check = start_test_check or sum([1 for _ in self.components if _.has_surrogate])
659 targets = targets or self.outputs()
660 xtest, ytest = self._get_test_set(test_set)
661 max_iter = self.refine_level + max_iter
663 # Estimate bounds from test set if provided (override current bounds if they are set)
664 if estimate_bounds:
665 if ytest is not None:
666 y_samples = to_surrogate_dataset(ytest, self.outputs(), del_fields=True)[0] # normalize/compress
667 _combine_latent_arrays(y_samples)
668 coupling_vars = {k: v for k, v in self.coupling_variables().items() if k in y_samples}
669 y_min, y_max = {}, {}
670 for var in coupling_vars.values():
671 y_min[var] = np.nanmin(y_samples[var], axis=0)
672 y_max[var] = np.nanmax(y_samples[var], axis=0)
673 if var.compression is not None:
674 new_domain = list(zip(y_min[var].tolist(), y_max[var].tolist()))
675 var.update_domain(new_domain, override=True)
676 else:
677 new_domain = (float(y_min[var]), float(y_max[var]))
678 var.update_domain(var.denormalize(new_domain), override=True)
679 del y_samples
680 else:
681 self.logger.warning('Could not estimate bounds for coupling variables: no test set provided. '
682 'Make sure you manually provide (good) coupling variable domains.')
684 # Track convergence progress on the error indicator and test set (plot to file)
685 if self.root_dir is not None:
686 err_record = [res['added_error'] for res in self.train_history]
687 err_fig, err_ax = plt.subplots(figsize=(6, 5), layout='tight')
689 if xtest is not None and ytest is not None:
690 num_plot = min(len(targets), 3)
691 test_record = np.full((self.refine_level, num_plot), np.nan)
692 t_fig, t_ax = plt.subplots(1, num_plot, figsize=(3.5 * num_plot, 4), layout='tight', squeeze=False,
693 sharey='row')
694 for j, res in enumerate(self.train_history):
695 for i, var in enumerate(targets[:num_plot]):
696 if (perf := res.get('test_error')) is not None:
697 test_record[j, i] = perf[var]
699 total_overhead = 0.0
700 total_model_wall_time = 0.0
701 t_start = time.time()
702 while True:
703 # Adaptive refinement step
704 t_iter_start = time.time()
705 train_result = self.refine(targets=targets, num_refine=num_refine, update_bounds=update_bounds,
706 executor=executor, weight_fcns=weight_fcns)
707 if train_result['component'] is None:
708 self._print_title_str('Termination criteria reached: No candidates left to refine')
709 break
711 # Keep track of algorithmic overhead (before and after call_model for this iteration)
712 m_start, m_end = self[train_result['component']].get_model_timestamps() # Start and end of call_model
713 if m_start is not None and m_end is not None:
714 train_result['overhead_s'] = (m_start - t_iter_start) + (time.time() - m_end)
715 train_result['model_s'] = m_end - m_start
716 else:
717 train_result['overhead_s'] = time.time() - t_iter_start
718 train_result['model_s'] = 0.0
719 total_overhead += train_result['overhead_s']
720 total_model_wall_time += train_result['model_s']
722 curr_error = train_result['added_error']
724 # Plot progress of error indicator
725 if self.root_dir is not None:
726 err_record.append(curr_error)
728 if plot_interval > 0 and self.refine_level % plot_interval == 0:
729 err_ax.clear(); err_ax.set_yscale('log'); err_ax.grid()
730 err_ax.xaxis.set_major_locator(MaxNLocator(integer=True))
731 err_ax.plot(err_record, '-k')
732 err_ax.set_xlabel('Iteration'); err_ax.set_ylabel('Relative error indicator')
733 err_fig.savefig(str(Path(self.root_dir) / 'error_indicator.pdf'), format='pdf', bbox_inches='tight')
735 # Save performance on a test set
736 if xtest is not None and ytest is not None:
737 # don't compute if components are uninitialized
738 perf = self.test_set_performance(xtest, ytest) if self.refine_level + 1 >= start_test_check else (
739 {str(var): np.nan for var in ytest if COORDS_STR_ID not in var})
740 train_result['test_error'] = perf.copy()
742 if self.root_dir is not None:
743 test_record = np.vstack((test_record, np.array([perf[var] for var in targets[:num_plot]])))
745 if plot_interval > 0 and self.refine_level % plot_interval == 0:
746 for i in range(num_plot):
747 with warnings.catch_warnings():
748 warnings.simplefilter("ignore", UserWarning)
749 t_ax[0, i].clear(); t_ax[0, i].set_yscale('log'); t_ax[0, i].grid()
750 t_ax[0, i].xaxis.set_major_locator(MaxNLocator(integer=True))
751 t_ax[0, i].plot(test_record[:, i], '-k')
752 t_ax[0, i].set_title(self.outputs()[targets[i]].get_tex(units=True))
753 t_ax[0, i].set_xlabel('Iteration')
754 t_ax[0, i].set_ylabel('Test set relative error' if i==0 else '')
755 t_fig.savefig(str(Path(self.root_dir) / 'test_set_error.pdf'),format='pdf',bbox_inches='tight')
757 self.train_history.append(train_result)
759 if self.root_dir is not None and save_interval > 0 and self.refine_level % save_interval == 0:
760 iter_name = f'{self.name}_iter{self.refine_level}'
761 if not (pth := self.root_dir / 'surrogates' / iter_name).is_dir():
762 os.mkdir(pth)
763 self.save_to_file(f'{iter_name}.yml', save_dir=pth) # Save to an iteration-specific directory
765 if cache_interval > 0 and self.refine_level % cache_interval == 0:
766 for comp in self.components:
767 comp.cache()
769 # Check all end conditions
770 if self.refine_level >= max_iter:
771 self._print_title_str(f'Termination criteria reached: Max iteration {self.refine_level}/{max_iter}')
772 break
773 if curr_error < max_tol:
774 self._print_title_str(f'Termination criteria reached: relative error {curr_error} < tol {max_tol}')
775 break
776 if ((time.time() - t_start) / 3600.0) >= runtime_hr:
777 t_end = time.time()
778 actual = datetime.timedelta(seconds=t_end - t_start)
779 target = datetime.timedelta(seconds=runtime_hr * 3600)
780 train_surplus = ((t_end - t_start) - runtime_hr * 3600) / 3600
781 self._print_title_str(f'Termination criteria reached: runtime {str(actual)} > {str(target)}')
782 self.logger.info(f'Surplus wall time: {train_surplus:.3f}/{runtime_hr:.3f} hours '
783 f'(+{100 * train_surplus / runtime_hr:.2f}%)')
784 break
786 self.logger.info(f'Model evaluation algorithm efficiency: '
787 f'{100 * total_model_wall_time / (total_model_wall_time + total_overhead):.2f}%')
789 if self.root_dir is not None:
790 iter_name = f'{self.name}_iter{self.refine_level}'
791 if not (pth := self.root_dir / 'surrogates' / iter_name).is_dir():
792 os.mkdir(pth)
793 self.save_to_file(f'{iter_name}.yml', save_dir=pth)
795 if xtest is not None and ytest is not None:
796 self._save_test_set((xtest, ytest))
798 self.logger.info(f'Final system surrogate: \n {self}')
800 def test_set_performance(self, xtest: Dataset, ytest: Dataset, index_set='test') -> Dataset:
801 """Compute the relative L2 error on a test set for the given target output variables.
803 :param xtest: `dict` of test set input samples (unnormalized)
804 :param ytest: `dict` of test set output samples (unnormalized)
805 :param index_set: index set to use for prediction (defaults to 'train')
806 :returns: `dict` of relative L2 errors for each target output variable
807 """
808 targets = [var for var in ytest.keys() if COORDS_STR_ID not in var and var in self.outputs()]
809 coords = {var: ytest[var] for var in ytest if COORDS_STR_ID in var}
810 xtest = to_surrogate_dataset(xtest, self.inputs(), del_fields=True)[0]
811 ysurr = self.predict(xtest, index_set=index_set, targets=targets)
812 ysurr = to_model_dataset(ysurr, self.outputs(), del_latent=True, **coords)[0]
813 perf = {}
814 for var in targets:
815 # Handle relative error for object arrays (field qtys)
816 ytest_obj = np.issubdtype(ytest[var].dtype, np.object_)
817 ysurr_obj = np.issubdtype(ysurr[var].dtype, np.object_)
818 if ytest_obj or ysurr_obj:
819 _iterable = np.ndindex(ytest[var].shape) if ytest_obj else np.ndindex(ysurr[var].shape)
820 num, den = [], []
821 for index in _iterable:
822 pred, targ = ysurr[var][index], ytest[var][index]
823 num.append((pred - targ)**2)
824 den.append(targ ** 2)
825 perf[var] = float(np.sqrt(sum([np.sum(n) for n in num]) / sum([np.sum(d) for d in den])))
826 else:
827 perf[var] = float(relative_error(ysurr[var], ytest[var]))
829 return perf
831 def refine(self, targets: list = None, num_refine: int = 100, update_bounds: bool = True, executor: Executor = None,
832 weight_fcns: dict[str, callable] | Literal['pdf'] | None = 'pdf') -> TrainIteration:
833 """Perform a single adaptive refinement step on the system surrogate.
835 :param targets: list of system output variables to focus refinement on, use all outputs if not specified
836 :param num_refine: number of input samples to compute error indicators on
837 :param update_bounds: whether to continuously update coupling variable bounds during refinement
838 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations
839 :param weight_fcns: weight functions for choosing new training data for each input variable; defaults to
840 the PDFs of each variable. If None, then no weighting is applied.
841 :returns: `dict` of the refinement results indicating the chosen component and candidate index
842 """
843 self._print_title_str(f'Refining system surrogate: iteration {self.refine_level + 1}')
844 targets = targets or self.outputs()
846 # Check for uninitialized components and refine those first
847 for comp in self.components:
848 if len(comp.active_set) == 0 and comp.has_surrogate:
849 alpha_star = (0,) * len(comp.model_fidelity)
850 beta_star = (0,) * len(comp.max_beta)
851 self.logger.info(f"Initializing component {comp.name}: adding {(alpha_star, beta_star)} to active set")
852 model_dir = (pth / 'components' / comp.name) if (pth := self.root_dir) is not None else None
853 comp.activate_index(alpha_star, beta_star, model_dir=model_dir, executor=executor,
854 weight_fcns=weight_fcns)
855 num_evals = comp.get_cost(alpha_star, beta_star)
856 cost_star = max(1., comp.model_costs.get(alpha_star, 1.) * num_evals) # Cpu time (s)
857 err_star = np.nan
858 return {'component': comp.name, 'alpha': alpha_star, 'beta': beta_star, 'num_evals': int(num_evals),
859 'added_cost': float(cost_star), 'added_error': float(err_star)}
861 # Compute entire integrated-surrogate on a random test set for global system QoI error estimation
862 x_samples = self.sample_inputs(num_refine)
863 y_curr = self.predict(x_samples, index_set='train', targets=targets)
864 _combine_latent_arrays(y_curr)
865 coupling_vars = {k: v for k, v in self.coupling_variables().items() if k in y_curr}
867 y_min, y_max = None, None
868 if update_bounds:
869 y_min = {var: np.nanmin(y_curr[var], axis=0, keepdims=True) for var in coupling_vars} # (1, ydim)
870 y_max = {var: np.nanmax(y_curr[var], axis=0, keepdims=True) for var in coupling_vars} # (1, ydim)
872 # Find the candidate surrogate with the largest error indicator
873 error_max, error_indicator = -np.inf, -np.inf
874 comp_star, alpha_star, beta_star, err_star, cost_star = None, None, None, -np.inf, 0
875 for comp in self.components:
876 if not comp.has_surrogate: # Skip analytic models that don't need a surrogate
877 continue
879 self.logger.info(f"Estimating error for component '{comp.name}'...")
881 if len(comp.candidate_set) > 0:
882 candidates = list(comp.candidate_set)
883 if executor is None:
884 ret = [self.predict(x_samples, targets=targets, index_set={comp.name: {(alpha, beta)}},
885 incremental={comp.name: True})
886 for alpha, beta in candidates]
887 else:
888 temp_buffer = self._remove_unpickleable()
889 futures = [executor.submit(self.predict, x_samples, targets=targets,
890 index_set={comp.name: {(alpha, beta)}}, incremental={comp.name: True})
891 for alpha, beta in candidates]
892 wait(futures, timeout=None, return_when=ALL_COMPLETED)
893 ret = [f.result() for f in futures]
894 self._restore_unpickleable(temp_buffer)
896 for i, y_cand in enumerate(ret):
897 alpha, beta = candidates[i]
898 _combine_latent_arrays(y_cand)
899 error = {}
900 for var, arr in y_cand.items():
901 if var in targets:
902 error[var] = relative_error(arr, y_curr[var], skip_nan=True)
904 if update_bounds and var in coupling_vars:
905 y_min[var] = np.nanmin(np.concatenate((y_min[var], arr), axis=0), axis=0, keepdims=True)
906 y_max[var] = np.nanmax(np.concatenate((y_max[var], arr), axis=0), axis=0, keepdims=True)
908 delta_error = np.nanmax([np.nanmax(error[var]) for var in error]) # Max error over all target QoIs
909 num_evals = comp.get_cost(alpha, beta)
910 delta_work = max(1., comp.model_costs.get(alpha, 1.) * num_evals) # Cpu time (s)
911 error_indicator = delta_error / delta_work
913 self.logger.info(f"Candidate multi-index: {(alpha, beta)}. Relative error: {delta_error}. "
914 f"Error indicator: {error_indicator}.")
916 if error_indicator > error_max:
917 error_max = error_indicator
918 comp_star, alpha_star, beta_star, err_star, cost_star = (
919 comp.name, alpha, beta, delta_error, delta_work)
920 else:
921 self.logger.info(f"Component '{comp.name}' has no available candidates left!")
923 # Update all coupling variable ranges
924 if update_bounds:
925 for var in coupling_vars.values():
926 if np.all(~np.isnan(y_min[var])) and np.all(~np.isnan(y_max[var])):
927 if var.compression is not None:
928 new_domain = list(zip(np.squeeze(y_min[var], axis=0).tolist(),
929 np.squeeze(y_max[var], axis=0).tolist()))
930 var.update_domain(new_domain)
931 else:
932 new_domain = (y_min[var][0], y_max[var][0])
933 var.update_domain(var.denormalize(new_domain)) # bds will be in norm space from predict() call
935 # Add the chosen multi-index to the chosen component
936 if comp_star is not None:
937 self.logger.info(f"Candidate multi-index {(alpha_star, beta_star)} chosen for component '{comp_star}'.")
938 model_dir = (pth / 'components' / comp_star) if (pth := self.root_dir) is not None else None
939 self[comp_star].activate_index(alpha_star, beta_star, model_dir=model_dir, executor=executor,
940 weight_fcns=weight_fcns)
941 num_evals = self[comp_star].get_cost(alpha_star, beta_star)
942 else:
943 self.logger.info(f"No candidates left for refinement, iteration: {self.refine_level}")
944 num_evals = 0
946 # Return the results of the refinement step
947 return {'component': comp_star, 'alpha': alpha_star, 'beta': beta_star, 'num_evals': int(num_evals),
948 'added_cost': float(cost_star), 'added_error': float(err_star)}
950 def predict(self, x: dict | Dataset,
951 max_fpi_iter: int = 100,
952 anderson_mem: int = 10,
953 fpi_tol: float = 1e-10,
954 use_model: str | tuple | dict = None,
955 model_dir: str | Path = None,
956 verbose: bool = False,
957 index_set: dict[str: IndexSet | Literal['train', 'test']] = 'test',
958 misc_coeff: dict[str: MiscTree] = None,
959 normalized_inputs: bool = True,
960 incremental: dict[str, bool] = False,
961 targets: list[str] = None,
962 executor: Executor = None,
963 var_shape: dict[str, tuple] = None) -> Dataset:
964 """Evaluate the system surrogate at inputs `x`. Return `y = system(x)`.
966 !!! Warning "Computing the true model with feedback loops"
967 You can use this function to predict outputs for your MD system using the full-order models rather than the
968 surrogate, by specifying `use_model`. This is convenient because the `System` manages all the
969 coupled information flow between models automatically. However, it is *highly* recommended to not use
970 the full model if your system contains feedback loops. The FPI nonlinear solver would be infeasible using
971 anything more computationally demanding than the surrogate.
973 :param x: `dict` of input samples for each variable in the system
974 :param max_fpi_iter: the limit on convergence for the fixed-point iteration routine
975 :param anderson_mem: hyperparameter for tuning the convergence of FPI with anderson acceleration
976 :param fpi_tol: tolerance limit for convergence of fixed-point iteration
977 :param use_model: 'best'=highest-fidelity, 'worst'=lowest-fidelity, tuple=specific fidelity, None=surrogate,
978 specify a `dict` of the above to assign different model fidelities for diff components
979 :param model_dir: directory to save model outputs if `use_model` is specified
980 :param verbose: whether to print out iteration progress during execution
981 :param index_set: `dict(comp=[indices])` to override the active set for a component, defaults to using the
982 `test` set for every component. Can also specify `train` for any component or a valid
983 `IndexSet` object. If `incremental` is specified, will be overwritten with `train`.
984 :param misc_coeff: `dict(comp=MiscTree)` to override the default coefficients for a component, passes through
985 along with `index_set` and `incremental` to `comp.predict()`.
986 :param normalized_inputs: true if the passed inputs are compressed/normalized for surrogate evaluation
987 (default), such as inputs returned by `sample_inputs`. Set to `False` if you are
988 passing inputs as the true models would expect them instead (i.e. not normalized).
989 :param incremental: whether to add `index_set` to the current active set for each component (temporarily);
990 this will set `index_set='train'` for all other components (since incremental will
991 augment the "training" active sets, not the "testing" candidate sets)
992 :param targets: list of output variables to return, defaults to returning all system outputs
993 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations
994 :param var_shape: (Optional) `dict` of shapes for field quantity inputs in `x` -- you would only specify this
995 if passing field qtys directly to the models (i.e. not using `sample_inputs`)
996 :returns: `dict` of output variables - the surrogate approximation of the system outputs (or the true model)
997 """
998 # Format inputs and allocate space
999 var_shape = var_shape or {}
1000 x, loop_shape = format_inputs(x, var_shape=var_shape) # {'x': (N, *var_shape)}
1001 y = {}
1002 all_inputs = ChainMap(x, y) # track all inputs (including coupling vars in y)
1003 N = int(np.prod(loop_shape))
1004 t1 = 0
1005 output_dir = None
1006 norm_status = {var: normalized_inputs for var in x} # keep track of whether inputs are normalized or not
1007 graph = self.graph()
1009 # Keep track of what outputs are computed
1010 is_computed = {}
1011 for var in (targets or self.outputs()):
1012 if (v := self.outputs().get(var, None)) is not None:
1013 if v.compression is not None:
1014 for field in v.compression.fields:
1015 is_computed[field] = False
1016 else:
1017 is_computed[var] = False
1019 def _set_default(struct: dict, default=None):
1020 """Helper to set a default value for each component key in a `dict`. Ensures all components have a value."""
1021 if struct is not None:
1022 if not isinstance(struct, dict):
1023 struct = {node: struct for node in graph.nodes} # use same for each component
1024 else:
1025 struct = {node: default for node in graph.nodes}
1026 return {node: struct.get(node, default) for node in graph.nodes}
1028 # Ensure use_model, index_set, and incremental are specified for each component model
1029 use_model = _set_default(use_model, None)
1030 incremental = _set_default(incremental, False) # default to train if incremental anywhere
1031 index_set = _set_default(index_set, 'train' if any([incremental[node] for node in graph.nodes]) else 'test')
1032 misc_coeff = _set_default(misc_coeff, None)
1034 samples = _Converged(N) # track convergence of samples
1036 def _gather_comp_inputs(comp, coupling=None):
1037 """Helper to gather inputs for a component, making sure they are normalized correctly. Any coupling
1038 variables passed in will be used in preference over `all_inputs`.
1039 """
1040 # Will access but not modify: all_inputs, use_model, norm_status
1041 field_coords = {}
1042 comp_input = {}
1043 coupling = coupling or {}
1045 # Take coupling variables as a priority
1046 comp_input.update({var: np.copy(arr[samples.curr_idx, ...]) for var, arr in
1047 coupling.items() if str(var).split(LATENT_STR_ID)[0] in comp.inputs})
1048 # Gather all other inputs
1049 for var, arr in all_inputs.items():
1050 var_id = str(var).split(LATENT_STR_ID)[0]
1051 if var_id in comp.inputs and var not in coupling:
1052 comp_input[var] = np.copy(arr[samples.curr_idx, ...])
1054 # Gather field coordinates
1055 for var in comp.inputs:
1056 coords_str = f'{var}{COORDS_STR_ID}'
1057 if (coords := all_inputs.get(coords_str)) is not None:
1058 field_coords[coords_str] = coords[samples.curr_idx, ...]
1059 elif (coords := comp.model_kwargs.get(coords_str)) is not None:
1060 field_coords[coords_str] = coords
1062 # Gather extra fields (will never be in coupling since field couplings should always be latent coeff)
1063 for var in comp.inputs:
1064 if var not in comp_input and var.compression is not None:
1065 for field in var.compression.fields:
1066 if field in all_inputs:
1067 comp_input[field] = np.copy(all_inputs[field][samples.curr_idx, ...])
1069 call_model = use_model.get(comp.name, None) is not None
1071 # Make sure we format all inputs for model evaluation (i.e. denormalize)
1072 if call_model:
1073 norm_inputs = {var: arr for var, arr in comp_input.items() if norm_status[var]}
1074 if len(norm_inputs) > 0:
1075 denorm_inputs, fc = to_model_dataset(norm_inputs, comp.inputs, del_latent=True, **field_coords)
1076 for var in norm_inputs:
1077 del comp_input[var]
1078 field_coords.update(fc)
1079 comp_input.update(denorm_inputs)
1081 # Otherwise, make sure we format inputs for surrogate evaluation (i.e. normalize)
1082 else:
1083 denorm_inputs = {var: arr for var, arr in comp_input.items() if not norm_status[var]}
1084 if len(denorm_inputs) > 0:
1085 norm_inputs, _ = to_surrogate_dataset(denorm_inputs, comp.inputs, del_fields=True, **field_coords)
1087 for var in denorm_inputs:
1088 del comp_input[var]
1089 comp_input.update(norm_inputs)
1091 return comp_input, field_coords, call_model
1093 # Convert system into DAG by grouping strongly-connected-components
1094 dag = nx.condensation(graph)
1096 # Compute component models in topological order
1097 for supernode in nx.topological_sort(dag):
1098 if np.all(list(is_computed.values())):
1099 break # Exit early if all selected return qois are computed
1101 scc = [n for n in dag.nodes[supernode]['members']]
1102 samples.reset_convergence()
1104 # Compute single component feedforward output (no FPI needed)
1105 if len(scc) == 1:
1106 if verbose:
1107 self.logger.info(f"Running component '{scc[0]}'...")
1108 t1 = time.time()
1110 # Gather inputs
1111 comp = self[scc[0]]
1112 comp_input, field_coords, call_model = _gather_comp_inputs(comp)
1114 # Compute outputs
1115 if model_dir is not None:
1116 output_dir = Path(model_dir) / scc[0]
1117 if not output_dir.exists():
1118 os.mkdir(output_dir)
1119 comp_output = comp.predict(comp_input, use_model=use_model.get(scc[0]), model_dir=output_dir,
1120 index_set=index_set.get(scc[0]), incremental=incremental.get(scc[0]),
1121 misc_coeff=misc_coeff.get(scc[0]), executor=executor, **field_coords)
1123 for var, arr in comp_output.items():
1124 if var == 'errors':
1125 if y.get(var) is None:
1126 y.setdefault(var, np.full((N,), None, dtype=object))
1127 global_indices = np.arange(N)[samples.curr_idx]
1129 for local_idx, err_info in arr.items():
1130 global_idx = int(global_indices[local_idx])
1131 err_info['index'] = global_idx
1132 y[var][global_idx] = err_info
1133 continue
1135 is_numeric = np.issubdtype(arr.dtype, np.number)
1136 if is_numeric: # for scalars or vectorized field quantities
1137 output_shape = arr.shape[1:]
1138 if y.get(var) is None:
1139 y.setdefault(var, np.full((N, *output_shape), np.nan))
1140 y[var][samples.curr_idx, ...] = arr
1142 else: # for fields returned as object arrays
1143 if y.get(var) is None:
1144 y.setdefault(var, np.full((N,), None, dtype=object))
1145 y[var][samples.curr_idx] = arr
1147 # Update valid indices and status for component outputs
1148 for var in comp_output:
1149 if str(var).split(LATENT_STR_ID)[0] in comp.outputs:
1150 is_numeric = np.issubdtype(y[var].dtype, np.number)
1151 new_valid = ~np.any(np.isnan(y[var]), axis=tuple(range(1, y[var].ndim))) if is_numeric else (
1152 [False if arr is None else ~np.any(np.isnan(arr)) for i, arr in enumerate(y[var])]
1153 )
1154 samples.valid_idx = np.logical_and(samples.valid_idx, new_valid)
1156 is_computed[str(var).split(LATENT_STR_ID)[0]] = True
1157 norm_status[var] = not call_model
1159 if verbose:
1160 self.logger.info(f"Component '{scc[0]}' completed. Runtime: {time.time() - t1} s")
1162 # Handle FPI for SCCs with more than one component
1163 else:
1164 # Set the initial guess for all coupling vars (middle of domain)
1165 scc_inputs = ChainMap(*[self[comp].inputs for comp in scc])
1166 scc_outputs = ChainMap(*[self[comp].outputs for comp in scc])
1167 coupling_vars = [scc_inputs.get(var) for var in (scc_inputs.keys() - x.keys()) if var in scc_outputs]
1168 coupling_prev = {}
1169 for var in coupling_vars:
1170 if (domain := var.get_domain()) is None:
1171 raise RuntimeError(f"Coupling variable '{var}' has an empty domain. All coupling variables "
1172 f"require a domain for the fixed-point iteration (FPI) solver.")
1174 if isinstance(domain, list): # Latent coefficients are the coupling variables
1175 for i, d in enumerate(domain):
1176 lb, ub = d
1177 coupling_prev[f'{var.name}{LATENT_STR_ID}{i}'] = np.broadcast_to((lb + ub) / 2, (N,)).copy()
1178 norm_status[f'{var.name}{LATENT_STR_ID}{i}'] = True
1179 else:
1180 lb, ub = var.normalize(domain)
1181 shape = (N,) + (1,) * len(var_shape.get(var, ()))
1182 coupling_prev[var] = np.broadcast_to((lb + ub) / 2, shape).copy()
1183 norm_status[var] = True
1185 residual_hist = deque(maxlen=anderson_mem)
1186 coupling_hist = deque(maxlen=anderson_mem)
1188 def _end_conditions_met():
1189 """Helper to compute residual, update history, and check end conditions."""
1190 residual = {}
1191 converged_idx = np.full(N, True)
1192 for var in coupling_prev:
1193 residual[var] = y[var] - coupling_prev[var]
1194 var_conv = np.all(np.abs(residual[var]) <= fpi_tol, axis=tuple(range(1, residual[var].ndim)))
1195 converged_idx = np.logical_and(converged_idx, var_conv)
1196 samples.valid_idx = np.logical_and(samples.valid_idx, ~np.isnan(coupling_prev[var]))
1197 samples.converged_idx = np.logical_or(samples.converged_idx, converged_idx)
1199 for var in coupling_prev:
1200 coupling_prev[var][samples.curr_idx, ...] = y[var][samples.curr_idx, ...]
1201 residual_hist.append(copy.deepcopy(residual))
1202 coupling_hist.append(copy.deepcopy(coupling_prev))
1204 if int(np.sum(samples.curr_idx)) == 0:
1205 if verbose:
1206 self.logger.info(f'FPI converged for SCC {scc} in {k} iterations with tol '
1207 f'{fpi_tol}. Final time: {time.time() - t1} s')
1208 return True
1210 max_error = np.max([np.max(np.abs(res[samples.curr_idx, ...])) for res in residual.values()])
1211 if verbose:
1212 self.logger.info(f'FPI iter: {k}. Max residual: {max_error}. Time: {time.time() - t1} s')
1214 if k >= max_fpi_iter:
1215 self.logger.warning(f'FPI did not converge in {max_fpi_iter} iterations for SCC {scc}: '
1216 f'{max_error} > tol {fpi_tol}. Some samples will be returned as NaN.')
1217 for var in coupling_prev:
1218 y[var][~samples.converged_idx, ...] = np.nan
1219 samples.valid_idx = np.logical_and(samples.valid_idx, samples.converged_idx)
1220 return True
1221 else:
1222 return False
1224 # Main FPI loop
1225 if verbose:
1226 self.logger.info(f"Initializing FPI for SCC {scc} ...")
1227 t1 = time.time()
1228 k = 0
1229 while True:
1230 for node in scc:
1231 # Gather inputs from exogenous and coupling sources
1232 comp = self[node]
1233 comp_input, kwds, call_model = _gather_comp_inputs(comp, coupling=coupling_prev)
1235 # Compute outputs (just don't do this FPI with expensive real models, please..)
1236 comp_output = comp.predict(comp_input, use_model=use_model.get(node), model_dir=None,
1237 index_set=index_set.get(node), incremental=incremental.get(node),
1238 misc_coeff=misc_coeff.get(node), executor=executor, **kwds)
1240 for var, arr in comp_output.items():
1241 if var == 'errors':
1242 if y.get(var) is None:
1243 y.setdefault(var, np.full((N,), None, dtype=object))
1244 global_indices = np.arange(N)[samples.curr_idx]
1246 for local_idx, err_info in arr.items():
1247 global_idx = int(global_indices[local_idx])
1248 err_info['index'] = global_idx
1249 y[var][global_idx] = err_info
1250 continue
1252 if np.issubdtype(arr.dtype, np.number): # scalars and vectorized field quantities
1253 output_shape = arr.shape[1:]
1254 if y.get(var) is not None:
1255 if output_shape != y.get(var).shape[1:]:
1256 y[var] = _merge_shapes((N, *output_shape), y[var])
1257 else:
1258 y.setdefault(var, np.full((N, *output_shape), np.nan))
1259 y[var][samples.curr_idx, ...] = arr
1260 else: # fields returned as object arrays
1261 if y.get(var) is None:
1262 y.setdefault(var, np.full((N,), None, dtype=object))
1263 y[var][samples.curr_idx] = arr
1265 if str(var).split(LATENT_STR_ID)[0] in comp.outputs:
1266 norm_status[var] = not call_model
1267 is_computed[str(var).split(LATENT_STR_ID)[0]] = True
1269 # Compute residual and check end conditions
1270 if _end_conditions_met():
1271 break
1273 # Skip anderson acceleration on first iteration
1274 if k == 0:
1275 k += 1
1276 continue
1278 # Iterate with anderson acceleration (only iterate on samples that are not yet converged)
1279 N_curr = int(np.sum(samples.curr_idx))
1280 mk = len(residual_hist) # Max of anderson mem
1281 var_shapes = []
1282 xdims = []
1283 for var in coupling_prev:
1284 shape = coupling_prev[var].shape[1:]
1285 var_shapes.append(shape)
1286 xdims.append(int(np.prod(shape)))
1287 N_couple = int(np.sum(xdims))
1288 res_snap = np.empty((N_curr, N_couple, mk)) # Shortened snapshot of residual history
1289 coupling_snap = np.empty((N_curr, N_couple, mk)) # Shortened snapshot of coupling history
1290 for i, (coupling_iter, residual_iter) in enumerate(zip(coupling_hist, residual_hist)):
1291 start_idx = 0
1292 for j, var in enumerate(coupling_prev):
1293 end_idx = start_idx + xdims[j]
1294 coupling_snap[:, start_idx:end_idx, i] = coupling_iter[var][samples.curr_idx, ...].reshape((N_curr, -1)) # noqa: E501
1295 res_snap[:, start_idx:end_idx, i] = residual_iter[var][samples.curr_idx, ...].reshape((N_curr, -1)) # noqa: E501
1296 start_idx = end_idx
1297 C = np.ones((N_curr, 1, mk))
1298 b = np.zeros((N_curr, N_couple, 1))
1299 d = np.ones((N_curr, 1, 1))
1300 alpha = np.expand_dims(constrained_lls(res_snap, b, C, d), axis=-3) # (..., 1, mk, 1)
1301 coupling_new = np.squeeze(coupling_snap[:, :, np.newaxis, :] @ alpha, axis=(-1, -2))
1302 start_idx = 0
1303 for j, var in enumerate(coupling_prev):
1304 end_idx = start_idx + xdims[j]
1305 coupling_prev[var][samples.curr_idx, ...] = coupling_new[:, start_idx:end_idx].reshape((N_curr, *var_shapes[j])) # noqa: E501
1306 start_idx = end_idx
1307 k += 1
1309 # Return all component outputs; samples that didn't converge during FPI are left as np.nan
1310 return format_outputs(y, loop_shape)
1312 def __call__(self, *args, **kwargs):
1313 """Convenience wrapper to allow calling as `ret = System(x)`."""
1314 return self.predict(*args, **kwargs)
1316 def __eq__(self, other):
1317 if not isinstance(other, System):
1318 return False
1319 return (self.components == other.components and
1320 self.name == other.name and
1321 self.train_history == other.train_history)
1323 def __getitem__(self, component: str) -> Component:
1324 """Convenience method to get a `Component` object from the `System`.
1326 :param component: the name of the component to get
1327 :returns: the `Component` object
1328 """
1329 return self.get_component(component)
1331 def get_component(self, comp_name: str) -> Component:
1332 """Return the `Component` object for this component.
1334 :param comp_name: name of the component to return
1335 :raises KeyError: if the component does not exist
1336 :returns: the `Component` object
1337 """
1338 if comp_name.lower() == 'system':
1339 return self
1340 else:
1341 for comp in self.components:
1342 if comp.name == comp_name:
1343 return comp
1344 raise KeyError(f"Component '{comp_name}' not found in system.")
1346 def _print_title_str(self, title_str: str):
1347 """Log an important message."""
1348 self.logger.info('-' * int(len(title_str)/2) + title_str + '-' * int(len(title_str)/2))
1350 def _remove_unpickleable(self) -> dict:
1351 """Remove and return unpickleable attributes before pickling (just the logger)."""
1352 stdout = False
1353 log_file = None
1354 if self._logger is not None:
1355 for handler in self._logger.handlers:
1356 if isinstance(handler, logging.StreamHandler):
1357 stdout = True
1358 break
1359 for handler in self._logger.handlers:
1360 if isinstance(handler, logging.FileHandler):
1361 log_file = handler.baseFilename
1362 break
1364 buffer = {'log_stdout': stdout, 'log_file': log_file}
1365 self.logger = None
1366 return buffer
1368 def _restore_unpickleable(self, buffer: dict):
1369 """Restore the unpickleable attributes from `buffer` after unpickling."""
1370 self.set_logger(log_file=buffer.get('log_file', None), stdout=buffer.get('log_stdout', None))
1372 def _get_test_set(self, test_set: str | Path | tuple = None) -> tuple:
1373 """Try to load a test set from the root directory if it exists."""
1374 if isinstance(test_set, tuple):
1375 return test_set # (xtest, ytest)
1376 else:
1377 ret = (None, None)
1378 if test_set is not None:
1379 test_set = Path(test_set)
1380 elif self.root_dir is not None:
1381 test_set = self.root_dir / 'test_set.pkl'
1383 if test_set is not None:
1384 if test_set.exists():
1385 with open(test_set, 'rb') as fd:
1386 data = pickle.load(fd)
1387 ret = data['test_set']
1389 return ret
1391 def _save_test_set(self, test_set: tuple = None):
1392 """Save the test set to the root directory if possible."""
1393 if self.root_dir is not None and test_set is not None:
1394 test_file = self.root_dir / 'test_set.pkl'
1395 if not test_file.exists():
1396 with open(test_file, 'wb') as fd:
1397 pickle.dump({'test_set': test_set}, fd)
1399 def save_to_file(self, filename: str, save_dir: str | Path = None, dumper=None):
1400 """Save surrogate to file. Defaults to `root/surrogates/filename.yml` with the default yaml encoder.
1402 :param filename: the name of the save file
1403 :param save_dir: the directory to save the file to (defaults to `root/surrogates` or `cwd()`)
1404 :param dumper: the encoder to use (defaults to the `amisc` yaml encoder)
1405 """
1406 from amisc import YamlLoader
1407 encoder = dumper or YamlLoader
1408 save_dir = save_dir or self.root_dir or Path.cwd()
1409 if Path(save_dir) == self.root_dir:
1410 save_dir = self.root_dir / 'surrogates'
1411 encoder.dump(self, Path(save_dir) / filename)
1413 @staticmethod
1414 def load_from_file(filename: str | Path, root_dir: str | Path = None, loader=None):
1415 """Load surrogate from file. Defaults to yaml loading. Tries to infer `amisc` directory structure.
1417 :param filename: the name of the load file
1418 :param root_dir: set this as the surrogate's root directory (will try to load from `amisc_` fmt by default)
1419 :param loader: the encoder to use (defaults to the `amisc` yaml encoder)
1420 """
1421 from amisc import YamlLoader
1422 encoder = loader or YamlLoader
1423 system = encoder.load(filename)
1424 root_dir = root_dir or system.root_dir
1426 # Try to infer amisc_root/surrogates/iter/filename structure
1427 if root_dir is None:
1428 parts = Path(filename).resolve().parts
1429 if len(parts) > 1 and parts[-2].startswith('amisc_'):
1430 root_dir = Path(filename).resolve().parent
1431 elif len(parts) > 2 and parts[-3].startswith('amisc_'):
1432 root_dir = Path(filename).resolve().parent.parent
1433 elif len(parts) > 3 and parts[-4].startswith('amisc_'):
1434 root_dir = Path(filename).resolve().parent.parent.parent
1436 system.root_dir = root_dir
1437 return system
1439 def clear(self):
1440 """Clear all surrogate model data and reset the system."""
1441 for comp in self.components:
1442 comp.clear()
1443 self.train_history.clear()
1445 def plot_slice(self, inputs: list[str] = None,
1446 outputs: list[str] = None,
1447 num_steps: int = 20,
1448 show_surr: bool = True,
1449 show_model: str | tuple | list = None,
1450 save_dir: str | Path = None,
1451 executor: Executor = None,
1452 nominal: dict[str: float] = None,
1453 random_walk: bool = False,
1454 from_file: str | Path = None,
1455 subplot_size_in: float = 3.):
1456 """Helper function to plot 1d slices of the surrogate and/or model outputs over the inputs. A single
1457 "slice" works by smoothly stepping from the lower bound of an input to its upper bound, while holding all other
1458 inputs constant at their nominal values (or smoothly varying them if `random_walk=True`).
1459 This function is useful for visualizing the behavior of the system surrogate and/or model(s) over a
1460 single input variable at a time.
1462 :param inputs: list of input variables to take 1d slices of (defaults to first 3 in `System.inputs`)
1463 :param outputs: list of model output variables to plot 1d slices of (defaults to first 3 in `System.outputs`)
1464 :param num_steps: the number of points to take in the 1d slice for each input variable; this amounts to a total
1465 of `num_steps*len(inputs)` model/surrogate evaluations
1466 :param show_surr: whether to show the surrogate prediction
1467 :param show_model: also compute and plot model predictions, `list` of ['best', 'worst', tuple(alpha), etc.]
1468 :param save_dir: base directory to save model outputs and plots (if specified)
1469 :param executor: a `concurrent.futures.Executor` object to parallelize model or surrogate evaluations
1470 :param nominal: `dict` of `var->nominal` to use as constant values for all non-sliced variables (use
1471 unnormalized values only; use `var_LATENT0` to specify nominal latent values)
1472 :param random_walk: whether to slice in a random d-dimensional direction instead of holding all non-slice
1473 variables const at `nominal`
1474 :param from_file: path to a `.pkl` file to load a saved slice from disk
1475 :param subplot_size_in: side length size of each square subplot in inches
1476 :returns: `fig, ax` with `len(inputs)` by `len(outputs)` subplots
1477 """
1478 # Manage loading important quantities from file (if provided)
1479 input_slices, output_slices_model, output_slices_surr = None, None, None
1480 if from_file is not None:
1481 with open(Path(from_file), 'rb') as fd:
1482 slice_data = pickle.load(fd)
1483 inputs = slice_data['inputs'] # Must use same input slices as save file
1484 show_model = slice_data['show_model'] # Must use same model data as save file
1485 outputs = slice_data.get('outputs') if outputs is None else outputs
1486 input_slices = slice_data['input_slices']
1487 save_dir = None # Don't run or save any models if loading from file
1489 # Set default values (take up to the first 3 inputs by default)
1490 all_inputs = self.inputs()
1491 all_outputs = self.outputs()
1492 rand_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4))
1493 if save_dir is not None:
1494 os.mkdir(Path(save_dir) / f'slice_{rand_id}')
1495 if nominal is None:
1496 nominal = dict()
1497 inputs = all_inputs[:3] if inputs is None else inputs
1498 outputs = all_outputs[:3] if outputs is None else outputs
1500 if show_model is not None and not isinstance(show_model, list):
1501 show_model = [show_model]
1503 # Handle field quantities (directly use latent variables or only the first one)
1504 for i, var in enumerate(list(inputs)):
1505 if LATENT_STR_ID not in str(var) and all_inputs[var].compression is not None:
1506 inputs[i] = f'{var}{LATENT_STR_ID}0'
1507 for i, var in enumerate(list(outputs)):
1508 if LATENT_STR_ID not in str(var) and all_outputs[var].compression is not None:
1509 outputs[i] = f'{var}{LATENT_STR_ID}0'
1511 bds = all_inputs.get_domains()
1512 xlabels = [all_inputs[var].get_tex(units=False) if LATENT_STR_ID not in str(var) else
1513 all_inputs[str(var).split(LATENT_STR_ID)[0]].get_tex(units=False) +
1514 f' (latent {str(var).split(LATENT_STR_ID)[1]})' for var in inputs]
1516 ylabels = [all_outputs[var].get_tex(units=False) if LATENT_STR_ID not in str(var) else
1517 all_outputs[str(var).split(LATENT_STR_ID)[0]].get_tex(units=False) +
1518 f' (latent {str(var).split(LATENT_STR_ID)[1]})' for var in outputs]
1520 # Construct slices of model inputs (if not provided)
1521 if input_slices is None:
1522 input_slices = {} # Each input variable with shape (num_steps, num_slice)
1523 for i in range(len(inputs)):
1524 if random_walk:
1525 # Make a random straight-line walk across d-cube
1526 r0 = self.sample_inputs((1,), use_pdf=False)
1527 rf = self.sample_inputs((1,), use_pdf=False)
1529 for var, bd in bds.items():
1530 if var == inputs[i]:
1531 r0[var] = np.atleast_1d(bd[0]) # Start slice at this lower bound
1532 rf[var] = np.atleast_1d(bd[1]) # Slice up to this upper bound
1534 step_size = (rf[var] - r0[var]) / (num_steps - 1)
1535 arr = r0[var] + step_size * np.arange(num_steps)
1537 input_slices[var] = arr[..., np.newaxis] if input_slices.get(var) is None else (
1538 np.concatenate((input_slices[var], arr[..., np.newaxis]), axis=-1))
1539 else:
1540 # Otherwise, only slice one variable
1541 for var, bd in bds.items():
1542 nom = nominal.get(var, np.mean(bd)) if LATENT_STR_ID in str(var) else (
1543 all_inputs[var].normalize(nominal.get(var, all_inputs[var].get_nominal())))
1544 arr = np.linspace(bd[0], bd[1], num_steps) if var == inputs[i] else np.full(num_steps, nom)
1546 input_slices[var] = arr[..., np.newaxis] if input_slices.get(var) is None else (
1547 np.concatenate((input_slices[var], arr[..., np.newaxis]), axis=-1))
1549 # Walk through each model that is requested by show_model
1550 if show_model is not None:
1551 if from_file is not None:
1552 output_slices_model = slice_data['output_slices_model']
1553 else:
1554 output_slices_model = list()
1555 for model in show_model:
1556 output_dir = None
1557 if save_dir is not None:
1558 output_dir = (Path(save_dir) / f'slice_{rand_id}' /
1559 str(model).replace('{', '').replace('}', '').replace(':', '=').replace("'", ''))
1560 os.mkdir(output_dir)
1561 output_slices_model.append(self.predict(input_slices, use_model=model, model_dir=output_dir,
1562 executor=executor))
1563 if show_surr:
1564 output_slices_surr = self.predict(input_slices, executor=executor) \
1565 if from_file is None else slice_data['output_slices_surr']
1567 # Make len(outputs) by len(inputs) grid of subplots
1568 fig, axs = plt.subplots(len(outputs), len(inputs), sharex='col', sharey='row', squeeze=False)
1569 for i, output_var in enumerate(outputs):
1570 for j, input_var in enumerate(inputs):
1571 ax = axs[i, j]
1572 x = input_slices[input_var][:, j]
1574 if show_model is not None:
1575 c = np.array([[0, 0, 0, 1], [0.5, 0.5, 0.5, 1]]) if len(show_model) <= 2 else (
1576 plt.get_cmap('jet')(np.linspace(0, 1, len(show_model))))
1577 for k in range(len(show_model)):
1578 model_str = (str(show_model[k]).replace('{', '').replace('}', '')
1579 .replace(':', '=').replace("'", ''))
1580 model_ret = to_surrogate_dataset(output_slices_model[k], all_outputs)[0]
1581 y_model = model_ret[output_var][:, j]
1582 label = {'best': 'High-fidelity' if len(show_model) > 1 else 'Model',
1583 'worst': 'Low-fidelity'}.get(model_str, model_str)
1584 ax.plot(x, y_model, ls='-', c=c[k, :], label=label)
1586 if show_surr:
1587 y_surr = output_slices_surr[output_var][:, j]
1588 ax.plot(x, y_surr, '--r', label='Surrogate')
1590 ax.set_xlabel(xlabels[j] if i == len(outputs) - 1 else '')
1591 ax.set_ylabel(ylabels[i] if j == 0 else '')
1592 if i == 0 and j == len(inputs) - 1:
1593 ax.legend()
1594 fig.set_size_inches(subplot_size_in * len(inputs), subplot_size_in * len(outputs))
1595 fig.tight_layout()
1597 # Save results (unless we were already loading from a save file)
1598 if from_file is None and save_dir is not None:
1599 fname = f'in={",".join([str(v) for v in inputs])}_out={",".join([str(v) for v in outputs])}'
1600 fname = f'slice_rand{rand_id}_' + fname if random_walk else f'slice_nom{rand_id}_' + fname
1601 fdir = Path(save_dir) / f'slice_{rand_id}'
1602 fig.savefig(fdir / f'{fname}.pdf', bbox_inches='tight', format='pdf')
1603 save_dict = {'inputs': inputs, 'outputs': outputs, 'show_model': show_model, 'show_surr': show_surr,
1604 'nominal': nominal, 'random_walk': random_walk, 'input_slices': input_slices,
1605 'output_slices_model': output_slices_model, 'output_slices_surr': output_slices_surr}
1606 with open(fdir / f'{fname}.pkl', 'wb') as fd:
1607 pickle.dump(save_dict, fd)
1609 return fig, axs
1611 def get_allocation(self):
1612 """Get a breakdown of cost allocation during training.
1614 :returns: `cost_alloc, model_cost, overhead_cost, model_evals` - the cost allocation per model/fidelity,
1615 the model evaluation cost per iteration (in s of CPU time), the algorithmic overhead cost per
1616 iteration, and the total number of model evaluations at each training iteration
1617 """
1618 cost_alloc = dict() # Cost allocation (cpu time in s) per node and model fidelity
1619 model_cost = [] # Cost of model evaluations (CPU time in s) per iteration
1620 overhead_cost = [] # Algorithm overhead costs (CPU time in s) per iteration
1621 model_evals = [] # Number of model evaluations at each training iteration
1623 prev_cands = {comp.name: IndexSet() for comp in self.components} # empty candidate sets
1625 # Add cumulative training costs
1626 for train_res, active_sets, cand_sets, misc_coeff_train, misc_coeff_test in self.simulate_fit():
1627 comp = train_res['component']
1628 alpha = train_res['alpha']
1629 beta = train_res['beta']
1630 overhead = train_res['overhead_s']
1632 cost_alloc.setdefault(comp, dict())
1634 new_cands = cand_sets[comp].union({(alpha, beta)}) - prev_cands[comp] # newly computed candidates
1636 iter_cost = 0.
1637 iter_eval = 0
1638 for alpha_new, beta_new in new_cands:
1639 cost_alloc[comp].setdefault(alpha_new, 0.)
1641 added_eval = self[comp].get_cost(alpha_new, beta_new)
1642 single_cost = self[comp].model_costs.get(alpha_new, 1.)
1644 iter_cost += added_eval * single_cost
1645 iter_eval += added_eval
1647 cost_alloc[comp][alpha_new] += added_eval * single_cost
1649 overhead_cost.append(overhead)
1650 model_cost.append(iter_cost)
1651 model_evals.append(iter_eval)
1652 prev_cands[comp] = cand_sets[comp].union({(alpha, beta)})
1654 return cost_alloc, np.atleast_1d(model_cost), np.atleast_1d(overhead_cost), np.atleast_1d(model_evals)
1656 def plot_allocation(self, cmap: str = 'Blues', text_bar_width: float = 0.06, arrow_bar_width: float = 0.02):
1657 """Plot bar charts showing cost allocation during training.
1659 !!! Warning "Beta feature"
1660 This has pretty good default settings, but it might look terrible for your use. Mostly provided here as
1661 a template for making cost allocation bar charts. Please feel free to copy and edit in your own code.
1663 :param cmap: the colormap string identifier for `plt`
1664 :param text_bar_width: the minimum total cost fraction above which a bar will print centered model fidelity text
1665 :param arrow_bar_width: the minimum total cost fraction above which a bar will try to print text with an arrow;
1666 below this amount, the bar is too skinny and won't print any text
1667 :returns: `fig, ax`, Figure and Axes objects
1668 """
1669 # Get total cost
1670 cost_alloc, model_cost, _, _ = self.get_allocation()
1671 total_cost = np.cumsum(model_cost)[-1]
1673 # Remove nodes with cost=0 from alloc dicts (i.e. analytical models)
1674 remove_nodes = []
1675 for node, alpha_dict in cost_alloc.items():
1676 if len(alpha_dict) == 0:
1677 remove_nodes.append(node)
1678 for node in remove_nodes:
1679 del cost_alloc[node]
1681 # Bar chart showing cost allocation breakdown for MF system at final iteration
1682 fig, ax = plt.subplots(figsize=(6, 5), layout='tight')
1683 width = 0.7
1684 x = np.arange(len(cost_alloc))
1685 xlabels = list(cost_alloc.keys()) # One bar for each component
1686 cmap = plt.get_cmap(cmap)
1688 for j, (node, alpha_dict) in enumerate(cost_alloc.items()):
1689 bottom = 0
1690 c_intervals = np.linspace(0, 1, len(alpha_dict))
1691 bars = [(alpha, cost, cost / total_cost) for alpha, cost in alpha_dict.items()]
1692 bars = sorted(bars, key=lambda ele: ele[2], reverse=True)
1693 for i, (alpha, cost, frac) in enumerate(bars):
1694 p = ax.bar(x[j], frac, width, color=cmap(c_intervals[i]), linewidth=1,
1695 edgecolor=[0, 0, 0], bottom=bottom)
1696 bottom += frac
1697 num_evals = round(cost / self[node].model_costs.get(alpha, 1.))
1698 if frac > text_bar_width:
1699 ax.bar_label(p, labels=[f'{alpha}, {num_evals}'], label_type='center')
1700 elif frac > arrow_bar_width:
1701 xy = (x[j] + width / 2, bottom - frac / 2) # Label smaller bars with a text off to the side
1702 ax.annotate(f'{alpha}, {num_evals}', xy, xytext=(xy[0] + 0.2, xy[1]),
1703 arrowprops={'arrowstyle': '->', 'linewidth': 1})
1704 else:
1705 pass # Don't label really small bars
1706 ax.set_xlabel('')
1707 ax.set_ylabel('Fraction of total cost')
1708 ax.set_xticks(x, xlabels)
1709 ax.set_xlim(left=-1, right=x[-1] + 1)
1711 if self.root_dir is not None:
1712 fig.savefig(Path(self.root_dir) / 'mf_allocation.pdf', bbox_inches='tight', format='pdf')
1714 return fig, ax
1716 def serialize(self, keep_components=False, serialize_args=None, serialize_kwargs=None) -> dict:
1717 """Convert to a `dict` with only standard Python types for fields.
1719 :param keep_components: whether to serialize the components as well (defaults to False)
1720 :param serialize_args: `dict` of arguments to pass to each component's serialize method
1721 :param serialize_kwargs: `dict` of keyword arguments to pass to each component's serialize method
1722 :returns: a `dict` representation of the `System` object
1723 """
1724 serialize_args = serialize_args or dict()
1725 serialize_kwargs = serialize_kwargs or dict()
1726 d = {}
1727 for key, value in self.__dict__.items():
1728 if value is not None and not key.startswith('_'):
1729 if key == 'components' and not keep_components:
1730 d[key] = [comp.serialize(keep_yaml_objects=False,
1731 serialize_args=serialize_args.get(comp.name),
1732 serialize_kwargs=serialize_kwargs.get(comp.name))
1733 for comp in value]
1734 elif key == 'train_history':
1735 if len(value) > 0:
1736 d[key] = value.serialize()
1737 else:
1738 if not isinstance(value, _builtin):
1739 self.logger.warning(f"Attribute '{key}' of type '{type(value)}' may not be a builtin "
1740 f"Python type. This may cause issues when saving/loading from file.")
1741 d[key] = value
1743 for key, value in self.model_extra.items():
1744 if isinstance(value, _builtin):
1745 d[key] = value
1747 return d
1749 @classmethod
1750 def deserialize(cls, serialized_data: dict) -> System:
1751 """Construct a `System` object from serialized data."""
1752 return cls(**serialized_data)
1754 @staticmethod
1755 def _yaml_representer(dumper: yaml.Dumper, system: System):
1756 """Serialize a `System` object to a YAML representation."""
1757 return dumper.represent_mapping(System.yaml_tag, system.serialize(keep_components=True))
1759 @staticmethod
1760 def _yaml_constructor(loader: yaml.Loader, node):
1761 """Convert the `!SystemSurrogate` tag in yaml to a [`System`][amisc.system.System] object."""
1762 if isinstance(node, yaml.SequenceNode):
1763 return [ele if isinstance(ele, System) else System.deserialize(ele)
1764 for ele in loader.construct_sequence(node, deep=True)]
1765 elif isinstance(node, yaml.MappingNode):
1766 return System.deserialize(loader.construct_mapping(node, deep=True))
1767 else:
1768 raise NotImplementedError(f'The "{System.yaml_tag}" yaml tag can only be used on a yaml sequence or '
1769 f'mapping, not a "{type(node)}".')