Coverage for src/amisc/system.py: 87%
1006 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
1"""The `System` object is a framework for multidisciplinary models. It manages multiple single discipline component
2models and the connections between them. It provides a top-level interface for constructing and evaluating surrogates.
4Features:
6- Manages multidisciplinary models in a graph data structure, supports feedforward and feedback connections
7- Feedback connections are solved with a fixed-point iteration (FPI) nonlinear solver with anderson acceleration
8- Top-level interface for training and using surrogates of each component model
9- Adaptive experimental design for choosing training data efficiently
10- Convenient testing, plotting, and performance metrics provided to assess quality of surrogates
11- Detailed logging and traceback information
12- Supports parallel or vectorized execution of component models
13- Abstract and flexible interfacing with component models
14- Easy serialization and deserialization to/from YAML files
15- Supports approximating field quantities via compression
17Includes:
19- `TrainHistory` — a history of training iterations for the system surrogate
20- `System` — the top-level object for managing multidisciplinary models
21"""
22# ruff: noqa: E702
23from __future__ import annotations
25import copy
26import datetime
27import functools
28import logging
29import os
30import pickle
31import random
32import string
33import time
34import warnings
35from collections import ChainMap, UserList, deque
36from concurrent.futures import ALL_COMPLETED, Executor, wait
37from datetime import timezone
38from pathlib import Path
39from typing import Annotated, Callable, ClassVar, Literal, Optional
41import matplotlib.pyplot as plt
42import networkx as nx
43import numpy as np
44import yaml
45from matplotlib.ticker import MaxNLocator
46from pydantic import BaseModel, ConfigDict, Field, field_validator
48from amisc.component import Component, IndexSet, MiscTree
49from amisc.serialize import Serializable, _builtin
50from amisc.typing import COORDS_STR_ID, LATENT_STR_ID, Dataset, MultiIndex, TrainIteration
51from amisc.utils import (
52 _combine_latent_arrays,
53 constrained_lls,
54 format_inputs,
55 format_outputs,
56 get_logger,
57 relative_error,
58 to_model_dataset,
59 to_surrogate_dataset,
60)
61from amisc.variable import VariableList
63__all__ = ['TrainHistory', 'System']
66class TrainHistory(UserList, Serializable):
67 """Stores the training history of a system surrogate as a list of `TrainIteration` objects."""
69 def __init__(self, data: list = None):
70 data = data or []
71 super().__init__(self._validate_data(data))
73 def serialize(self) -> list[dict]:
74 """Return a list of each result in the history serialized to a `dict`."""
75 ret_list = []
76 for res in self:
77 new_res = res.copy()
78 new_res['alpha'] = str(res['alpha'])
79 new_res['beta'] = str(res['beta'])
80 ret_list.append(new_res)
81 return ret_list
83 @classmethod
84 def deserialize(cls, serialized_data: list[dict]) -> TrainHistory:
85 """Deserialize a list of `dict` objects into a `TrainHistory` object."""
86 return TrainHistory(serialized_data)
88 @classmethod
89 def _validate_data(cls, data: list[dict]) -> list[TrainIteration]:
90 return [cls._validate_item(item) for item in data]
92 @classmethod
93 def _validate_item(cls, item: dict):
94 """Format a `TrainIteration` `dict` item before appending to the history."""
95 item.setdefault('test_error', None)
96 item.setdefault('overhead_s', 0.0)
97 item.setdefault('model_s', 0.0)
98 item['alpha'] = MultiIndex(item['alpha'])
99 item['beta'] = MultiIndex(item['beta'])
100 item['num_evals'] = int(item['num_evals'])
101 item['added_cost'] = float(item['added_cost'])
102 item['added_error'] = float(item['added_error'])
103 item['overhead_s'] = float(item['overhead_s'])
104 item['model_s'] = float(item['model_s'])
105 return item
107 def append(self, item: dict):
108 super().append(self._validate_item(item))
110 def __add__(self, other):
111 other_list = other.data if isinstance(other, TrainHistory) else other
112 return TrainHistory(data=self.data + other_list)
114 def extend(self, items):
115 super().extend([self._validate_item(item) for item in items])
117 def insert(self, index, item):
118 super().insert(index, self._validate_item(item))
120 def __setitem__(self, key, value):
121 super().__setitem__(key, self._validate_item(value))
123 def __eq__(self, other):
124 """Two `TrainHistory` objects are equal if they have the same length and all items are equal, excluding nans."""
125 if not isinstance(other, TrainHistory):
126 return False
127 if len(self) != len(other):
128 return False
129 for item_self, item_other in zip(self, other):
130 for key in item_self:
131 if key in item_other:
132 val_self = item_self[key]
133 val_other = item_other[key]
134 if isinstance(val_self, float) and isinstance(val_other, float):
135 if not (np.isnan(val_self) and np.isnan(val_other)) and val_self != val_other:
136 return False
137 elif isinstance(val_self, dict) and isinstance(val_other, dict):
138 for v, err in val_self.items():
139 if v in val_other:
140 err_other = val_other[v]
141 if not (np.isnan(err) and np.isnan(err_other)) and err != err_other:
142 return False
143 else:
144 return False
145 elif val_self != val_other:
146 return False
147 else:
148 return False
149 return True
152class _Converged:
153 """Helper class to track which samples have converged during `System.predict()`."""
154 def __init__(self, num_samples):
155 self.num_samples = num_samples
156 self.valid_idx = np.full(num_samples, True) # All samples are valid by default
157 self.converged_idx = np.full(num_samples, False) # For FPI convergence
159 def reset_convergence(self):
160 self.converged_idx = np.full(self.num_samples, False)
162 @property
163 def curr_idx(self):
164 return np.logical_and(self.valid_idx, ~self.converged_idx)
167def _merge_shapes(target_shape, arr):
168 """Helper to merge an array into the target shape."""
169 shape1, shape2 = target_shape, arr.shape
170 if len(shape2) > len(shape1):
171 shape1, shape2 = shape2, shape1
172 result = []
173 for i in range(len(shape1)):
174 if i < len(shape2):
175 if shape1[i] == 1:
176 result.append(shape2[i])
177 elif shape2[i] == 1:
178 result.append(shape1[i])
179 else:
180 result.append(shape1[i])
181 else:
182 result.append(1)
183 arr = arr.reshape(tuple(result))
184 return np.broadcast_to(arr, target_shape).copy()
187class System(BaseModel, Serializable):
188 """
189 Multidisciplinary (MD) surrogate framework top-level class. Construct a `System` from a list of
190 `Component` models.
192 !!! Example
193 ```python
194 def f1(x):
195 y = x ** 2
196 return y
197 def f2(y):
198 z = y + 1
199 return z
201 system = System(f1, f2)
202 ```
204 A `System` object can saved/loaded from `.yml` files using the `!System` yaml tag.
206 :ivar name: the name of the system
207 :ivar components: list of `Component` models that make up the MD system
208 :ivar train_history: history of training iterations for the system surrogate (filled in during training)
210 :ivar _root_dir: root directory where all surrogate build products are saved to file
211 :ivar _logger: logger object for the system
212 """
213 yaml_tag: ClassVar[str] = u'!System'
214 model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True, validate_default=True,
215 extra='allow')
217 name: Annotated[str, Field(default_factory=lambda: "System_" + "".join(random.choices(string.digits, k=3)))]
218 components: Callable | Component | list[Callable | Component]
219 train_history: list[dict] | TrainHistory = TrainHistory()
220 amisc_version: str = None
222 _root_dir: Optional[str]
223 _logger: Optional[logging.Logger] = None
225 def __init__(self, /, *args, components=None, root_dir=None, **kwargs):
226 """Construct a `System` object from a list of `Component` models in `*args` or `components`. If
227 a `root_dir` is provided, then a new directory will be created under `root_dir` with the name
228 `amisc_{timestamp}`. This directory will be used to save all build products and log files.
230 :param components: list of `Component` models that make up the MD system
231 :param root_dir: root directory where all surrogate build products are saved to file (optional)
232 """
233 if components is None:
234 components = []
235 for a in args:
236 if isinstance(a, Component) or callable(a):
237 components.append(a)
238 else:
239 try:
240 components.extend(a)
241 except TypeError as e:
242 raise ValueError(f"Invalid component: {a}") from e
244 import amisc
245 amisc_version = kwargs.pop('amisc_version', amisc.__version__)
246 super().__init__(components=components, amisc_version=amisc_version, **kwargs)
247 self.root_dir = root_dir
249 def __repr__(self):
250 s = f'---- {self.name} ----\n'
251 s += f'amisc version: {self.amisc_version}\n'
252 s += f'Refinement level: {self.refine_level}\n'
253 s += f'Components: {", ".join([comp.name for comp in self.components])}\n'
254 s += f'Inputs: {", ".join([var.name for var in self.inputs()])}\n'
255 s += f'Outputs: {", ".join([var.name for var in self.outputs()])}'
256 return s
258 def __str__(self):
259 return self.__repr__()
261 @field_validator('components')
262 @classmethod
263 def _validate_components(cls, comps) -> list[Component]:
264 if not isinstance(comps, list):
265 comps = [comps]
266 comps = [Component.deserialize(c) for c in comps]
268 # Merge all variables to avoid name conflicts
269 merged_vars = VariableList.merge(*[comp.inputs for comp in comps], *[comp.outputs for comp in comps])
270 for comp in comps:
271 comp.inputs.update({var.name: var for var in merged_vars.values() if var in comp.inputs})
272 comp.outputs.update({var.name: var for var in merged_vars.values() if var in comp.outputs})
274 return comps
276 @field_validator('train_history')
277 @classmethod
278 def _validate_train_history(cls, history) -> TrainHistory:
279 if isinstance(history, TrainHistory):
280 return history
281 else:
282 return TrainHistory.deserialize(history)
284 def graph(self) -> nx.DiGraph:
285 """Build a directed graph of the system components based on their input-output relationships."""
286 graph = nx.DiGraph()
287 model_deps = {}
288 for comp in self.components:
289 graph.add_node(comp.name)
290 for output in comp.outputs:
291 model_deps[output] = comp.name
292 for comp in self.components:
293 for in_var in comp.inputs:
294 if in_var in model_deps:
295 graph.add_edge(model_deps[in_var], comp.name)
297 return graph
299 def _save_on_error(func):
300 """Gracefully exit and save the `System` object on any errors."""
301 @functools.wraps(func)
302 def wrap(self, *args, **kwargs):
303 try:
304 return func(self, *args, **kwargs)
305 except:
306 if self.root_dir is not None:
307 self.save_to_file(f'{self.name}_error.yml')
308 self.logger.critical(f'An error occurred during execution of "{func.__name__}". Saving '
309 f'System object to {self.name}_error.yml', exc_info=True)
310 self.logger.info(f'Final system surrogate on exit: \n {self}')
311 raise
312 return wrap
313 _save_on_error = staticmethod(_save_on_error)
315 def insert_components(self, components: list | Callable | Component):
316 """Insert new components into the system."""
317 components = components if isinstance(components, list) else [components]
318 self.components = self.components + components
320 def swap_component(self, old_component: str | Component, new_component: Callable | Component):
321 """Replace an old component with a new component."""
322 old_name = old_component if isinstance(old_component, str) else old_component.name
323 comps = [comp if comp.name != old_name else new_component for comp in self.components]
324 self.components = comps
326 def remove_component(self, component: str | Component):
327 """Remove a component from the system."""
328 comp_name = component if isinstance(component, str) else component.name
329 self.components = [comp for comp in self.components if comp.name != comp_name]
331 def inputs(self) -> VariableList:
332 """Collect all inputs from each component in the `System` and combine them into a
333 single [`VariableList`][amisc.variable.VariableList] object, excluding variables that are also outputs of
334 any component.
336 :returns: A [`VariableList`][amisc.variable.VariableList] containing all inputs from the components.
337 """
338 all_inputs = ChainMap(*[comp.inputs for comp in self.components])
339 return VariableList({k: all_inputs[k] for k in all_inputs.keys() - self.outputs().keys()})
341 def outputs(self) -> VariableList:
342 """Collect all outputs from each component in the `System` and combine them into a
343 single [`VariableList`][amisc.variable.VariableList] object.
345 :returns: A [`VariableList`][amisc.variable.VariableList] containing all outputs from the components.
346 """
347 return VariableList({k: v for k, v in ChainMap(*[comp.outputs for comp in self.components]).items()})
349 def coupling_variables(self) -> VariableList:
350 """Collect all coupling variables from each component in the `System` and combine them into a
351 single [`VariableList`][amisc.variable.VariableList] object.
353 :returns: A [`VariableList`][amisc.variable.VariableList] containing all coupling variables from the components.
354 """
355 all_outputs = self.outputs()
356 return VariableList({k: all_outputs[k] for k in (all_outputs.keys() &
357 ChainMap(*[comp.inputs for comp in self.components]).keys())})
359 def variables(self):
360 """Iterator over all variables in the system (inputs and outputs)."""
361 yield from ChainMap(self.inputs(), self.outputs()).values()
363 @property
364 def refine_level(self) -> int:
365 """The total number of training iterations."""
366 return len(self.train_history)
368 @property
369 def logger(self) -> logging.Logger:
370 return self._logger
372 @logger.setter
373 def logger(self, logger: logging.Logger):
374 self._logger = logger
375 for comp in self.components:
376 comp.logger = logger
378 @staticmethod
379 def timestamp() -> str:
380 """Return a UTC timestamp string in the isoformat `YYYY-MM-DDTHH.MM.SS`."""
381 return datetime.datetime.now(tz=timezone.utc).isoformat().split('.')[0].replace(':', '.')
383 @property
384 def root_dir(self):
385 """Return the root directory of the surrogate (if available), otherwise `None`."""
386 return Path(self._root_dir) if self._root_dir is not None else None
388 @root_dir.setter
389 def root_dir(self, root_dir: str | Path):
390 """Set the root directory for all build products. If `root_dir` is `None`, then no products will be saved.
391 Otherwise, log files, model outputs, surrogate files, etc. will be saved under this directory.
393 !!! Note "`amisc` root directory"
394 If `root_dir` is not `None`, then a new directory will be created under `root_dir` with the name
395 `amisc_{timestamp}`. This directory will be used to save all build products. If `root_dir` matches the
396 `amisc_*` format, then it will be used directly.
398 :param root_dir: the root directory for all build products
399 """
400 if root_dir is not None:
401 parts = Path(root_dir).resolve().parts
402 if parts[-1].startswith('amisc_'):
403 self._root_dir = Path(root_dir).resolve().as_posix()
404 if not self.root_dir.is_dir():
405 os.mkdir(self.root_dir)
406 else:
407 root_dir = Path(root_dir) / ('amisc_' + self.timestamp())
408 os.mkdir(root_dir)
409 self._root_dir = Path(root_dir).resolve().as_posix()
411 log_file = None
412 if not (pth := self.root_dir / 'surrogates').is_dir():
413 os.mkdir(pth)
414 if not (pth := self.root_dir / 'components').is_dir():
415 os.mkdir(pth)
416 for comp in self.components:
417 if comp.model_kwarg_requested('output_path'):
418 if not (comp_pth := pth / comp.name).is_dir():
419 os.mkdir(comp_pth)
420 for f in os.listdir(self.root_dir):
421 if f.endswith('.log'):
422 log_file = (self.root_dir / f).resolve().as_posix()
423 break
424 if log_file is None:
425 log_file = (self.root_dir / f'amisc_{self.timestamp()}.log').resolve().as_posix()
426 self.set_logger(log_file=log_file)
428 else:
429 self._root_dir = None
430 self.set_logger(log_file=None)
432 def set_logger(self, log_file: str | Path | bool = None, stdout: bool = None, logger: logging.Logger = None,
433 level: int = logging.INFO):
434 """Set a new `logging.Logger` object.
436 :param log_file: log to this file if str or Path (defaults to whatever is currently set or empty);
437 set `False` to remove file logging or set `True` to create a default log file in the root dir
438 :param stdout: whether to connect the logger to console (defaults to whatever is currently set or `False`)
439 :param logger: the logging object to use (this will override the `log_file` and `stdout` arguments if set);
440 if `None`, then a new logger is created according to `log_file` and `stdout`
441 :param level: the logging level to set the logger to (defaults to `logging.INFO`)
442 """
443 # Decide whether to use stdout
444 if stdout is None:
445 stdout = False
446 if self._logger is not None:
447 for handler in self._logger.handlers:
448 if isinstance(handler, logging.StreamHandler):
449 stdout = True
450 break
452 # Decide what log_file to use (if any)
453 if log_file is True:
454 log_file = pth / f'amisc_{self.timestamp()}.log' if (pth := self.root_dir) is not None else (
455 f'amisc_{self.timestamp()}.log')
456 elif log_file is None:
457 if self._logger is not None:
458 for handler in self._logger.handlers:
459 if isinstance(handler, logging.FileHandler):
460 log_file = handler.baseFilename
461 break
462 elif log_file is False:
463 log_file = None
465 self._logger = logger or get_logger(self.name, log_file=log_file, stdout=stdout, level=level)
467 for comp in self.components:
468 comp.set_logger(log_file=log_file, stdout=stdout, logger=logger, level=level)
470 def sample_inputs(self, size: tuple | int,
471 component: str = 'System',
472 normalize: bool = True,
473 use_pdf: bool | str | list[str] = False,
474 include: str | list[str] = None,
475 exclude: str | list[str] = None,
476 nominal: dict[str, float] = None) -> Dataset:
477 """Return samples of the inputs according to provided options. Will return samples in the
478 normalized/compressed space of the surrogate by default. See [`to_model_dataset`][amisc.utils.to_model_dataset]
479 to convert the samples to be usable by the true model directly.
481 :param size: tuple or integer specifying shape or number of samples to obtain
482 :param component: which component to sample inputs for (defaults to full system exogenous inputs)
483 :param normalize: whether to normalize the samples (defaults to True)
484 :param use_pdf: whether to sample from variable pdfs (defaults to False, which will instead sample from the
485 variable domain bounds). If a string or list of strings is provided, then only those variables
486 or variable categories will be sampled using their pdfs.
487 :param include: a list of variable or variable categories to include in the sampling. Defaults to using all
488 input variables.
489 :param exclude: a list of variable or variable categories to exclude from the sampling. Empty by default.
490 :param nominal: `dict(var_id=value)` of nominal values for params with relative uncertainty. Specify nominal
491 values as unnormalized (will be normalized if `normalize=True`)
492 :returns: `dict` of `(*size,)` samples for each selected input variable
493 """
494 size = (size, ) if isinstance(size, int) else size
495 nominal = nominal or dict()
496 inputs = self.inputs() if component == 'System' else self[component].inputs
497 if include is None:
498 include = []
499 if not isinstance(include, list):
500 include = [include]
501 if exclude is None:
502 exclude = []
503 if not isinstance(exclude, list):
504 exclude = [exclude]
505 if isinstance(use_pdf, str):
506 use_pdf = [use_pdf]
508 selected_inputs = []
509 for var in inputs:
510 if len(include) == 0 or var.name in include or var.category in include:
511 if var.name not in exclude and var.category not in exclude:
512 selected_inputs.append(var)
514 samples = {}
515 for var in selected_inputs:
516 # Sample from latent variable domains for field quantities
517 if var.compression is not None:
518 latent = var.sample_domain(size)
519 for i in range(latent.shape[-1]):
520 samples[f'{var.name}{LATENT_STR_ID}{i}'] = latent[..., i]
522 # Sample scalars normally
523 else:
524 lb, ub = var.get_domain()
525 pdf = (var.name in use_pdf or var.category in use_pdf) if isinstance(use_pdf, list) else use_pdf
526 nom = nominal.get(var.name, None)
528 x_sample = var.sample(size, nominal=nom) if pdf else var.sample_domain(size)
529 good_idx = (x_sample < ub) & (x_sample > lb)
530 num_reject = np.sum(~good_idx)
532 while num_reject > 0:
533 new_sample = var.sample((num_reject,), nominal=nom) if pdf else var.sample_domain((num_reject,))
534 x_sample[~good_idx] = new_sample
535 good_idx = (x_sample < ub) & (x_sample > lb)
536 num_reject = np.sum(~good_idx)
538 samples[var.name] = var.normalize(x_sample) if normalize else x_sample
540 return samples
542 def simulate_fit(self):
543 """Loop back through training history and simulate each iteration. Will yield the internal data structures
544 of each `Component` surrogate after each iteration of training (without needing to call `fit()` or any
545 of the underlying models). This might be useful, for example, for computing the surrogate predictions on
546 a new test set or viewing cumulative training costs.
548 !!! Example
549 Say you have a new test set: `(new_xtest, new_ytest)`, and you want to compute the accuracy of the
550 surrogate fit at each iteration of the training history:
552 ```python
553 for train_iter, active_sets, candidate_sets, misc_coeff_train, misc_coeff_test in system.simulate_fit():
554 # Do something with the surrogate data structures
555 new_ysurr = system.predict(new_xtest, index_set=active_sets, misc_coeff=misc_coeff_train)
556 train_error = relative_error(new_ysurr, new_ytest)
557 ```
559 :return: a generator of the active index sets, candidate index sets, and MISC coefficients
560 of each component model at each iteration of the training history
561 """
562 # "Simulated" data structures for each component
563 active_sets = {comp.name: IndexSet() for comp in self.components} # active index sets for each component
564 candidate_sets = {comp.name: IndexSet() for comp in self.components} # candidate sets for each component
565 misc_coeff_train = {comp.name: MiscTree() for comp in self.components} # MISC coeff for active sets
566 misc_coeff_test = {comp.name: MiscTree() for comp in self.components} # MISC coeff for active + candidate sets
568 for train_result in self.train_history:
569 # The selected refinement component and indices
570 comp_star = train_result['component']
571 alpha_star = train_result['alpha']
572 beta_star = train_result['beta']
573 comp = self[comp_star]
575 # Get forward neighbors for the selected index
576 neighbors = comp._neighbors(alpha_star, beta_star, active_set=active_sets[comp_star], forward=True)
578 # "Activate" the index in the simulated data structure
579 s = set()
580 s.add((alpha_star, beta_star))
581 comp.update_misc_coeff(IndexSet(s), index_set=active_sets[comp_star],
582 misc_coeff=misc_coeff_train[comp_star])
584 if (alpha_star, beta_star) in candidate_sets[comp_star]:
585 candidate_sets[comp_star].remove((alpha_star, beta_star))
586 else:
587 # Only for initial index which didn't come from the candidate set
588 comp.update_misc_coeff(IndexSet(s), index_set=active_sets[comp_star].union(candidate_sets[comp_star]),
589 misc_coeff=misc_coeff_test[comp_star])
590 active_sets[comp_star].update(s)
592 comp.update_misc_coeff(neighbors, index_set=active_sets[comp_star].union(candidate_sets[comp_star]),
593 misc_coeff=misc_coeff_test[comp_star]) # neighbors will only ever pass here once
594 candidate_sets[comp_star].update(neighbors)
596 # Caller can now do whatever they want as if the system surrogate were at this training iteration
597 # See the "index_set" and "misc_coeff" overrides for `System.predict()` for example
598 yield train_result, active_sets, candidate_sets, misc_coeff_train, misc_coeff_test
600 def add_output(self):
601 """Add an output variable retroactively to a component surrogate. User should provide a callable that
602 takes a save path and extracts the model output data for given training point/location.
603 """
604 # TODO
605 # Loop back through the surrogate training history
606 # Simulate activate_index and extract the model output from file rather than calling the model
607 # Update all interpolator states
608 raise NotImplementedError
610 @_save_on_error
611 def fit(self, targets: list = None,
612 num_refine: int = 100,
613 max_iter: int = 20,
614 max_tol: float = 1e-3,
615 runtime_hr: float = 1.,
616 estimate_bounds: bool = False,
617 update_bounds: bool = True,
618 test_set: tuple | str | Path = None,
619 start_test_check: int = None,
620 save_interval: int = 0,
621 plot_interval: int = 1,
622 cache_interval: int = 0,
623 executor: Executor = None,
624 weight_fcns: dict[str, callable] | Literal['pdf'] | None = 'pdf'):
625 """Train the system surrogate adaptively by iterative refinement until an end condition is met.
627 :param targets: list of system output variables to focus refinement on, use all outputs if not specified
628 :param num_refine: number of input samples to compute error indicators on
629 :param max_iter: the maximum number of refinement steps to take
630 :param max_tol: the max allowable value in relative L2 error to achieve
631 :param runtime_hr: the threshold wall clock time (hr) at which to stop further refinement (will go
632 until all models finish the current iteration)
633 :param estimate_bounds: whether to estimate bounds for the coupling variables; will only try to estimate from
634 the `test_set` if provided (defaults to `True`). Otherwise, you should manually
635 provide domains for all coupling variables.
636 :param update_bounds: whether to continuously update coupling variable bounds during refinement
637 :param test_set: `tuple` of `(xtest, ytest)` to show convergence of surrogate to the true model. The test set
638 inputs and outputs are specified as `dicts` of `np.ndarrays` with keys corresponding to the
639 variable names. Can also pass a path to a `.pkl` file that has the test set data as
640 {'test_set': (xtest, ytest)}.
641 :param start_test_check: the iteration to start checking the test set error (defaults to the number
642 of components); surrogate evaluation isn't useful during initialization so you
643 should at least allow one iteration per component before checking test set error
644 :param save_interval: number of refinement steps between each progress save, none if 0; `System.root_dir`
645 must be specified to save to file
646 :param plot_interval: how often to plot the error indicator and test set error (defaults to every iteration);
647 will only plot and save to file if a root directory is set
648 :param cache_interval: how often to cache component data in order to speed up future training iterations (at
649 the cost of additional memory usage); defaults to 0 (no caching)
650 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations (optional, but
651 recommended for expensive models)
652 :param weight_fcns: a `dict` of weight functions to apply to each input variable for training data selection;
653 defaults to using the pdf of each variable. If None, then no weighting is applied.
654 """
655 start_test_check = start_test_check or sum([1 for _ in self.components if _.has_surrogate])
656 targets = targets or self.outputs()
657 xtest, ytest = self._get_test_set(test_set)
658 max_iter = self.refine_level + max_iter
660 # Estimate bounds from test set if provided (override current bounds if they are set)
661 if estimate_bounds:
662 if ytest is not None:
663 y_samples = to_surrogate_dataset(ytest, self.outputs(), del_fields=True)[0] # normalize/compress
664 _combine_latent_arrays(y_samples)
665 coupling_vars = {k: v for k, v in self.coupling_variables().items() if k in y_samples}
666 y_min, y_max = {}, {}
667 for var in coupling_vars.values():
668 y_min[var] = np.nanmin(y_samples[var], axis=0)
669 y_max[var] = np.nanmax(y_samples[var], axis=0)
670 if var.compression is not None:
671 new_domain = list(zip(y_min[var].tolist(), y_max[var].tolist()))
672 var.update_domain(new_domain, override=True)
673 else:
674 new_domain = (float(y_min[var]), float(y_max[var]))
675 var.update_domain(var.denormalize(new_domain), override=True)
676 del y_samples
677 else:
678 self.logger.warning('Could not estimate bounds for coupling variables: no test set provided. '
679 'Make sure you manually provide (good) coupling variable domains.')
681 # Track convergence progress on the error indicator and test set (plot to file)
682 if self.root_dir is not None:
683 err_record = [res['added_error'] for res in self.train_history]
684 err_fig, err_ax = plt.subplots(figsize=(6, 5), layout='tight')
686 if xtest is not None and ytest is not None:
687 num_plot = min(len(targets), 3)
688 test_record = np.full((self.refine_level, num_plot), np.nan)
689 t_fig, t_ax = plt.subplots(1, num_plot, figsize=(3.5 * num_plot, 4), layout='tight', squeeze=False,
690 sharey='row')
691 for j, res in enumerate(self.train_history):
692 for i, var in enumerate(targets[:num_plot]):
693 if (perf := res.get('test_error')) is not None:
694 test_record[j, i] = perf[var]
696 total_overhead = 0.0
697 total_model_wall_time = 0.0
698 t_start = time.time()
699 while True:
700 # Adaptive refinement step
701 t_iter_start = time.time()
702 train_result = self.refine(targets=targets, num_refine=num_refine, update_bounds=update_bounds,
703 executor=executor, weight_fcns=weight_fcns)
704 if train_result['component'] is None:
705 self._print_title_str('Termination criteria reached: No candidates left to refine')
706 break
708 # Keep track of algorithmic overhead (before and after call_model for this iteration)
709 m_start, m_end = self[train_result['component']].get_model_timestamps() # Start and end of call_model
710 if m_start is not None and m_end is not None:
711 train_result['overhead_s'] = (m_start - t_iter_start) + (time.time() - m_end)
712 train_result['model_s'] = m_end - m_start
713 else:
714 train_result['overhead_s'] = time.time() - t_iter_start
715 train_result['model_s'] = 0.0
716 total_overhead += train_result['overhead_s']
717 total_model_wall_time += train_result['model_s']
719 curr_error = train_result['added_error']
721 # Plot progress of error indicator
722 if self.root_dir is not None:
723 err_record.append(curr_error)
725 if plot_interval > 0 and self.refine_level % plot_interval == 0:
726 err_ax.clear(); err_ax.set_yscale('log'); err_ax.grid()
727 err_ax.xaxis.set_major_locator(MaxNLocator(integer=True))
728 err_ax.plot(err_record, '-k')
729 err_ax.set_xlabel('Iteration'); err_ax.set_ylabel('Relative error indicator')
730 err_fig.savefig(str(Path(self.root_dir) / 'error_indicator.pdf'), format='pdf', bbox_inches='tight')
732 # Save performance on a test set
733 if xtest is not None and ytest is not None:
734 # don't compute if components are uninitialized
735 perf = self.test_set_performance(xtest, ytest) if self.refine_level + 1 >= start_test_check else (
736 {str(var): np.nan for var in ytest if COORDS_STR_ID not in var})
737 train_result['test_error'] = perf.copy()
739 if self.root_dir is not None:
740 test_record = np.vstack((test_record, np.array([perf[var] for var in targets[:num_plot]])))
742 if plot_interval > 0 and self.refine_level % plot_interval == 0:
743 for i in range(num_plot):
744 with warnings.catch_warnings():
745 warnings.simplefilter("ignore", UserWarning)
746 t_ax[0, i].clear(); t_ax[0, i].set_yscale('log'); t_ax[0, i].grid()
747 t_ax[0, i].xaxis.set_major_locator(MaxNLocator(integer=True))
748 t_ax[0, i].plot(test_record[:, i], '-k')
749 t_ax[0, i].set_title(self.outputs()[targets[i]].get_tex(units=True))
750 t_ax[0, i].set_xlabel('Iteration')
751 t_ax[0, i].set_ylabel('Test set relative error' if i==0 else '')
752 t_fig.savefig(str(Path(self.root_dir) / 'test_set_error.pdf'),format='pdf',bbox_inches='tight')
754 self.train_history.append(train_result)
756 if self.root_dir is not None and save_interval > 0 and self.refine_level % save_interval == 0:
757 iter_name = f'{self.name}_iter{self.refine_level}'
758 if not (pth := self.root_dir / 'surrogates' / iter_name).is_dir():
759 os.mkdir(pth)
760 self.save_to_file(f'{iter_name}.yml', save_dir=pth) # Save to an iteration-specific directory
762 if cache_interval > 0 and self.refine_level % cache_interval == 0:
763 for comp in self.components:
764 comp.cache()
766 # Check all end conditions
767 if self.refine_level >= max_iter:
768 self._print_title_str(f'Termination criteria reached: Max iteration {self.refine_level}/{max_iter}')
769 break
770 if curr_error < max_tol:
771 self._print_title_str(f'Termination criteria reached: relative error {curr_error} < tol {max_tol}')
772 break
773 if ((time.time() - t_start) / 3600.0) >= runtime_hr:
774 t_end = time.time()
775 actual = datetime.timedelta(seconds=t_end - t_start)
776 target = datetime.timedelta(seconds=runtime_hr * 3600)
777 train_surplus = ((t_end - t_start) - runtime_hr * 3600) / 3600
778 self._print_title_str(f'Termination criteria reached: runtime {str(actual)} > {str(target)}')
779 self.logger.info(f'Surplus wall time: {train_surplus:.3f}/{runtime_hr:.3f} hours '
780 f'(+{100 * train_surplus / runtime_hr:.2f}%)')
781 break
783 self.logger.info(f'Model evaluation algorithm efficiency: '
784 f'{100 * total_model_wall_time / (total_model_wall_time + total_overhead):.2f}%')
786 if self.root_dir is not None:
787 iter_name = f'{self.name}_iter{self.refine_level}'
788 if not (pth := self.root_dir / 'surrogates' / iter_name).is_dir():
789 os.mkdir(pth)
790 self.save_to_file(f'{iter_name}.yml', save_dir=pth)
792 if xtest is not None and ytest is not None:
793 self._save_test_set((xtest, ytest))
795 self.logger.info(f'Final system surrogate: \n {self}')
797 def test_set_performance(self, xtest: Dataset, ytest: Dataset, index_set='test') -> Dataset:
798 """Compute the relative L2 error on a test set for the given target output variables.
800 :param xtest: `dict` of test set input samples (unnormalized)
801 :param ytest: `dict` of test set output samples (unnormalized)
802 :param index_set: index set to use for prediction (defaults to 'train')
803 :returns: `dict` of relative L2 errors for each target output variable
804 """
805 targets = [var for var in ytest.keys() if COORDS_STR_ID not in var and var in self.outputs()]
806 coords = {var: ytest[var] for var in ytest if COORDS_STR_ID in var}
807 xtest = to_surrogate_dataset(xtest, self.inputs(), del_fields=True)[0]
808 ysurr = self.predict(xtest, index_set=index_set, targets=targets)
809 ysurr = to_model_dataset(ysurr, self.outputs(), del_latent=True, **coords)[0]
810 perf = {}
811 for var in targets:
812 # Handle relative error for object arrays (field qtys)
813 ytest_obj = np.issubdtype(ytest[var].dtype, np.object_)
814 ysurr_obj = np.issubdtype(ysurr[var].dtype, np.object_)
815 if ytest_obj or ysurr_obj:
816 _iterable = np.ndindex(ytest[var].shape) if ytest_obj else np.ndindex(ysurr[var].shape)
817 num, den = [], []
818 for index in _iterable:
819 pred, targ = ysurr[var][index], ytest[var][index]
820 num.append((pred - targ)**2)
821 den.append(targ ** 2)
822 perf[var] = float(np.sqrt(sum([np.sum(n) for n in num]) / sum([np.sum(d) for d in den])))
823 else:
824 perf[var] = float(relative_error(ysurr[var], ytest[var]))
826 return perf
828 def refine(self, targets: list = None, num_refine: int = 100, update_bounds: bool = True, executor: Executor = None,
829 weight_fcns: dict[str, callable] | Literal['pdf'] | None = 'pdf') -> TrainIteration:
830 """Perform a single adaptive refinement step on the system surrogate.
832 :param targets: list of system output variables to focus refinement on, use all outputs if not specified
833 :param num_refine: number of input samples to compute error indicators on
834 :param update_bounds: whether to continuously update coupling variable bounds during refinement
835 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations
836 :param weight_fcns: weight functions for choosing new training data for each input variable; defaults to
837 the PDFs of each variable. If None, then no weighting is applied.
838 :returns: `dict` of the refinement results indicating the chosen component and candidate index
839 """
840 self._print_title_str(f'Refining system surrogate: iteration {self.refine_level + 1}')
841 targets = targets or self.outputs()
843 # Check for uninitialized components and refine those first
844 for comp in self.components:
845 if len(comp.active_set) == 0 and comp.has_surrogate:
846 alpha_star = (0,) * len(comp.model_fidelity)
847 beta_star = (0,) * len(comp.max_beta)
848 self.logger.info(f"Initializing component {comp.name}: adding {(alpha_star, beta_star)} to active set")
849 model_dir = (pth / 'components' / comp.name) if (pth := self.root_dir) is not None else None
850 comp.activate_index(alpha_star, beta_star, model_dir=model_dir, executor=executor,
851 weight_fcns=weight_fcns)
852 num_evals = comp.get_cost(alpha_star, beta_star)
853 cost_star = comp.model_costs.get(alpha_star, 1.) * num_evals # Cpu time (s)
854 err_star = np.nan
855 return {'component': comp.name, 'alpha': alpha_star, 'beta': beta_star, 'num_evals': int(num_evals),
856 'added_cost': float(cost_star), 'added_error': float(err_star)}
858 # Compute entire integrated-surrogate on a random test set for global system QoI error estimation
859 x_samples = self.sample_inputs(num_refine)
860 y_curr = self.predict(x_samples, index_set='train', targets=targets)
861 _combine_latent_arrays(y_curr)
862 coupling_vars = {k: v for k, v in self.coupling_variables().items() if k in y_curr}
864 y_min, y_max = None, None
865 if update_bounds:
866 y_min = {var: np.nanmin(y_curr[var], axis=0, keepdims=True) for var in coupling_vars} # (1, ydim)
867 y_max = {var: np.nanmax(y_curr[var], axis=0, keepdims=True) for var in coupling_vars} # (1, ydim)
869 # Find the candidate surrogate with the largest error indicator
870 error_max, error_indicator = -np.inf, -np.inf
871 comp_star, alpha_star, beta_star, err_star, cost_star = None, None, None, -np.inf, 0
872 for comp in self.components:
873 if not comp.has_surrogate: # Skip analytic models that don't need a surrogate
874 continue
876 self.logger.info(f"Estimating error for component '{comp.name}'...")
878 if len(comp.candidate_set) > 0:
879 candidates = list(comp.candidate_set)
880 if executor is None:
881 ret = [self.predict(x_samples, targets=targets, index_set={comp.name: {(alpha, beta)}},
882 incremental={comp.name: True})
883 for alpha, beta in candidates]
884 else:
885 temp_buffer = self._remove_unpickleable()
886 futures = [executor.submit(self.predict, x_samples, targets=targets,
887 index_set={comp.name: {(alpha, beta)}}, incremental={comp.name: True})
888 for alpha, beta in candidates]
889 wait(futures, timeout=None, return_when=ALL_COMPLETED)
890 ret = [f.result() for f in futures]
891 self._restore_unpickleable(temp_buffer)
893 for i, y_cand in enumerate(ret):
894 alpha, beta = candidates[i]
895 _combine_latent_arrays(y_cand)
896 error = {}
897 for var, arr in y_cand.items():
898 if var in targets:
899 error[var] = relative_error(arr, y_curr[var], skip_nan=True)
901 if update_bounds and var in coupling_vars:
902 y_min[var] = np.nanmin(np.concatenate((y_min[var], arr), axis=0), axis=0, keepdims=True)
903 y_max[var] = np.nanmax(np.concatenate((y_max[var], arr), axis=0), axis=0, keepdims=True)
905 delta_error = np.nanmax([np.nanmax(error[var]) for var in error]) # Max error over all target QoIs
906 num_evals = comp.get_cost(alpha, beta)
907 delta_work = comp.model_costs.get(alpha, 1.) * num_evals # Cpu time (s)
908 error_indicator = delta_error / delta_work
910 self.logger.info(f"Candidate multi-index: {(alpha, beta)}. Relative error: {delta_error}. "
911 f"Error indicator: {error_indicator}.")
913 if error_indicator > error_max:
914 error_max = error_indicator
915 comp_star, alpha_star, beta_star, err_star, cost_star = (
916 comp.name, alpha, beta, delta_error, delta_work)
917 else:
918 self.logger.info(f"Component '{comp.name}' has no available candidates left!")
920 # Update all coupling variable ranges
921 if update_bounds:
922 for var in coupling_vars.values():
923 if np.all(~np.isnan(y_min[var])) and np.all(~np.isnan(y_max[var])):
924 if var.compression is not None:
925 new_domain = list(zip(np.squeeze(y_min[var], axis=0).tolist(),
926 np.squeeze(y_max[var], axis=0).tolist()))
927 var.update_domain(new_domain)
928 else:
929 new_domain = (y_min[var][0], y_max[var][0])
930 var.update_domain(var.denormalize(new_domain)) # bds will be in norm space from predict() call
932 # Add the chosen multi-index to the chosen component
933 if comp_star is not None:
934 self.logger.info(f"Candidate multi-index {(alpha_star, beta_star)} chosen for component '{comp_star}'.")
935 model_dir = (pth / 'components' / comp_star) if (pth := self.root_dir) is not None else None
936 self[comp_star].activate_index(alpha_star, beta_star, model_dir=model_dir, executor=executor,
937 weight_fcns=weight_fcns)
938 num_evals = self[comp_star].get_cost(alpha_star, beta_star)
939 else:
940 self.logger.info(f"No candidates left for refinement, iteration: {self.refine_level}")
941 num_evals = 0
943 # Return the results of the refinement step
944 return {'component': comp_star, 'alpha': alpha_star, 'beta': beta_star, 'num_evals': int(num_evals),
945 'added_cost': float(cost_star), 'added_error': float(err_star)}
947 def predict(self, x: dict | Dataset,
948 max_fpi_iter: int = 100,
949 anderson_mem: int = 10,
950 fpi_tol: float = 1e-10,
951 use_model: str | tuple | dict = None,
952 model_dir: str | Path = None,
953 verbose: bool = False,
954 index_set: dict[str: IndexSet | Literal['train', 'test']] = 'test',
955 misc_coeff: dict[str: MiscTree] = None,
956 normalized_inputs: bool = True,
957 incremental: dict[str, bool] = False,
958 targets: list[str] = None,
959 executor: Executor = None,
960 var_shape: dict[str, tuple] = None) -> Dataset:
961 """Evaluate the system surrogate at inputs `x`. Return `y = system(x)`.
963 !!! Warning "Computing the true model with feedback loops"
964 You can use this function to predict outputs for your MD system using the full-order models rather than the
965 surrogate, by specifying `use_model`. This is convenient because the `System` manages all the
966 coupled information flow between models automatically. However, it is *highly* recommended to not use
967 the full model if your system contains feedback loops. The FPI nonlinear solver would be infeasible using
968 anything more computationally demanding than the surrogate.
970 :param x: `dict` of input samples for each variable in the system
971 :param max_fpi_iter: the limit on convergence for the fixed-point iteration routine
972 :param anderson_mem: hyperparameter for tuning the convergence of FPI with anderson acceleration
973 :param fpi_tol: tolerance limit for convergence of fixed-point iteration
974 :param use_model: 'best'=highest-fidelity, 'worst'=lowest-fidelity, tuple=specific fidelity, None=surrogate,
975 specify a `dict` of the above to assign different model fidelities for diff components
976 :param model_dir: directory to save model outputs if `use_model` is specified
977 :param verbose: whether to print out iteration progress during execution
978 :param index_set: `dict(comp=[indices])` to override the active set for a component, defaults to using the
979 `test` set for every component. Can also specify `train` for any component or a valid
980 `IndexSet` object. If `incremental` is specified, will be overwritten with `train`.
981 :param misc_coeff: `dict(comp=MiscTree)` to override the default coefficients for a component, passes through
982 along with `index_set` and `incremental` to `comp.predict()`.
983 :param normalized_inputs: true if the passed inputs are compressed/normalized for surrogate evaluation
984 (default), such as inputs returned by `sample_inputs`. Set to `False` if you are
985 passing inputs as the true models would expect them instead (i.e. not normalized).
986 :param incremental: whether to add `index_set` to the current active set for each component (temporarily);
987 this will set `index_set='train'` for all other components (since incremental will
988 augment the "training" active sets, not the "testing" candidate sets)
989 :param targets: list of output variables to return, defaults to returning all system outputs
990 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations
991 :param var_shape: (Optional) `dict` of shapes for field quantity inputs in `x` -- you would only specify this
992 if passing field qtys directly to the models (i.e. not using `sample_inputs`)
993 :returns: `dict` of output variables - the surrogate approximation of the system outputs (or the true model)
994 """
995 # Format inputs and allocate space
996 var_shape = var_shape or {}
997 x, loop_shape = format_inputs(x, var_shape=var_shape) # {'x': (N, *var_shape)}
998 y = {}
999 all_inputs = ChainMap(x, y) # track all inputs (including coupling vars in y)
1000 N = int(np.prod(loop_shape))
1001 t1 = 0
1002 output_dir = None
1003 norm_status = {var: normalized_inputs for var in x} # keep track of whether inputs are normalized or not
1004 graph = self.graph()
1006 # Keep track of what outputs are computed
1007 is_computed = {}
1008 for var in (targets or self.outputs()):
1009 if (v := self.outputs().get(var, None)) is not None:
1010 if v.compression is not None:
1011 for field in v.compression.fields:
1012 is_computed[field] = False
1013 else:
1014 is_computed[var] = False
1016 def _set_default(struct: dict, default=None):
1017 """Helper to set a default value for each component key in a `dict`. Ensures all components have a value."""
1018 if struct is not None:
1019 if not isinstance(struct, dict):
1020 struct = {node: struct for node in graph.nodes} # use same for each component
1021 else:
1022 struct = {node: default for node in graph.nodes}
1023 return {node: struct.get(node, default) for node in graph.nodes}
1025 # Ensure use_model, index_set, and incremental are specified for each component model
1026 use_model = _set_default(use_model, None)
1027 incremental = _set_default(incremental, False) # default to train if incremental anywhere
1028 index_set = _set_default(index_set, 'train' if any([incremental[node] for node in graph.nodes]) else 'test')
1029 misc_coeff = _set_default(misc_coeff, None)
1031 samples = _Converged(N) # track convergence of samples
1033 def _gather_comp_inputs(comp, coupling=None):
1034 """Helper to gather inputs for a component, making sure they are normalized correctly. Any coupling
1035 variables passed in will be used in preference over `all_inputs`.
1036 """
1037 # Will access but not modify: all_inputs, use_model, norm_status
1038 field_coords = {}
1039 comp_input = {}
1040 coupling = coupling or {}
1042 # Take coupling variables as a priority
1043 comp_input.update({var: np.copy(arr[samples.curr_idx, ...]) for var, arr in
1044 coupling.items() if str(var).split(LATENT_STR_ID)[0] in comp.inputs})
1045 # Gather all other inputs
1046 for var, arr in all_inputs.items():
1047 var_id = str(var).split(LATENT_STR_ID)[0]
1048 if var_id in comp.inputs and var not in coupling:
1049 comp_input[var] = np.copy(arr[samples.curr_idx, ...])
1051 # Gather field coordinates
1052 for var in comp.inputs:
1053 coords_str = f'{var}{COORDS_STR_ID}'
1054 if (coords := all_inputs.get(coords_str)) is not None:
1055 field_coords[coords_str] = coords[samples.curr_idx, ...]
1056 elif (coords := comp.model_kwargs.get(coords_str)) is not None:
1057 field_coords[coords_str] = coords
1059 # Gather extra fields (will never be in coupling since field couplings should always be latent coeff)
1060 for var in comp.inputs:
1061 if var not in comp_input and var.compression is not None:
1062 for field in var.compression.fields:
1063 if field in all_inputs:
1064 comp_input[field] = np.copy(all_inputs[field][samples.curr_idx, ...])
1066 call_model = use_model.get(comp.name, None) is not None
1068 # Make sure we format all inputs for model evaluation (i.e. denormalize)
1069 if call_model:
1070 norm_inputs = {var: arr for var, arr in comp_input.items() if norm_status[var]}
1071 if len(norm_inputs) > 0:
1072 denorm_inputs, fc = to_model_dataset(norm_inputs, comp.inputs, del_latent=True, **field_coords)
1073 for var in norm_inputs:
1074 del comp_input[var]
1075 field_coords.update(fc)
1076 comp_input.update(denorm_inputs)
1078 # Otherwise, make sure we format inputs for surrogate evaluation (i.e. normalize)
1079 else:
1080 denorm_inputs = {var: arr for var, arr in comp_input.items() if not norm_status[var]}
1081 if len(denorm_inputs) > 0:
1082 norm_inputs, _ = to_surrogate_dataset(denorm_inputs, comp.inputs, del_fields=True, **field_coords)
1084 for var in denorm_inputs:
1085 del comp_input[var]
1086 comp_input.update(norm_inputs)
1088 return comp_input, field_coords, call_model
1090 # Convert system into DAG by grouping strongly-connected-components
1091 dag = nx.condensation(graph)
1093 # Compute component models in topological order
1094 for supernode in nx.topological_sort(dag):
1095 if np.all(list(is_computed.values())):
1096 break # Exit early if all selected return qois are computed
1098 scc = [n for n in dag.nodes[supernode]['members']]
1099 samples.reset_convergence()
1101 # Compute single component feedforward output (no FPI needed)
1102 if len(scc) == 1:
1103 if verbose:
1104 self.logger.info(f"Running component '{scc[0]}'...")
1105 t1 = time.time()
1107 # Gather inputs
1108 comp = self[scc[0]]
1109 comp_input, field_coords, call_model = _gather_comp_inputs(comp)
1111 # Compute outputs
1112 if model_dir is not None:
1113 output_dir = Path(model_dir) / scc[0]
1114 if not output_dir.exists():
1115 os.mkdir(output_dir)
1116 comp_output = comp.predict(comp_input, use_model=use_model.get(scc[0]), model_dir=output_dir,
1117 index_set=index_set.get(scc[0]), incremental=incremental.get(scc[0]),
1118 misc_coeff=misc_coeff.get(scc[0]), executor=executor, **field_coords)
1120 for var, arr in comp_output.items():
1121 is_numeric = np.issubdtype(arr.dtype, np.number)
1122 if is_numeric: # for scalars or vectorized field quantities
1123 output_shape = arr.shape[1:]
1124 if y.get(var) is None:
1125 y.setdefault(var, np.full((N, *output_shape), np.nan))
1126 y[var][samples.curr_idx, ...] = arr
1128 else: # for fields returned as object arrays
1129 if y.get(var) is None:
1130 y.setdefault(var, np.full((N,), None, dtype=object))
1131 y[var][samples.curr_idx] = arr
1133 # Update valid indices and status for component outputs
1134 if str(var).split(LATENT_STR_ID)[0] in comp.outputs:
1135 new_valid = ~np.any(np.isnan(y[var]), axis=tuple(range(1, y[var].ndim))) if is_numeric else (
1136 [~np.any(np.isnan(y[var][i])) for i in range(N)]
1137 )
1138 samples.valid_idx = np.logical_and(samples.valid_idx, new_valid)
1140 is_computed[str(var).split(LATENT_STR_ID)[0]] = True
1141 norm_status[var] = not call_model
1143 if verbose:
1144 self.logger.info(f"Component '{scc[0]}' completed. Runtime: {time.time() - t1} s")
1146 # Handle FPI for SCCs with more than one component
1147 else:
1148 # Set the initial guess for all coupling vars (middle of domain)
1149 scc_inputs = ChainMap(*[self[comp].inputs for comp in scc])
1150 scc_outputs = ChainMap(*[self[comp].outputs for comp in scc])
1151 coupling_vars = [scc_inputs.get(var) for var in (scc_inputs.keys() - x.keys()) if var in scc_outputs]
1152 coupling_prev = {}
1153 for var in coupling_vars:
1154 domain = var.get_domain()
1155 if isinstance(domain, list): # Latent coefficients are the coupling variables
1156 for i, d in enumerate(domain):
1157 lb, ub = d
1158 coupling_prev[f'{var.name}{LATENT_STR_ID}{i}'] = np.broadcast_to((lb + ub) / 2, (N,)).copy()
1159 norm_status[f'{var.name}{LATENT_STR_ID}{i}'] = True
1160 else:
1161 lb, ub = var.normalize(domain)
1162 shape = (N,) + (1,) * len(var_shape.get(var, ()))
1163 coupling_prev[var] = np.broadcast_to((lb + ub) / 2, shape).copy()
1164 norm_status[var] = True
1166 residual_hist = deque(maxlen=anderson_mem)
1167 coupling_hist = deque(maxlen=anderson_mem)
1169 def _end_conditions_met():
1170 """Helper to compute residual, update history, and check end conditions."""
1171 residual = {}
1172 converged_idx = np.full(N, True)
1173 for var in coupling_prev:
1174 residual[var] = y[var] - coupling_prev[var]
1175 var_conv = np.all(np.abs(residual[var]) <= fpi_tol, axis=tuple(range(1, residual[var].ndim)))
1176 converged_idx = np.logical_and(converged_idx, var_conv)
1177 samples.valid_idx = np.logical_and(samples.valid_idx, ~np.isnan(coupling_prev[var]))
1178 samples.converged_idx = np.logical_or(samples.converged_idx, converged_idx)
1180 for var in coupling_prev:
1181 coupling_prev[var][samples.curr_idx, ...] = y[var][samples.curr_idx, ...]
1182 residual_hist.append(copy.deepcopy(residual))
1183 coupling_hist.append(copy.deepcopy(coupling_prev))
1185 if int(np.sum(samples.curr_idx)) == 0:
1186 if verbose:
1187 self.logger.info(f'FPI converged for SCC {scc} in {k} iterations with tol '
1188 f'{fpi_tol}. Final time: {time.time() - t1} s')
1189 return True
1191 max_error = np.max([np.max(np.abs(res[samples.curr_idx, ...])) for res in residual.values()])
1192 if verbose:
1193 self.logger.info(f'FPI iter: {k}. Max residual: {max_error}. Time: {time.time() - t1} s')
1195 if k >= max_fpi_iter:
1196 self.logger.warning(f'FPI did not converge in {max_fpi_iter} iterations for SCC {scc}: '
1197 f'{max_error} > tol {fpi_tol}. Some samples will be returned as NaN.')
1198 for var in coupling_prev:
1199 y[var][~samples.converged_idx, ...] = np.nan
1200 samples.valid_idx = np.logical_and(samples.valid_idx, samples.converged_idx)
1201 return True
1202 else:
1203 return False
1205 # Main FPI loop
1206 if verbose:
1207 self.logger.info(f"Initializing FPI for SCC {scc} ...")
1208 t1 = time.time()
1209 k = 0
1210 while True:
1211 for node in scc:
1212 # Gather inputs from exogenous and coupling sources
1213 comp = self[node]
1214 comp_input, kwds, call_model = _gather_comp_inputs(comp, coupling=coupling_prev)
1216 # Compute outputs (just don't do this FPI with expensive real models, please..)
1217 comp_output = comp.predict(comp_input, use_model=use_model.get(node), model_dir=None,
1218 index_set=index_set.get(node), incremental=incremental.get(node),
1219 misc_coeff=misc_coeff.get(node), executor=executor, **kwds)
1220 for var, arr in comp_output.items():
1221 if np.issubdtype(arr.dtype, np.number): # scalars and vectorized field quantities
1222 output_shape = arr.shape[1:]
1223 if y.get(var) is not None:
1224 if output_shape != y.get(var).shape[1:]:
1225 y[var] = _merge_shapes((N, *output_shape), y[var])
1226 else:
1227 y.setdefault(var, np.full((N, *output_shape), np.nan))
1228 y[var][samples.curr_idx, ...] = arr
1229 else: # fields returned as object arrays
1230 if y.get(var) is None:
1231 y.setdefault(var, np.full((N,), None, dtype=object))
1232 y[var][samples.curr_idx] = arr
1234 if str(var).split(LATENT_STR_ID)[0] in comp.outputs:
1235 norm_status[var] = not call_model
1236 is_computed[str(var).split(LATENT_STR_ID)[0]] = True
1238 # Compute residual and check end conditions
1239 if _end_conditions_met():
1240 break
1242 # Skip anderson acceleration on first iteration
1243 if k == 0:
1244 k += 1
1245 continue
1247 # Iterate with anderson acceleration (only iterate on samples that are not yet converged)
1248 N_curr = int(np.sum(samples.curr_idx))
1249 mk = len(residual_hist) # Max of anderson mem
1250 var_shapes = []
1251 xdims = []
1252 for var in coupling_prev:
1253 shape = coupling_prev[var].shape[1:]
1254 var_shapes.append(shape)
1255 xdims.append(int(np.prod(shape)))
1256 N_couple = int(np.sum(xdims))
1257 res_snap = np.empty((N_curr, N_couple, mk)) # Shortened snapshot of residual history
1258 coupling_snap = np.empty((N_curr, N_couple, mk)) # Shortened snapshot of coupling history
1259 for i, (coupling_iter, residual_iter) in enumerate(zip(coupling_hist, residual_hist)):
1260 start_idx = 0
1261 for j, var in enumerate(coupling_prev):
1262 end_idx = start_idx + xdims[j]
1263 coupling_snap[:, start_idx:end_idx, i] = coupling_iter[var][samples.curr_idx, ...].reshape((N_curr, -1)) # noqa: E501
1264 res_snap[:, start_idx:end_idx, i] = residual_iter[var][samples.curr_idx, ...].reshape((N_curr, -1)) # noqa: E501
1265 start_idx = end_idx
1266 C = np.ones((N_curr, 1, mk))
1267 b = np.zeros((N_curr, N_couple, 1))
1268 d = np.ones((N_curr, 1, 1))
1269 alpha = np.expand_dims(constrained_lls(res_snap, b, C, d), axis=-3) # (..., 1, mk, 1)
1270 coupling_new = np.squeeze(coupling_snap[:, :, np.newaxis, :] @ alpha, axis=(-1, -2))
1271 start_idx = 0
1272 for j, var in enumerate(coupling_prev):
1273 end_idx = start_idx + xdims[j]
1274 coupling_prev[var][samples.curr_idx, ...] = coupling_new[:, start_idx:end_idx].reshape((N_curr, *var_shapes[j])) # noqa: E501
1275 start_idx = end_idx
1276 k += 1
1278 # Return all component outputs; samples that didn't converge during FPI are left as np.nan
1279 return format_outputs(y, loop_shape)
1281 def __call__(self, *args, **kwargs):
1282 """Convenience wrapper to allow calling as `ret = System(x)`."""
1283 return self.predict(*args, **kwargs)
1285 def __eq__(self, other):
1286 if not isinstance(other, System):
1287 return False
1288 return (self.components == other.components and
1289 self.name == other.name and
1290 self.train_history == other.train_history)
1292 def __getitem__(self, component: str) -> Component:
1293 """Convenience method to get a `Component` object from the `System`.
1295 :param component: the name of the component to get
1296 :returns: the `Component` object
1297 """
1298 return self.get_component(component)
1300 def get_component(self, comp_name: str) -> Component:
1301 """Return the `Component` object for this component.
1303 :param comp_name: name of the component to return
1304 :raises KeyError: if the component does not exist
1305 :returns: the `Component` object
1306 """
1307 if comp_name.lower() == 'system':
1308 return self
1309 else:
1310 for comp in self.components:
1311 if comp.name == comp_name:
1312 return comp
1313 raise KeyError(f"Component '{comp_name}' not found in system.")
1315 def _print_title_str(self, title_str: str):
1316 """Log an important message."""
1317 self.logger.info('-' * int(len(title_str)/2) + title_str + '-' * int(len(title_str)/2))
1319 def _remove_unpickleable(self) -> dict:
1320 """Remove and return unpickleable attributes before pickling (just the logger)."""
1321 stdout = False
1322 log_file = None
1323 if self._logger is not None:
1324 for handler in self._logger.handlers:
1325 if isinstance(handler, logging.StreamHandler):
1326 stdout = True
1327 break
1328 for handler in self._logger.handlers:
1329 if isinstance(handler, logging.FileHandler):
1330 log_file = handler.baseFilename
1331 break
1333 buffer = {'log_stdout': stdout, 'log_file': log_file}
1334 self.logger = None
1335 return buffer
1337 def _restore_unpickleable(self, buffer: dict):
1338 """Restore the unpickleable attributes from `buffer` after unpickling."""
1339 self.set_logger(log_file=buffer.get('log_file', None), stdout=buffer.get('log_stdout', None))
1341 def _get_test_set(self, test_set: str | Path | tuple = None) -> tuple:
1342 """Try to load a test set from the root directory if it exists."""
1343 if isinstance(test_set, tuple):
1344 return test_set # (xtest, ytest)
1345 else:
1346 ret = (None, None)
1347 if test_set is not None:
1348 test_set = Path(test_set)
1349 elif self.root_dir is not None:
1350 test_set = self.root_dir / 'test_set.pkl'
1352 if test_set is not None:
1353 if test_set.exists():
1354 with open(test_set, 'rb') as fd:
1355 data = pickle.load(fd)
1356 ret = data['test_set']
1358 return ret
1360 def _save_test_set(self, test_set: tuple = None):
1361 """Save the test set to the root directory if possible."""
1362 if self.root_dir is not None and test_set is not None:
1363 test_file = self.root_dir / 'test_set.pkl'
1364 if not test_file.exists():
1365 with open(test_file, 'wb') as fd:
1366 pickle.dump({'test_set': test_set}, fd)
1368 def save_to_file(self, filename: str, save_dir: str | Path = None, dumper=None):
1369 """Save surrogate to file. Defaults to `root/surrogates/filename.yml` with the default yaml encoder.
1371 :param filename: the name of the save file
1372 :param save_dir: the directory to save the file to (defaults to `root/surrogates` or `cwd()`)
1373 :param dumper: the encoder to use (defaults to the `amisc` yaml encoder)
1374 """
1375 from amisc import YamlLoader
1376 encoder = dumper or YamlLoader
1377 save_dir = save_dir or self.root_dir or Path.cwd()
1378 if Path(save_dir) == self.root_dir:
1379 save_dir = self.root_dir / 'surrogates'
1380 encoder.dump(self, Path(save_dir) / filename)
1382 @staticmethod
1383 def load_from_file(filename: str | Path, root_dir: str | Path = None, loader=None):
1384 """Load surrogate from file. Defaults to yaml loading. Tries to infer `amisc` directory structure.
1386 :param filename: the name of the load file
1387 :param root_dir: set this as the surrogate's root directory (will try to load from `amisc_` fmt by default)
1388 :param loader: the encoder to use (defaults to the `amisc` yaml encoder)
1389 """
1390 from amisc import YamlLoader
1391 encoder = loader or YamlLoader
1392 system = encoder.load(filename)
1393 root_dir = root_dir or system.root_dir
1395 # Try to infer amisc_root/surrogates/iter/filename structure
1396 if root_dir is None:
1397 parts = Path(filename).resolve().parts
1398 if len(parts) > 1 and parts[-2].startswith('amisc_'):
1399 root_dir = Path(filename).resolve().parent
1400 elif len(parts) > 2 and parts[-3].startswith('amisc_'):
1401 root_dir = Path(filename).resolve().parent.parent
1402 elif len(parts) > 3 and parts[-4].startswith('amisc_'):
1403 root_dir = Path(filename).resolve().parent.parent.parent
1405 system.root_dir = root_dir
1406 return system
1408 def clear(self):
1409 """Clear all surrogate model data and reset the system."""
1410 for comp in self.components:
1411 comp.clear()
1412 self.train_history.clear()
1414 def plot_slice(self, inputs: list[str] = None,
1415 outputs: list[str] = None,
1416 num_steps: int = 20,
1417 show_surr: bool = True,
1418 show_model: str | tuple | list = None,
1419 save_dir: str | Path = None,
1420 executor: Executor = None,
1421 nominal: dict[str: float] = None,
1422 random_walk: bool = False,
1423 from_file: str | Path = None,
1424 subplot_size_in: float = 3.):
1425 """Helper function to plot 1d slices of the surrogate and/or model outputs over the inputs. A single
1426 "slice" works by smoothly stepping from the lower bound of an input to its upper bound, while holding all other
1427 inputs constant at their nominal values (or smoothly varying them if `random_walk=True`).
1428 This function is useful for visualizing the behavior of the system surrogate and/or model(s) over a
1429 single input variable at a time.
1431 :param inputs: list of input variables to take 1d slices of (defaults to first 3 in `System.inputs`)
1432 :param outputs: list of model output variables to plot 1d slices of (defaults to first 3 in `System.outputs`)
1433 :param num_steps: the number of points to take in the 1d slice for each input variable; this amounts to a total
1434 of `num_steps*len(inputs)` model/surrogate evaluations
1435 :param show_surr: whether to show the surrogate prediction
1436 :param show_model: also compute and plot model predictions, `list` of ['best', 'worst', tuple(alpha), etc.]
1437 :param save_dir: base directory to save model outputs and plots (if specified)
1438 :param executor: a `concurrent.futures.Executor` object to parallelize model or surrogate evaluations
1439 :param nominal: `dict` of `var->nominal` to use as constant values for all non-sliced variables (use
1440 unnormalized values only; use `var_LATENT0` to specify nominal latent values)
1441 :param random_walk: whether to slice in a random d-dimensional direction instead of holding all non-slice
1442 variables const at `nominal`
1443 :param from_file: path to a `.pkl` file to load a saved slice from disk
1444 :param subplot_size_in: side length size of each square subplot in inches
1445 :returns: `fig, ax` with `len(inputs)` by `len(outputs)` subplots
1446 """
1447 # Manage loading important quantities from file (if provided)
1448 input_slices, output_slices_model, output_slices_surr = None, None, None
1449 if from_file is not None:
1450 with open(Path(from_file), 'rb') as fd:
1451 slice_data = pickle.load(fd)
1452 inputs = slice_data['inputs'] # Must use same input slices as save file
1453 show_model = slice_data['show_model'] # Must use same model data as save file
1454 outputs = slice_data.get('outputs') if outputs is None else outputs
1455 input_slices = slice_data['input_slices']
1456 save_dir = None # Don't run or save any models if loading from file
1458 # Set default values (take up to the first 3 inputs by default)
1459 all_inputs = self.inputs()
1460 all_outputs = self.outputs()
1461 rand_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4))
1462 if save_dir is not None:
1463 os.mkdir(Path(save_dir) / f'slice_{rand_id}')
1464 if nominal is None:
1465 nominal = dict()
1466 inputs = all_inputs[:3] if inputs is None else inputs
1467 outputs = all_outputs[:3] if outputs is None else outputs
1469 if show_model is not None and not isinstance(show_model, list):
1470 show_model = [show_model]
1472 # Handle field quantities (directly use latent variables or only the first one)
1473 for i, var in enumerate(list(inputs)):
1474 if LATENT_STR_ID not in str(var) and all_inputs[var].compression is not None:
1475 inputs[i] = f'{var}{LATENT_STR_ID}0'
1476 for i, var in enumerate(list(outputs)):
1477 if LATENT_STR_ID not in str(var) and all_outputs[var].compression is not None:
1478 outputs[i] = f'{var}{LATENT_STR_ID}0'
1480 bds = all_inputs.get_domains()
1481 xlabels = [all_inputs[var].get_tex(units=False) if LATENT_STR_ID not in str(var) else
1482 all_inputs[str(var).split(LATENT_STR_ID)[0]].get_tex(units=False) +
1483 f' (latent {str(var).split(LATENT_STR_ID)[1]})' for var in inputs]
1485 ylabels = [all_outputs[var].get_tex(units=False) if LATENT_STR_ID not in str(var) else
1486 all_outputs[str(var).split(LATENT_STR_ID)[0]].get_tex(units=False) +
1487 f' (latent {str(var).split(LATENT_STR_ID)[1]})' for var in outputs]
1489 # Construct slices of model inputs (if not provided)
1490 if input_slices is None:
1491 input_slices = {} # Each input variable with shape (num_steps, num_slice)
1492 for i in range(len(inputs)):
1493 if random_walk:
1494 # Make a random straight-line walk across d-cube
1495 r0 = self.sample_inputs((1,), use_pdf=False)
1496 rf = self.sample_inputs((1,), use_pdf=False)
1498 for var, bd in bds.items():
1499 if var == inputs[i]:
1500 r0[var] = np.atleast_1d(bd[0]) # Start slice at this lower bound
1501 rf[var] = np.atleast_1d(bd[1]) # Slice up to this upper bound
1503 step_size = (rf[var] - r0[var]) / (num_steps - 1)
1504 arr = r0[var] + step_size * np.arange(num_steps)
1506 input_slices[var] = arr[..., np.newaxis] if input_slices.get(var) is None else (
1507 np.concatenate((input_slices[var], arr[..., np.newaxis]), axis=-1))
1508 else:
1509 # Otherwise, only slice one variable
1510 for var, bd in bds.items():
1511 nom = nominal.get(var, np.mean(bd)) if LATENT_STR_ID in str(var) else (
1512 all_inputs[var].normalize(nominal.get(var, all_inputs[var].get_nominal())))
1513 arr = np.linspace(bd[0], bd[1], num_steps) if var == inputs[i] else np.full(num_steps, nom)
1515 input_slices[var] = arr[..., np.newaxis] if input_slices.get(var) is None else (
1516 np.concatenate((input_slices[var], arr[..., np.newaxis]), axis=-1))
1518 # Walk through each model that is requested by show_model
1519 if show_model is not None:
1520 if from_file is not None:
1521 output_slices_model = slice_data['output_slices_model']
1522 else:
1523 output_slices_model = list()
1524 for model in show_model:
1525 output_dir = None
1526 if save_dir is not None:
1527 output_dir = (Path(save_dir) / f'slice_{rand_id}' /
1528 str(model).replace('{', '').replace('}', '').replace(':', '=').replace("'", ''))
1529 os.mkdir(output_dir)
1530 output_slices_model.append(self.predict(input_slices, use_model=model, model_dir=output_dir,
1531 executor=executor))
1532 if show_surr:
1533 output_slices_surr = self.predict(input_slices, executor=executor) \
1534 if from_file is None else slice_data['output_slices_surr']
1536 # Make len(outputs) by len(inputs) grid of subplots
1537 fig, axs = plt.subplots(len(outputs), len(inputs), sharex='col', sharey='row', squeeze=False)
1538 for i, output_var in enumerate(outputs):
1539 for j, input_var in enumerate(inputs):
1540 ax = axs[i, j]
1541 x = input_slices[input_var][:, j]
1543 if show_model is not None:
1544 c = np.array([[0, 0, 0, 1], [0.5, 0.5, 0.5, 1]]) if len(show_model) <= 2 else (
1545 plt.get_cmap('jet')(np.linspace(0, 1, len(show_model))))
1546 for k in range(len(show_model)):
1547 model_str = (str(show_model[k]).replace('{', '').replace('}', '')
1548 .replace(':', '=').replace("'", ''))
1549 model_ret = to_surrogate_dataset(output_slices_model[k], all_outputs)[0]
1550 y_model = model_ret[output_var][:, j]
1551 label = {'best': 'High-fidelity' if len(show_model) > 1 else 'Model',
1552 'worst': 'Low-fidelity'}.get(model_str, model_str)
1553 ax.plot(x, y_model, ls='-', c=c[k, :], label=label)
1555 if show_surr:
1556 y_surr = output_slices_surr[output_var][:, j]
1557 ax.plot(x, y_surr, '--r', label='Surrogate')
1559 ax.set_xlabel(xlabels[j] if i == len(outputs) - 1 else '')
1560 ax.set_ylabel(ylabels[i] if j == 0 else '')
1561 if i == 0 and j == len(inputs) - 1:
1562 ax.legend()
1563 fig.set_size_inches(subplot_size_in * len(inputs), subplot_size_in * len(outputs))
1564 fig.tight_layout()
1566 # Save results (unless we were already loading from a save file)
1567 if from_file is None and save_dir is not None:
1568 fname = f'in={",".join([str(v) for v in inputs])}_out={",".join([str(v) for v in outputs])}'
1569 fname = f'slice_rand{rand_id}_' + fname if random_walk else f'slice_nom{rand_id}_' + fname
1570 fdir = Path(save_dir) / f'slice_{rand_id}'
1571 fig.savefig(fdir / f'{fname}.pdf', bbox_inches='tight', format='pdf')
1572 save_dict = {'inputs': inputs, 'outputs': outputs, 'show_model': show_model, 'show_surr': show_surr,
1573 'nominal': nominal, 'random_walk': random_walk, 'input_slices': input_slices,
1574 'output_slices_model': output_slices_model, 'output_slices_surr': output_slices_surr}
1575 with open(fdir / f'{fname}.pkl', 'wb') as fd:
1576 pickle.dump(save_dict, fd)
1578 return fig, axs
1580 def get_allocation(self):
1581 """Get a breakdown of cost allocation during training.
1583 :returns: `cost_alloc, model_cost, overhead_cost, model_evals` - the cost allocation per model/fidelity,
1584 the model evaluation cost per iteration (in s of CPU time), the algorithmic overhead cost per
1585 iteration, and the total number of model evaluations at each training iteration
1586 """
1587 cost_alloc = dict() # Cost allocation (cpu time in s) per node and model fidelity
1588 model_cost = [] # Cost of model evaluations (CPU time in s) per iteration
1589 overhead_cost = [] # Algorithm overhead costs (CPU time in s) per iteration
1590 model_evals = [] # Number of model evaluations at each training iteration
1592 prev_cands = {comp.name: IndexSet() for comp in self.components} # empty candidate sets
1594 # Add cumulative training costs
1595 for train_res, active_sets, cand_sets, misc_coeff_train, misc_coeff_test in self.simulate_fit():
1596 comp = train_res['component']
1597 alpha = train_res['alpha']
1598 beta = train_res['beta']
1599 overhead = train_res['overhead_s']
1601 cost_alloc.setdefault(comp, dict())
1603 new_cands = cand_sets[comp].union({(alpha, beta)}) - prev_cands[comp] # newly computed candidates
1605 iter_cost = 0.
1606 iter_eval = 0
1607 for alpha_new, beta_new in new_cands:
1608 cost_alloc[comp].setdefault(alpha_new, 0.)
1610 added_eval = self[comp].get_cost(alpha_new, beta_new)
1611 single_cost = self[comp].model_costs.get(alpha_new, 1.)
1613 iter_cost += added_eval * single_cost
1614 iter_eval += added_eval
1616 cost_alloc[comp][alpha_new] += added_eval * single_cost
1618 overhead_cost.append(overhead)
1619 model_cost.append(iter_cost)
1620 model_evals.append(iter_eval)
1621 prev_cands[comp] = cand_sets[comp].union({(alpha, beta)})
1623 return cost_alloc, np.atleast_1d(model_cost), np.atleast_1d(overhead_cost), np.atleast_1d(model_evals)
1625 def plot_allocation(self, cmap: str = 'Blues', text_bar_width: float = 0.06, arrow_bar_width: float = 0.02):
1626 """Plot bar charts showing cost allocation during training.
1628 !!! Warning "Beta feature"
1629 This has pretty good default settings, but it might look terrible for your use. Mostly provided here as
1630 a template for making cost allocation bar charts. Please feel free to copy and edit in your own code.
1632 :param cmap: the colormap string identifier for `plt`
1633 :param text_bar_width: the minimum total cost fraction above which a bar will print centered model fidelity text
1634 :param arrow_bar_width: the minimum total cost fraction above which a bar will try to print text with an arrow;
1635 below this amount, the bar is too skinny and won't print any text
1636 :returns: `fig, ax`, Figure and Axes objects
1637 """
1638 # Get total cost
1639 cost_alloc, model_cost, _, _ = self.get_allocation()
1640 total_cost = np.cumsum(model_cost)[-1]
1642 # Remove nodes with cost=0 from alloc dicts (i.e. analytical models)
1643 remove_nodes = []
1644 for node, alpha_dict in cost_alloc.items():
1645 if len(alpha_dict) == 0:
1646 remove_nodes.append(node)
1647 for node in remove_nodes:
1648 del cost_alloc[node]
1650 # Bar chart showing cost allocation breakdown for MF system at final iteration
1651 fig, ax = plt.subplots(figsize=(6, 5), layout='tight')
1652 width = 0.7
1653 x = np.arange(len(cost_alloc))
1654 xlabels = list(cost_alloc.keys()) # One bar for each component
1655 cmap = plt.get_cmap(cmap)
1657 for j, (node, alpha_dict) in enumerate(cost_alloc.items()):
1658 bottom = 0
1659 c_intervals = np.linspace(0, 1, len(alpha_dict))
1660 bars = [(alpha, cost, cost / total_cost) for alpha, cost in alpha_dict.items()]
1661 bars = sorted(bars, key=lambda ele: ele[2], reverse=True)
1662 for i, (alpha, cost, frac) in enumerate(bars):
1663 p = ax.bar(x[j], frac, width, color=cmap(c_intervals[i]), linewidth=1,
1664 edgecolor=[0, 0, 0], bottom=bottom)
1665 bottom += frac
1666 num_evals = round(cost / self[node].model_costs.get(alpha, 1.))
1667 if frac > text_bar_width:
1668 ax.bar_label(p, labels=[f'{alpha}, {num_evals}'], label_type='center')
1669 elif frac > arrow_bar_width:
1670 xy = (x[j] + width / 2, bottom - frac / 2) # Label smaller bars with a text off to the side
1671 ax.annotate(f'{alpha}, {num_evals}', xy, xytext=(xy[0] + 0.2, xy[1]),
1672 arrowprops={'arrowstyle': '->', 'linewidth': 1})
1673 else:
1674 pass # Don't label really small bars
1675 ax.set_xlabel('')
1676 ax.set_ylabel('Fraction of total cost')
1677 ax.set_xticks(x, xlabels)
1678 ax.set_xlim(left=-1, right=x[-1] + 1)
1680 if self.root_dir is not None:
1681 fig.savefig(Path(self.root_dir) / 'mf_allocation.pdf', bbox_inches='tight', format='pdf')
1683 return fig, ax
1685 def serialize(self, keep_components=False, serialize_args=None, serialize_kwargs=None) -> dict:
1686 """Convert to a `dict` with only standard Python types for fields.
1688 :param keep_components: whether to serialize the components as well (defaults to False)
1689 :param serialize_args: `dict` of arguments to pass to each component's serialize method
1690 :param serialize_kwargs: `dict` of keyword arguments to pass to each component's serialize method
1691 :returns: a `dict` representation of the `System` object
1692 """
1693 serialize_args = serialize_args or dict()
1694 serialize_kwargs = serialize_kwargs or dict()
1695 d = {}
1696 for key, value in self.__dict__.items():
1697 if value is not None and not key.startswith('_'):
1698 if key == 'components' and not keep_components:
1699 d[key] = [comp.serialize(keep_yaml_objects=False,
1700 serialize_args=serialize_args.get(comp.name),
1701 serialize_kwargs=serialize_kwargs.get(comp.name))
1702 for comp in value]
1703 elif key == 'train_history':
1704 if len(value) > 0:
1705 d[key] = value.serialize()
1706 else:
1707 if not isinstance(value, _builtin):
1708 self.logger.warning(f"Attribute '{key}' of type '{type(value)}' may not be a builtin "
1709 f"Python type. This may cause issues when saving/loading from file.")
1710 d[key] = value
1712 for key, value in self.model_extra.items():
1713 if isinstance(value, _builtin):
1714 d[key] = value
1716 return d
1718 @classmethod
1719 def deserialize(cls, serialized_data: dict) -> System:
1720 """Construct a `System` object from serialized data."""
1721 return cls(**serialized_data)
1723 @staticmethod
1724 def _yaml_representer(dumper: yaml.Dumper, system: System):
1725 """Serialize a `System` object to a YAML representation."""
1726 return dumper.represent_mapping(System.yaml_tag, system.serialize(keep_components=True))
1728 @staticmethod
1729 def _yaml_constructor(loader: yaml.Loader, node):
1730 """Convert the `!SystemSurrogate` tag in yaml to a [`System`][amisc.system.System] object."""
1731 if isinstance(node, yaml.SequenceNode):
1732 return [ele if isinstance(ele, System) else System.deserialize(ele)
1733 for ele in loader.construct_sequence(node, deep=True)]
1734 elif isinstance(node, yaml.MappingNode):
1735 return System.deserialize(loader.construct_mapping(node, deep=True))
1736 else:
1737 raise NotImplementedError(f'The "{System.yaml_tag}" yaml tag can only be used on a yaml sequence or '
1738 f'mapping, not a "{type(node)}".')