Coverage for src/amisc/system.py: 83%
813 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-29 21:38 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-29 21:38 +0000
1"""The `SystemSurrogate` is a framework for multidisciplinary models. It manages multiple single discipline component
2models and the connections between them. It provides a top-level interface for constructing and evaluating surrogates.
4Features
5--------
6- Manages multidisciplinary models in a graph data structure, supports feedforward and feedback connections
7- Feedback connections are solved with a fixed-point iteration (FPI) nonlinear solver
8- FPI uses Anderson acceleration and surrogate evaluations for speed-up
9- Top-level interface for training and using surrogates of each component model
10- Adaptive experimental design for choosing training data efficiently
11- Convenient testing, plotting, and performance metrics provided to assess quality of surrogates
12- Detailed logging and traceback information
13- Supports parallel execution with OpenMP and MPI protocols
14- Abstract and flexible interfacing with component models
16!!! Info "Model specification"
17 Models are callable Python wrapper functions of the form `ret = model(x, *args, **kwargs)`, where `x` is an
18 `np.ndarray` of model inputs (and `*args, **kwargs` allow passing any other required configurations for your model).
19 The return value is a Python dictionary of the form `ret = {'y': y, 'files': files, 'cost': cost, etc.}`. In the
20 return dictionary, you specify the raw model output `y` as an `np.ndarray` at a _minimum_. Optionally, you can
21 specify paths to output files and the average model cost (in seconds of cpu time), and anything else you want. Your
22 `model()` function can do anything it wants in order to go from `x` → `y`. Python has the flexibility to call
23 virtually any external codes, or to implement the function natively with `numpy`.
25!!! Info "Component specification"
26 A component adds some extra configuration around a callable `model`. These configurations are defined in a Python
27 dictionary, which we give the custom type `ComponentSpec`. At a bare _minimum_, you must specify a callable
28 `model` and its connections to other models within the multidisciplinary system. The limiting case is a single
29 component model, for which the configuration is simply `component = ComponentSpec(model)`.
30"""
31# ruff: noqa: E702
32import copy
33import datetime
34import functools
35import os
36import pickle
37import random
38import shutil
39import string
40import time
41from collections import UserDict
42from concurrent.futures import Executor
43from datetime import timezone
44from pathlib import Path
46import dill
47import matplotlib.pyplot as plt
48import networkx as nx
49import numpy as np
50from joblib import Parallel, delayed
51from joblib.externals.loky import set_loky_pickler
52from uqtils import ax_default
54from amisc import IndexSet, IndicesRV
55from amisc.component import AnalyticalSurrogate, ComponentSurrogate, SparseGridSurrogate
56from amisc.rv import BaseRV
57from amisc.utils import get_logger
60class ComponentSpec(UserDict):
61 """Provides a simple extension class of a Python dictionary, used to configure a component model.
63 !!! Info "Specifying a list of random variables"
64 The three fields: `exo_in`, `coupling_in`, and `coupling_out` fully determine how a component fits within a
65 multidisciplinary system. For each, you must specify a list of variables in the same order as the model uses
66 them. The model will use all exogenous inputs first, and then all coupling inputs. You can use a variable's
67 global integer index into the system `exo_vars` or `coupling_vars`, or you can use the `str` id of the variable
68 or the variable itself. This is summarized in the `amisc.IndicesRV` custom type.
70 !!! Example
71 Let's say you have a model:
72 ```python
73 def my_model(x, *args, **kwargs):
74 print(x.shape) # (3,), so a total of 3 inputs
75 G = 6.674e-11
76 m1 = x[0] # System-level input
77 m2 = x[1] # System-level input
78 r = x[2] # Coupling input
79 F = G*m1*m2 / r**2
80 return {'y': F}
81 ```
82 Let's say this model is part of a larger system where `m1` and `m2` are specified by the system, and `r` comes
83 from a different model that predicts the distance between two objects. You would set the configuration as:
84 ```python
85 component = ComponentSpec(my_model, exo_in=['m1', 'm2'], coupling_in=['r'], coupling_out=['F'])
86 ```
87 """
88 Options = ['model', 'name', 'exo_in', 'coupling_in', 'coupling_out', 'truth_alpha', 'max_alpha', 'max_beta',
89 'surrogate', 'model_args', 'model_kwargs', 'save_output']
91 def __init__(self, model: callable, name: str = '', exo_in: IndicesRV = None,
92 coupling_in: IndicesRV | dict[str: IndicesRV] = None, coupling_out: IndicesRV = None,
93 truth_alpha: tuple | int = (), max_alpha: tuple | int = (), max_beta: tuple | int = (),
94 surrogate: str | ComponentSurrogate = 'lagrange', model_args: tuple = (), model_kwargs: dict = None,
95 save_output: bool = False):
96 """Construct the configuration for this component model.
98 !!! Warning
99 Always specify the model at a _global_ scope, i.e. don't use `lambda` or nested functions. When saving to
100 file, only a symbolic reference to the function signature will be saved, which must be globally defined
101 when loading back from that save file.
103 :param model: the component model, must be defined in a global scope (i.e. in a module or top-level of a script)
104 :param name: the name used to identify this component model
105 :param exo_in: specifies the global, system-level (i.e. exogenous/external) inputs to this model
106 :param coupling_in: specifies the coupling inputs received from other models
107 :param coupling_out: specifies all outputs of this model (which may couple later to downstream models)
108 :param truth_alpha: the model fidelity indices to treat as a "ground truth" reference
109 :param max_alpha: the maximum model fidelity indices to allow for refinement purposes
110 :param max_beta: the maximum surrogate fidelity indices to allow for refinement purposes
111 :param surrogate: one of ('lagrange, 'analytical'), or the `ComponentSurrogate` class to use directly
112 :param model_args: optional arguments to pass to the component model
113 :param model_kwargs: optional keyword arguments to pass to the component model
114 :param save_output: whether this model will be saving outputs to file
115 """
116 d = locals()
117 d2 = {key: value for key, value in d.items() if key in ComponentSpec.Options}
118 super().__init__(d2)
120 def __setitem__(self, key, value):
121 if key in ComponentSpec.Options:
122 super().__setitem__(key, value)
123 else:
124 raise ValueError(f'"{key}" is not applicable for a ComponentSpec. Try one of {ComponentSpec.Options}.')
126 def __delitem__(self, key):
127 raise TypeError("Not allowed to delete items from a ComponentSpec.")
130class SystemSurrogate:
131 """Multidisciplinary (MD) surrogate framework top-level class.
133 !!! Note "Accessing individual components"
134 The `ComponentSurrogate` objects that compose `SystemSurrogate` are internally stored in the `self.graph.nodes`
135 data structure. You can access them with `get_component(comp_name)`.
137 :ivar exo_vars: global list of exogenous/external inputs for the MD system
138 :ivar coupling_vars: global list of coupling variables for the MD system (including all system-level outputs)
139 :ivar refine_level: the total number of refinement steps that have been made
140 :ivar build_metrics: contains data that summarizes surrogate training progress
141 :ivar root_dir: root directory where all surrogate build products are saved to file
142 :ivar log_file: log file where all logs are written to by default
143 :ivar executor: manages parallel execution for the system
144 :ivar graph: the internal graph data structure of the MD system
146 :vartype exo_vars: list[BaseRV]
147 :vartype coupling_vars: list[BaseRV]
148 :vartype refine_level: int
149 :vartype build_metrics: dict
150 :vartype root_dir: str
151 :vartype log_file: str
152 :vartype executor: Executor
153 :vartype graph: nx.DiGraph
154 """
156 def __init__(self, components: list[ComponentSpec] | ComponentSpec, exo_vars: list[BaseRV] | BaseRV,
157 coupling_vars: list[BaseRV] | BaseRV, est_bds: int = 0, save_dir: str | Path = None,
158 executor: Executor = None, stdout: bool = True, init_surr: bool = True, logger_name: str = None):
159 """Construct the MD system surrogate.
161 !!! Warning
162 Component models should always use coupling variables in the order they appear in the system-level
163 `coupling_vars`.
165 :param components: list of components in the MD system (using the ComponentSpec class)
166 :param exo_vars: list of system-level exogenous/external inputs
167 :param coupling_vars: list of all coupling variables (including all system-level outputs)
168 :param est_bds: number of samples to estimate coupling variable bounds, do nothing if 0
169 :param save_dir: root directory for all build products (.log, .pkl, .json, etc.), won't save if None
170 :param executor: an instance of a `concurrent.futures.Executor`, use to iterate new candidates in parallel
171 :param stdout: whether to log to console
172 :param init_surr: whether to initialize the surrogate immediately when constructing
173 :param logger_name: the name of the logger to use, if None then uses class name by default
174 """
175 # Setup root save directory
176 if save_dir is not None:
177 timestamp = datetime.datetime.now(tz=timezone.utc).isoformat().split('.')[0].replace(':', '.')
178 save_dir = Path(save_dir) / ('amisc_' + timestamp)
179 os.mkdir(save_dir)
180 self.root_dir = None
181 self.log_file = None
182 self.logger = None
183 self.executor = executor
184 self.graph = nx.DiGraph()
185 self.set_root_directory(save_dir, stdout=stdout, logger_name=logger_name)
187 # Store system info in a directed graph data structure
188 self.exo_vars = copy.deepcopy(exo_vars) if isinstance(exo_vars, list) else [exo_vars]
189 self.x_vars = self.exo_vars # Create an alias to be consistent with components
190 self.coupling_vars = copy.deepcopy(coupling_vars) if isinstance(coupling_vars, list) else [coupling_vars]
191 self.refine_level = 0
192 self.build_metrics = dict() # Save refinement error metrics during training
194 # Construct graph nodes
195 components = [components] if not isinstance(components, list) else components
196 for k, comp in enumerate(components):
197 if comp['name'] == '':
198 comp['name'] = f'Component {k}'
199 Nk = len(components)
200 nodes = {comp['name']: comp for comp in components} # work-around since self.graph.nodes is not built yet
201 for k in range(Nk):
202 # Add the component as a str() node, with attributes specifying details of the surrogate
203 comp_dict = components[k]
204 indices, surr = self._build_component(comp_dict, nodes=nodes)
205 self.graph.add_node(comp_dict['name'], surrogate=surr, is_computed=False, **indices)
207 # Connect all neighbor nodes
208 for node, node_obj in self.graph.nodes.items():
209 for neighbor in node_obj['local_in']:
210 self.graph.add_edge(neighbor, node)
212 self.set_logger(logger_name, stdout=stdout) # Need to update component loggers
214 # Estimate coupling variable bounds
215 if est_bds > 0:
216 self._estimate_coupling_bds(est_bds)
218 # Init system with most coarse fidelity indices in each component
219 if init_surr:
220 self.init_system()
221 self._save_progress('sys_init.pkl')
223 def _build_component(self, component: ComponentSpec, nodes=None) -> tuple[dict, ComponentSurrogate]:
224 """Build and return a `ComponentSurrogate` from a `dict` that describes the component model/connections.
226 :param component: specifies details of a component (see `ComponentSpec`)
227 :param nodes: `dict` of `{node: node_attributes}`, defaults to `self.graph.nodes`
228 :returns: `connections, surr`: a `dict` of all connection indices and the `ComponentSurrogate` object
229 """
230 nodes = self.graph.nodes if nodes is None else nodes
231 kwargs = component.get('model_kwargs', {})
232 kwargs = {} if kwargs is None else kwargs
234 # Set up defaults if this is a trivial one component system
235 exo_in = component.get('exo_in', None)
236 coupling_in = component.get('coupling_in', None)
237 coupling_out = component.get('coupling_out', None)
238 if len(nodes) == 1:
239 exo_in = list(np.arange(0, len(self.exo_vars)))
240 coupling_in = []
241 coupling_out = list(np.arange(0, len(self.coupling_vars)))
242 else:
243 exo_in = [] if exo_in is None else exo_in
244 coupling_in = [] if coupling_in is None else coupling_in
245 coupling_out = [] if coupling_out is None else coupling_out
246 exo_in = [exo_in] if not isinstance(exo_in, list) else exo_in
247 coupling_in = [coupling_in] if not isinstance(coupling_in, list | dict) else coupling_in
248 component['coupling_out'] = [coupling_out] if not isinstance(coupling_out, list) else coupling_out
250 # Raise an error if all inputs or all outputs are empty
251 if len(exo_in) + len(coupling_in) == 0:
252 raise ValueError(f'Component {component["name"]} has no inputs! Please specify inputs in '
253 f'"exo_in" or "coupling_in".')
254 if len(component['coupling_out']) == 0:
255 raise ValueError(f'Component {component["name"]} has no outputs! Please specify outputs in '
256 f'"coupling_out".')
258 # Get exogenous input indices (might already be a list of ints, otherwise convert list of vars to indices)
259 if len(exo_in) > 0:
260 if isinstance(exo_in[0], str | BaseRV):
261 exo_in = [self.exo_vars.index(var) for var in exo_in]
263 # Get global coupling output indices for all nodes (convert list of vars to list of indices if necessary)
264 global_out = {}
265 for node, node_obj in nodes.items():
266 node_use = node_obj if node != component.get('name') else component
267 coupling_out = node_use.get('coupling_out', None)
268 coupling_out = [] if coupling_out is None else coupling_out
269 coupling_out = [coupling_out] if not isinstance(coupling_out, list) else coupling_out
270 global_out[node] = [self.coupling_vars.index(var) for var in coupling_out] if isinstance(
271 coupling_out[0], str | BaseRV) else coupling_out
273 # Refactor coupling inputs into both local and global index formats
274 local_in = dict() # e.g. {'Cathode': [0, 1, 2], 'Thruster': [0,], etc...}
275 global_in = list() # e.g. [0, 2, 4, 5, 6]
276 if isinstance(coupling_in, dict):
277 # If already a dict, get local connection indices from each neighbor
278 for node, connections in coupling_in.items():
279 conn_list = [connections] if not isinstance(connections, list) else connections
280 if isinstance(conn_list[0], str | BaseRV):
281 global_ind = [self.coupling_vars.index(var) for var in conn_list]
282 local_in[node] = sorted([global_out[node].index(i) for i in global_ind])
283 else:
284 local_in[node] = sorted(conn_list)
286 # Convert to global coupling indices
287 for node, local_idx in local_in.items():
288 global_in.extend([global_out[node][i] for i in local_idx])
289 global_in = sorted(global_in)
290 else:
291 # Otherwise, convert a list of global indices or vars into a dict of local indices
292 if len(coupling_in) > 0:
293 if isinstance(coupling_in[0], str | BaseRV):
294 coupling_in = [self.coupling_vars.index(var) for var in coupling_in]
295 global_in = sorted(coupling_in)
296 for node, node_obj in nodes.items():
297 if node != component['name']:
298 l = list() # noqa: E741
299 for i in global_in:
300 try:
301 l.append(global_out[node].index(i))
302 except ValueError:
303 pass
304 if l:
305 local_in[node] = sorted(l)
307 # Store all connection indices for this component
308 connections = dict(exo_in=exo_in, local_in=local_in, global_in=global_in,
309 global_out=global_out.get(component.get('name')))
311 # Set up a component output save directory
312 if component.get('save_output', False) and self.root_dir is not None:
313 output_dir = str((Path(self.root_dir) / 'components' / component['name']).resolve())
314 if not Path(output_dir).is_dir():
315 os.mkdir(output_dir)
316 kwargs['output_dir'] = output_dir
317 else:
318 if kwargs.get('output_dir', None) is not None:
319 kwargs['output_dir'] = None
321 # Initialize a new component surrogate
322 surr_class = component.get('surrogate', 'lagrange')
323 if isinstance(surr_class, str):
324 match surr_class:
325 case 'lagrange':
326 surr_class = SparseGridSurrogate
327 case 'analytical':
328 surr_class = AnalyticalSurrogate
329 case other:
330 raise NotImplementedError(f"Surrogate type '{other}' is not known at this time.")
332 # Check for an override of model fidelity indices (to enable just single-fidelity evaluation)
333 if kwargs.get('hf_override', False):
334 truth_alpha, max_alpha = (), ()
335 kwargs['hf_override'] = component['truth_alpha'] # Pass in the truth alpha indices as a kwarg to model
336 else:
337 truth_alpha, max_alpha = component['truth_alpha'], component['max_alpha']
338 max_beta = component.get('max_beta')
339 truth_alpha = (truth_alpha,) if isinstance(truth_alpha, int) else truth_alpha
340 max_alpha = (max_alpha,) if isinstance(max_alpha, int) else max_alpha
341 max_beta = (max_beta,) if isinstance(max_beta, int) else max_beta
343 # Assumes input ordering is exogenous vars + sorted coupling vars
344 x_vars = [self.exo_vars[i] for i in exo_in] + [self.coupling_vars[i] for i in global_in]
345 surr = surr_class(x_vars, component['model'], truth_alpha=truth_alpha, max_alpha=max_alpha,
346 max_beta=max_beta, executor=self.executor, log_file=self.log_file,
347 model_args=component.get('model_args'), model_kwargs=kwargs)
348 return connections, surr
350 def swap_component(self, component: ComponentSpec, exo_add: BaseRV | list[BaseRV] = None,
351 exo_remove: IndicesRV = None, qoi_add: BaseRV | list[BaseRV] = None,
352 qoi_remove: IndicesRV = None):
353 """Swap a new component into the system, updating all connections/inputs.
355 !!! Warning "Beta feature, proceed with caution"
356 If you are swapping a new component in, you cannot remove any inputs that are expected by other components,
357 including the coupling variables output by the current model.
359 :param component: specs of new component model (must replace an existing component with matching `name`)
360 :param exo_add: variables to add to system exogenous inputs (will be appended to end)
361 :param exo_remove: indices of system exogenous inputs to delete (can't be shared by other components)
362 :param qoi_add: system output QoIs to add (will be appended to end of `coupling_vars`)
363 :param qoi_remove: indices of system `coupling_vars` to delete (can't be shared by other components)
364 """
365 # Delete system exogenous inputs
366 if exo_remove is None:
367 exo_remove = []
368 exo_remove = [exo_remove] if not isinstance(exo_remove, list) else exo_remove
369 exo_remove = [self.exo_vars.index(var) for var in exo_remove] if exo_remove and isinstance(
370 exo_remove[0], str | BaseRV) else exo_remove
372 exo_remove = sorted(exo_remove)
373 for j, exo_var_idx in enumerate(exo_remove):
374 # Adjust exogenous indices for all components to account for deleted system inputs
375 for node, node_obj in self.graph.nodes.items():
376 if node != component['name']:
377 for i, idx in enumerate(node_obj['exo_in']):
378 if idx == exo_var_idx:
379 error_msg = f"Can't delete system exogenous input at idx {exo_var_idx}, since it is " \
380 f"shared by component '{node}'."
381 self.logger.error(error_msg)
382 raise ValueError(error_msg)
383 if idx > exo_var_idx:
384 node_obj['exo_in'][i] -= 1
386 # Need to update the remaining delete indices by -1 to account for each sequential deletion
387 del self.exo_vars[exo_var_idx]
388 for i in range(j+1, len(exo_remove)):
389 exo_remove[i] -= 1
391 # Append any new exogenous inputs to the end
392 if exo_add is not None:
393 exo_add = [exo_add] if not isinstance(exo_add, list) else exo_add
394 self.exo_vars.extend(exo_add)
396 # Delete system qoi outputs (if not shared by other components)
397 qoi_remove = sorted(self._get_qoi_ind(qoi_remove))
398 for j, qoi_idx in enumerate(qoi_remove):
399 # Adjust coupling indices for all components to account for deleted system outputs
400 for node, node_obj in self.graph.nodes.items():
401 if node != component['name']:
402 for i, idx in enumerate(node_obj['global_in']):
403 if idx == qoi_idx:
404 error_msg = f"Can't delete system QoI at idx {qoi_idx}, since it is an input to " \
405 f"component '{node}'."
406 self.logger.error(error_msg)
407 raise ValueError(error_msg)
408 if idx > qoi_idx:
409 node_obj['global_in'][i] -= 1
411 for i, idx in enumerate(node_obj['global_out']):
412 if idx > qoi_idx:
413 node_obj['global_out'][i] -= 1
415 # Need to update the remaining delete indices by -1 to account for each sequential deletion
416 del self.coupling_vars[qoi_idx]
417 for i in range(j+1, len(qoi_remove)):
418 qoi_remove[i] -= 1
420 # Append any new system QoI outputs to the end
421 if qoi_add is not None:
422 qoi_add = [qoi_add] if not isinstance(qoi_add, list) else qoi_add
423 self.coupling_vars.extend(qoi_add)
425 # Build and initialize the new component surrogate
426 indices, surr = self._build_component(component)
427 surr.init_coarse()
429 # Make changes to adj matrix if coupling inputs changed
430 prev_neighbors = list(self.graph.nodes[component['name']]['local_in'].keys())
431 new_neighbors = list(indices['local_in'].keys())
432 for neighbor in new_neighbors:
433 if neighbor not in prev_neighbors:
434 self.graph.add_edge(neighbor, component['name'])
435 else:
436 prev_neighbors.remove(neighbor)
437 for neighbor in prev_neighbors:
438 self.graph.remove_edge(neighbor, component['name'])
440 self.logger.info(f"Swapped component '{component['name']}'.")
441 nx.set_node_attributes(self.graph, {component['name']: {'exo_in': indices['exo_in'], 'local_in':
442 indices['local_in'], 'global_in': indices['global_in'],
443 'global_out': indices['global_out'],
444 'surrogate': surr, 'is_computed': False}})
446 def insert_component(self, component: ComponentSpec, exo_add: BaseRV | list[BaseRV] = None,
447 qoi_add: BaseRV | list[BaseRV] = None):
448 """Insert a new component into the system.
450 :param component: specs of new component model
451 :param exo_add: variables to add to system exogenous inputs (will be appended to end of `exo_vars`)
452 :param qoi_add: system output QoIs to add (will be appended to end of `coupling_vars`)
453 """
454 if exo_add is not None:
455 exo_add = [exo_add] if not isinstance(exo_add, list) else exo_add
456 self.exo_vars.extend(exo_add)
457 if qoi_add is not None:
458 qoi_add = [qoi_add] if not isinstance(qoi_add, list) else qoi_add
459 self.coupling_vars.extend(qoi_add)
461 indices, surr = self._build_component(component)
462 surr.init_coarse()
463 self.graph.add_node(component['name'], surrogate=surr, is_computed=False, **indices)
465 # Add graph edges
466 neighbors = list(indices['local_in'].keys())
467 for neighbor in neighbors:
468 self.graph.add_edge(neighbor, component['name'])
469 self.logger.info(f"Inserted component '{component['name']}'.")
471 def _save_on_error(func):
472 """Gracefully exit and save `SystemSurrogate` on any errors."""
473 @functools.wraps(func)
474 def wrap(self, *args, **kwargs):
475 try:
476 return func(self, *args, **kwargs)
477 except:
478 self._save_progress('sys_error.pkl')
479 self.logger.critical(f'An error occurred during execution of {func.__name__}. Saving '
480 f'SystemSurrogate object to sys_error.pkl', exc_info=True)
481 self.logger.info(f'Final system surrogate on exit: \n {self}')
482 raise
483 return wrap
484 _save_on_error = staticmethod(_save_on_error)
486 @_save_on_error
487 def init_system(self):
488 """Add the coarsest multi-index to each component surrogate."""
489 self._print_title_str('Initializing all component surrogates')
490 for node, node_obj in self.graph.nodes.items():
491 node_obj['surrogate'].init_coarse()
492 # for alpha, beta in list(node_obj['surrogate'].candidate_set):
493 # # Add one refinement in each input dimension to initialize
494 # node_obj['surrogate'].activate_index(alpha, beta)
495 self.logger.info(f"Initialized component '{node}'.")
497 @_save_on_error
498 def fit(self, qoi_ind: IndicesRV = None, num_refine: int = 100, max_iter: int = 20, max_tol: float = 1e-3,
499 max_runtime: float = 1, save_interval: int = 0, update_bounds: bool = True, test_set: dict = None,
500 n_jobs: int = 1):
501 """Train the system surrogate adaptively by iterative refinement until an end condition is met.
503 :param qoi_ind: list of system QoI variables to focus refinement on, use all QoI if not specified
504 :param num_refine: number of samples of exogenous inputs to compute error indicators on
505 :param max_iter: the maximum number of refinement steps to take
506 :param max_tol: the max allowable value in relative L2 error to achieve
507 :param max_runtime: the maximum wall clock time (hr) to run refinement for (will go until all models finish)
508 :param save_interval: number of refinement steps between each progress save, none if 0
509 :param update_bounds: whether to continuously update coupling variable bounds during refinement
510 :param test_set: `dict(xt=(Nt, x_dim), yt=(Nt, y_dim)` to show convergence of surrogate to the truth model
511 :param n_jobs: number of cpu workers for computing error indicators (on master MPI task), 1=sequential
512 """
513 qoi_ind = self._get_qoi_ind(qoi_ind)
514 Nqoi = len(qoi_ind)
515 max_iter = self.refine_level + max_iter
516 curr_error = np.inf
517 t_start = time.time()
518 test_stats, xt, yt, t_fig, t_ax = None, None, None, None, None
520 # Record of (error indicator, component, alpha, beta, num_evals, total added cost (s)) for each iteration
521 train_record = self.build_metrics.get('train_record', [])
522 if test_set is not None:
523 xt, yt = test_set['xt'], test_set['yt']
524 xt, yt = self.build_metrics.get('xt', xt), self.build_metrics.get('yt', yt) # Overrides test set param
526 # Track convergence progress on a test set and on the max error indicator
527 err_fig, err_ax = plt.subplots()
528 if xt is not None and yt is not None:
529 self.build_metrics['xt'] = xt
530 self.build_metrics['yt'] = yt
531 if self.build_metrics.get('test_stats') is not None:
532 test_stats = self.build_metrics.get('test_stats')
533 else:
534 # Get initial perf metrics, (2, Nqoi)
535 test_stats = np.expand_dims(self.get_test_metrics(xt, yt, qoi_ind=qoi_ind), axis=0)
536 t_fig, t_ax = plt.subplots(1, Nqoi) if Nqoi > 1 else plt.subplots()
538 # Set up a parallel pool of workers, sequential if n_jobs=1
539 with Parallel(n_jobs=n_jobs, verbose=0) as ppool:
540 while True:
541 # Check all end conditions
542 if self.refine_level >= max_iter:
543 self._print_title_str(f'Termination criteria reached: Max iteration {self.refine_level}/{max_iter}')
544 break
545 if curr_error == -np.inf:
546 self._print_title_str('Termination criteria reached: No candidates left to refine')
547 break
548 if curr_error < max_tol:
549 self._print_title_str(f'Termination criteria reached: L2 error {curr_error} < tol {max_tol}')
550 break
551 if ((time.time() - t_start)/3600.0) >= max_runtime:
552 actual = datetime.timedelta(seconds=time.time()-t_start)
553 target = datetime.timedelta(seconds=max_runtime*3600)
554 self._print_title_str(f'Termination criteria reached: runtime {str(actual)} > {str(target)}')
555 break
557 # Refine surrogate and save progress
558 refine_res = self.refine(qoi_ind=qoi_ind, num_refine=num_refine, update_bounds=update_bounds,
559 ppool=ppool)
560 curr_error = refine_res[0]
561 if save_interval > 0 and self.refine_level % save_interval == 0:
562 self._save_progress(f'sys_iter_{self.refine_level}.pkl')
564 # Plot progress of error indicator
565 train_record.append(refine_res)
566 error_record = [res[0] for res in train_record]
567 self.build_metrics['train_record'] = train_record
568 err_ax.clear(); err_ax.grid(); err_ax.plot(error_record, '-k')
569 ax_default(err_ax, 'Iteration', r'Relative $L_2$ error indicator', legend=False)
570 err_ax.set_yscale('log')
571 if self.root_dir is not None:
572 err_fig.savefig(str(Path(self.root_dir) / 'error_indicator.png'), dpi=300, format='png')
574 # Plot progress on test set
575 if xt is not None and yt is not None:
576 stats = self.get_test_metrics(xt, yt, qoi_ind=qoi_ind)
577 test_stats = np.concatenate((test_stats, stats[np.newaxis, ...]), axis=0)
578 for i in range(Nqoi):
579 ax = t_ax if Nqoi == 1 else t_ax[i]
580 ax.clear(); ax.grid(); ax.set_yscale('log')
581 ax.plot(test_stats[:, 1, i], '-k')
582 ax.set_title(self.coupling_vars[qoi_ind[i]].to_tex(units=True))
583 ax_default(ax, 'Iteration', r'Relative $L_2$ error', legend=False)
584 t_fig.set_size_inches(3.5*Nqoi, 3.5)
585 t_fig.tight_layout()
586 if self.root_dir is not None:
587 t_fig.savefig(str(Path(self.root_dir) / 'test_set.png'), dpi=300, format='png')
588 self.build_metrics['test_stats'] = test_stats
590 self._save_progress('sys_final.pkl')
591 self.logger.info(f'Final system surrogate: \n {self}')
593 def get_allocation(self, idx: int = None):
594 """Get a breakdown of cost allocation up to a certain iteration number during training (starting at 1).
596 :param idx: the iteration number to get allocation results for (defaults to last refinement step)
597 :returns: `cost_alloc, offline_alloc, cost_cum` - the cost alloc per node/fidelity and cumulative training cost
598 """
599 if idx is None:
600 idx = self.refine_level
601 if idx > self.refine_level:
602 raise ValueError(f'Specified index: {idx} is greater than the max training level of {self.refine_level}')
604 cost_alloc = dict() # Cost allocation per node and model fidelity
605 cost_cum = [0.0] # Cumulative cost allocation during training
607 # Add initialization costs for each node
608 for node, node_obj in self.graph.nodes.items():
609 surr = node_obj['surrogate']
610 base_alpha = (0,) * len(surr.truth_alpha)
611 base_beta = (0,) * (len(surr.max_refine) - len(surr.truth_alpha))
612 base_cost = surr.get_cost(base_alpha, base_beta)
613 cost_alloc[node] = dict()
614 if base_cost > 0:
615 cost_alloc[node][str(base_alpha)] = np.array([1, float(base_cost)])
616 cost_cum[0] += float(base_cost)
618 # Add cumulative training costs
619 for i in range(idx):
620 err_indicator, node, alpha, beta, num_evals, cost = self.build_metrics['train_record'][i]
621 if cost_alloc[node].get(str(alpha), None) is None:
622 cost_alloc[node][str(alpha)] = np.zeros(2) # (num model evals, total cpu_time cost)
623 cost_alloc[node][str(alpha)] += [round(num_evals), float(cost)]
624 cost_cum.append(float(cost))
626 # Get summary of total offline costs spent building search candidates (i.e. training overhead)
627 offline_alloc = dict()
628 for node, node_obj in self.graph.nodes.items():
629 surr = node_obj['surrogate']
630 offline_alloc[node] = dict()
631 for alpha, beta in surr.candidate_set:
632 if offline_alloc[node].get(str(alpha), None) is None:
633 offline_alloc[node][str(alpha)] = np.zeros(2) # (num model evals, total cpu_time cost)
634 added_cost = surr.get_cost(alpha, beta)
635 base_cost = surr.get_sub_surrogate(alpha, beta).model_cost
636 offline_alloc[node][str(alpha)] += [round(added_cost/base_cost), float(added_cost)]
638 return cost_alloc, offline_alloc, np.cumsum(cost_cum)
640 def get_test_metrics(self, xt: np.ndarray, yt: np.ndarray, qoi_ind: IndicesRV = None,
641 training: bool = True) -> np.ndarray:
642 """Get relative L2 error metric over a test set.
644 :param xt: `(Nt, x_dim)` random test set of inputs
645 :param yt: `(Nt, y_dim)` random test set outputs
646 :param qoi_ind: list of indices of QoIs to get metrics for
647 :param training: whether to evaluate the surrogate in training or evaluation mode
648 :returns: `stats` - `(2, Nqoi)` array → `[num_candidates, rel_L2_error]` for each QoI
649 """
650 qoi_ind = self._get_qoi_ind(qoi_ind)
651 ysurr = self(xt, training=training)
652 ysurr = ysurr[:, qoi_ind]
653 yt = yt[:, qoi_ind]
654 with np.errstate(divide='ignore', invalid='ignore'):
655 rel_l2_err = np.sqrt(np.mean((yt - ysurr) ** 2, axis=0)) / np.sqrt(np.mean(yt ** 2, axis=0))
656 rel_l2_err = np.nan_to_num(rel_l2_err, posinf=np.nan, neginf=np.nan, nan=np.nan)
657 num_cands = 0
658 for node, node_obj in self.graph.nodes.items():
659 num_cands += len(node_obj['surrogate'].index_set) + len(node_obj['surrogate'].candidate_set)
661 # Get test stats for each QoI
662 stats = np.zeros((2, yt.shape[-1]))
663 self.logger.debug(f'{"QoI idx": >10} {"Iteration": >10} {"len(I_k)": >10} {"Relative L2": >15}')
664 for i in range(yt.shape[-1]):
665 stats[:, i] = np.array([num_cands, rel_l2_err[i]])
666 self.logger.debug(f'{i: 10d} {self.refine_level: 10d} {num_cands: 10d} {rel_l2_err[i]: 15.5f}')
668 return stats
670 def _get_qoi_ind(self, qoi_ind: IndicesRV) -> list[int]:
671 """Small helper to make sure QoI indices are a list of integers."""
672 if qoi_ind is None:
673 qoi_ind = list(np.arange(0, len(self.coupling_vars)))
674 qoi_ind = [qoi_ind] if not isinstance(qoi_ind, list) else qoi_ind
675 qoi_ind = [self.coupling_vars.index(var) for var in qoi_ind] if qoi_ind and isinstance(
676 qoi_ind[0], str | BaseRV) else qoi_ind
678 return qoi_ind
680 def refine(self, qoi_ind: IndicesRV = None, num_refine: int = 100, update_bounds: bool = True,
681 ppool: Parallel = None) -> tuple:
682 """Find and refine the component surrogate with the largest error on system-level QoI.
684 :param qoi_ind: indices of system QoI to focus surrogate refinement on, use all QoI if not specified
685 :param num_refine: number of samples of exogenous inputs to compute error indicators on
686 :param update_bounds: whether to continuously update coupling variable bounds
687 :param ppool: a `Parallel` instance from `joblib` to compute error indicators in parallel, None=sequential
688 :returns refine_res: a tuple of `(error_indicator, component, node_star, alpha_star, beta_star, N, cost)`
689 indicating the chosen candidate index and incurred cost
690 """
691 self._print_title_str(f'Refining system surrogate: iteration {self.refine_level + 1}')
692 set_loky_pickler('dill') # Dill can serialize 'self' for parallel workers
693 temp_exc = self.executor # It can't serialize an executor though, so must save this temporarily
694 self.set_executor(None)
695 qoi_ind = self._get_qoi_ind(qoi_ind)
697 # Compute entire integrated-surrogate on a random test set for global system QoI error estimation
698 x_exo = self.sample_inputs((num_refine,))
699 y_curr = self(x_exo, training=True)
700 y_min, y_max = None, None
701 if update_bounds:
702 y_min = np.min(y_curr, axis=0, keepdims=True) # (1, ydim)
703 y_max = np.max(y_curr, axis=0, keepdims=True) # (1, ydim)
705 # Find the candidate surrogate with the largest error indicator
706 error_max, error_indicator = -np.inf, -np.inf
707 node_star, alpha_star, beta_star, l2_star, cost_star = None, None, None, -np.inf, 0
708 for node, node_obj in self.graph.nodes.items():
709 self.logger.info(f"Estimating error for component '{node}'...")
710 candidates = node_obj['surrogate'].candidate_set.copy()
712 def compute_error(alpha, beta):
713 # Helper function for computing error indicators for a given candidate (alpha, beta)
714 index_set = node_obj['surrogate'].index_set.copy()
715 index_set.append((alpha, beta))
716 y_cand = self(x_exo, training=True, index_set={node: index_set})
717 ymin = np.min(y_cand, axis=0, keepdims=True)
718 ymax = np.max(y_cand, axis=0, keepdims=True)
719 error = y_cand[:, qoi_ind] - y_curr[:, qoi_ind]
720 rel_l2 = np.sqrt(np.nanmean(error ** 2, axis=0)) / np.sqrt(np.nanmean(y_cand[:, qoi_ind] ** 2, axis=0))
721 rel_l2 = np.nan_to_num(rel_l2, nan=np.nan, posinf=np.nan, neginf=np.nan)
722 delta_error = np.nanmax(rel_l2) # Max relative L2 error over all system QoIs
723 delta_work = max(1, node_obj['surrogate'].get_cost(alpha, beta)) # Cpu time (s)
725 return ymin, ymax, delta_error, delta_work
727 if len(candidates) > 0:
728 ret = ppool(delayed(compute_error)(alpha, beta) for alpha, beta in candidates) if ppool is not None \
729 else [compute_error(alpha, beta) for alpha, beta in candidates]
731 for i, (ymin, ymax, d_error, d_work) in enumerate(ret):
732 if update_bounds:
733 y_min = np.min(np.concatenate((y_min, ymin), axis=0), axis=0, keepdims=True)
734 y_max = np.max(np.concatenate((y_max, ymax), axis=0), axis=0, keepdims=True)
735 alpha, beta = candidates[i]
736 error_indicator = d_error / d_work
737 self.logger.info(f"Candidate multi-index: {(alpha, beta)}. L2 error: {d_error}. Error indicator: "
738 f"{error_indicator}.")
740 if error_indicator > error_max:
741 error_max = error_indicator
742 node_star, alpha_star, beta_star, l2_star, cost_star = node, alpha, beta, d_error, d_work
743 else:
744 self.logger.info(f"Component '{node}' has no available candidates left!")
746 # Update all coupling variable ranges
747 if update_bounds:
748 for i in range(y_curr.shape[-1]):
749 self._update_coupling_bds(i, (y_min[0, i], y_max[0, i]))
751 # Add the chosen multi-index to the chosen component
752 self.set_executor(temp_exc)
753 if node_star is not None:
754 self.logger.info(f"Candidate multi-index {(alpha_star, beta_star)} chosen for component '{node_star}'")
755 self.graph.nodes[node_star]['surrogate'].activate_index(alpha_star, beta_star)
756 self.refine_level += 1
757 num_evals = round(cost_star / self[node_star].get_sub_surrogate(alpha_star, beta_star).model_cost)
758 else:
759 self.logger.info(f"No candidates left for refinement, iteration: {self.refine_level}")
760 num_evals = 0
762 return l2_star, node_star, alpha_star, beta_star, num_evals, cost_star
764 def predict(self, x: np.ndarray | float, max_fpi_iter: int = 100, anderson_mem: int = 10, fpi_tol: float = 1e-10,
765 use_model: str | tuple | dict = None, model_dir: str | Path = None, verbose: bool = False,
766 training: bool = False, index_set: dict[str: IndexSet] = None, qoi_ind: IndicesRV = None,
767 ppool=None) -> np.ndarray:
768 """Evaluate the system surrogate at exogenous inputs `x`.
770 !!! Warning
771 You can use this function to predict outputs for your MD system using the full-order models rather than the
772 surrogate, by specifying `use_model`. This is convenient because `SystemSurrogate` manages all the
773 coupled information flow between models automatically. However, it is *highly* recommended to not use
774 the full model if your system contains feedback loops. The FPI nonlinear solver would be infeasible using
775 anything more computationally demanding than the surrogate.
777 :param x: `(..., x_dim)` the points to get surrogate predictions for
778 :param max_fpi_iter: the limit on convergence for the fixed-point iteration routine
779 :param anderson_mem: hyperparameter for tuning the convergence of FPI with anderson acceleration
780 :param fpi_tol: tolerance limit for convergence of fixed-point iteration
781 :param use_model: 'best'=highest-fidelity, 'worst'=lowest-fidelity, tuple=specific fidelity, None=surrogate,
782 specify a `dict` of the above to assign different model fidelities for diff components
783 :param model_dir: directory to save model outputs if `use_model` is specified
784 :param verbose: whether to print out iteration progress during execution
785 :param training: whether to call the system surrogate in training or evaluation mode, ignored if `use_model`
786 :param index_set: `dict(node=[indices])` to override default index set for a node (only useful for parallel)
787 :param qoi_ind: list of qoi indices to return, defaults to returning all system `coupling_vars`
788 :param ppool: a joblib `Parallel` instance to pass to each component to loop over multi-indices in parallel
789 :returns y: `(..., y_dim)` the surrogate approximation of the system QoIs
790 """
791 # Allocate space for all system outputs (just save all coupling vars)
792 x = np.atleast_1d(x)
793 ydim = len(self.coupling_vars)
794 y = np.zeros(x.shape[:-1] + (ydim,), dtype=x.dtype)
795 valid_idx = ~np.isnan(x[..., 0]) # Keep track of valid samples (set to False if FPI fails)
796 t1 = 0
797 output_dir = None
798 qoi_ind = self._get_qoi_ind(qoi_ind)
799 is_computed = np.full(ydim, False)
801 # Interpret which model fidelities to use for each component (if specified)
802 if use_model is not None:
803 if not isinstance(use_model, dict):
804 use_model = {node: use_model for node in self.graph.nodes} # use same for each component
805 else:
806 use_model = {node: None for node in self.graph.nodes}
807 use_model = {node: use_model.get(node, None) for node in self.graph.nodes}
809 # Initialize all components
810 for node, node_obj in self.graph.nodes.items():
811 node_obj['is_computed'] = False
813 # Convert system into DAG by grouping strongly-connected-components
814 dag = nx.condensation(self.graph)
816 # Compute component models in topological order
817 for supernode in nx.topological_sort(dag):
818 if np.all(is_computed[qoi_ind]):
819 break # Exit early if all qois of interest are computed
821 scc = [n for n in dag.nodes[supernode]['members']]
823 # Compute single component feedforward output (no FPI needed)
824 if len(scc) == 1:
825 if verbose:
826 self.logger.info(f"Running component '{scc[0]}'...")
827 t1 = time.time()
829 # Gather inputs
830 node_obj = self.graph.nodes[scc[0]]
831 exo_inputs = x[..., node_obj['exo_in']]
832 # for comp_name in node_obj['local_in']:
833 # assert self.graph.nodes[comp_name]['is_computed']
834 coupling_inputs = y[..., node_obj['global_in']]
835 comp_input = np.concatenate((exo_inputs, coupling_inputs), axis=-1) # (..., xdim)
837 # Compute outputs
838 indices = index_set.get(scc[0], None) if index_set is not None else None
839 if model_dir is not None:
840 output_dir = Path(model_dir) / scc[0]
841 os.mkdir(output_dir)
842 comp_output = node_obj['surrogate'](comp_input[valid_idx, :], use_model=use_model.get(scc[0]),
843 model_dir=output_dir, training=training, index_set=indices,
844 ppool=ppool)
845 for local_i, global_i in enumerate(node_obj['global_out']):
846 y[valid_idx, global_i] = comp_output[..., local_i]
847 is_computed[global_i] = True
848 node_obj['is_computed'] = True
850 if verbose:
851 self.logger.info(f"Component '{scc[0]}' completed. Runtime: {time.time() - t1} s")
853 # Handle FPI for SCCs with more than one component
854 else:
855 # Set the initial guess for all coupling vars (middle of domain)
856 coupling_bds = [rv.bounds() for rv in self.coupling_vars]
857 x_couple = np.array([(bds[0] + bds[1]) / 2 for bds in coupling_bds]).astype(x.dtype)
858 x_couple = np.broadcast_to(x_couple, x.shape[:-1] + x_couple.shape).copy()
860 adj_nodes = []
861 fpi_idx = set()
862 for node in scc:
863 for comp_name, local_idx in self.graph.nodes[node]['local_in'].items():
864 # Track the global idx of all coupling vars that need FPI
865 if comp_name in scc:
866 fpi_idx.update([self.graph.nodes[comp_name]['global_out'][idx] for idx in local_idx])
868 # Override coupling vars from components outside the scc (should already be computed)
869 if comp_name not in scc and comp_name not in adj_nodes:
870 # assert self.graph.nodes[comp_name]['is_computed']
871 global_idx = self.graph.nodes[comp_name]['global_out']
872 x_couple[..., global_idx] = y[..., global_idx]
873 adj_nodes.append(comp_name) # Only need to do this once for each adj component
874 x_couple_next = x_couple.copy()
875 fpi_idx = sorted(fpi_idx)
877 # Main FPI loop
878 if verbose:
879 self.logger.info(f"Initializing FPI for SCC {scc} ...")
880 t1 = time.time()
881 k = 0
882 residual_hist = None
883 x_hist = None
884 while True:
885 for node in scc:
886 # Gather inputs from exogenous and coupling sources
887 node_obj = self.graph.nodes[node]
888 exo_inputs = x[..., node_obj['exo_in']]
889 coupling_inputs = x_couple[..., node_obj['global_in']]
890 comp_input = np.concatenate((exo_inputs, coupling_inputs), axis=-1) # (..., xdim)
892 # Compute component outputs (just don't do this FPI with the real models, please..)
893 indices = index_set.get(node, None) if index_set is not None else None
894 comp_output = node_obj['surrogate'](comp_input[valid_idx, :], use_model=use_model.get(node),
895 model_dir=None, training=training, index_set=indices,
896 ppool=ppool)
897 global_idx = node_obj['global_out']
898 for local_i, global_i in enumerate(global_idx):
899 x_couple_next[valid_idx, global_i] = comp_output[..., local_i]
900 # Can't splice valid_idx with global_idx for some reason, have to loop over global_idx here
902 # Compute residual and check end conditions
903 residual = np.expand_dims(x_couple_next[..., fpi_idx] - x_couple[..., fpi_idx], axis=-1)
904 max_error = np.max(np.abs(residual[valid_idx, :, :]))
905 if verbose:
906 self.logger.info(f'FPI iter: {k}. Max residual: {max_error}. Time: {time.time() - t1} s')
907 if max_error <= fpi_tol:
908 if verbose:
909 self.logger.info(f'FPI converged for SCC {scc} in {k} iterations with {max_error} < tol '
910 f'{fpi_tol}. Final time: {time.time() - t1} s')
911 break
912 if k >= max_fpi_iter:
913 self.logger.warning(f'FPI did not converge in {max_fpi_iter} iterations for SCC {scc}: '
914 f'{max_error} > tol {fpi_tol}. Some samples will be returned as NaN.')
915 converged_idx = np.max(np.abs(residual), axis=(-1, -2)) <= fpi_tol
916 for idx in fpi_idx:
917 y[~converged_idx, idx] = np.nan
918 valid_idx = np.logical_and(valid_idx, converged_idx)
919 break
921 # Keep track of residual and x_couple histories
922 if k == 0:
923 residual_hist = residual.copy() # (..., xdim, 1)
924 x_hist = np.expand_dims(x_couple_next[..., fpi_idx], axis=-1) # (..., xdim, 1)
925 x_couple[:] = x_couple_next[:]
926 k += 1
927 continue # skip anderson accel on first iteration
929 # Iterate with anderson acceleration (only iterate on samples that are not yet converged)
930 converged_idx = np.max(np.abs(residual), axis=(-1, -2)) <= fpi_tol
931 curr_idx = np.logical_and(valid_idx, ~converged_idx)
932 residual_hist = np.concatenate((residual_hist, residual), axis=-1)
933 x_hist = np.concatenate((x_hist, np.expand_dims(x_couple_next[..., fpi_idx], axis=-1)), axis=-1)
934 mk = min(anderson_mem, k)
935 Fk = residual_hist[curr_idx, :, k-mk:] # (..., xdim, mk+1)
936 C = np.ones(Fk.shape[:-2] + (1, mk + 1))
937 b = np.zeros(Fk.shape[:-2] + (len(fpi_idx), 1))
938 d = np.ones(Fk.shape[:-2] + (1, 1))
939 alpha = np.expand_dims(self._constrained_lls(Fk, b, C, d), axis=-3) # (..., 1, mk+1, 1)
940 x_new = np.squeeze(x_hist[curr_idx, :, np.newaxis, -(mk+1):] @ alpha, axis=(-1, -2))
941 for local_i, global_i in enumerate(fpi_idx):
942 x_couple[curr_idx, global_i] = x_new[..., local_i]
943 k += 1
945 # Save outputs of each component in SCC after convergence of FPI
946 for node in scc:
947 global_idx = self.graph.nodes[node]['global_out']
948 for global_i in global_idx:
949 y[valid_idx, global_i] = x_couple_next[valid_idx, global_i]
950 is_computed[global_i] = True
951 self.graph.nodes[node]['is_computed'] = True
953 # Return all component outputs (..., Nqoi); samples that didn't converge during FPI are left as np.nan
954 return y[..., qoi_ind]
956 def __call__(self, *args, **kwargs):
957 """Convenience wrapper to allow calling as `ret = SystemSurrogate(x)`."""
958 return self.predict(*args, **kwargs)
960 def _estimate_coupling_bds(self, num_est: int, anderson_mem: int = 10, fpi_tol: float = 1e-10,
961 max_fpi_iter: int = 100):
962 """Estimate and set the coupling variable bounds.
964 :param num_est: the number of samples of exogenous inputs to use
965 :param anderson_mem: FPI hyperparameter (default is usually good)
966 :param fpi_tol: floating point tolerance for FPI convergence
967 :param max_fpi_iter: maximum number of FPI iterations
968 """
969 self._print_title_str('Estimating coupling variable bounds')
970 x = self.sample_inputs((num_est,))
971 y = self(x, use_model='best', verbose=True, anderson_mem=anderson_mem, fpi_tol=fpi_tol,
972 max_fpi_iter=max_fpi_iter)
973 for i in range(len(self.coupling_vars)):
974 lb = np.nanmin(y[:, i])
975 ub = np.nanmax(y[:, i])
976 self._update_coupling_bds(i, (lb, ub), init=True)
978 def _update_coupling_bds(self, global_idx: int, bds: tuple, init: bool = False, buffer: float = 0.05):
979 """Update coupling variable bounds.
981 :param global_idx: global index of coupling variable to update
982 :param bds: new bounds to update the current bounds with
983 :param init: whether to set new bounds or update existing (default)
984 :param buffer: fraction of domain length to buffer upper/lower bounds
985 """
986 offset = buffer * (bds[1] - bds[0])
987 offset_bds = (bds[0] - offset, bds[1] + offset)
988 coupling_bds = [rv.bounds() for rv in self.coupling_vars]
989 new_bds = offset_bds if init else (min(coupling_bds[global_idx][0], offset_bds[0]),
990 max(coupling_bds[global_idx][1], offset_bds[1]))
991 self.coupling_vars[global_idx].update_bounds(*new_bds)
993 # Iterate over all components and update internal coupling variable bounds
994 for node_name, node_obj in self.graph.nodes.items():
995 if global_idx in node_obj['global_in']:
996 # Get the local index for this coupling variable within each component's inputs
997 local_idx = len(node_obj['exo_in']) + node_obj['global_in'].index(global_idx)
998 node_obj['surrogate'].update_input_bds(local_idx, new_bds)
1000 def sample_inputs(self, size: tuple | int, comp: str = 'System', use_pdf: bool = False,
1001 nominal: dict[str: float] = None, constants: set[str] = None) -> np.ndarray:
1002 """Return samples of the inputs according to provided options.
1004 :param size: tuple or integer specifying shape or number of samples to obtain
1005 :param comp: which component to sample inputs for (defaults to full system exogenous inputs)
1006 :param use_pdf: whether to sample from each variable's pdf, defaults to random samples over input domain instead
1007 :param nominal: `dict(var_id=value)` of nominal values for params with relative uncertainty, also can use
1008 to specify constant values for a variable listed in `constants`
1009 :param constants: set of param types to hold constant while sampling (i.e. calibration, design, etc.),
1010 can also put a `var_id` string in here to specify a single variable to hold constant
1011 :returns x: `(*size, x_dim)` samples of the inputs for the given component/system
1012 """
1013 size = (size, ) if isinstance(size, int) else size
1014 if nominal is None:
1015 nominal = dict()
1016 if constants is None:
1017 constants = set()
1018 x_vars = self.exo_vars if comp == 'System' else self[comp].x_vars
1019 x = np.empty((*size, len(x_vars)))
1020 for i, var in enumerate(x_vars):
1021 # Set a constant value for this variable
1022 if var.param_type in constants or var in constants:
1023 x[..., i] = nominal.get(var, var.nominal) # Defaults to variable's nominal value if not specified
1025 # Sample from this variable's pdf or randomly within its domain bounds (reject if outside bounds)
1026 else:
1027 lb, ub = var.bounds()
1028 x_sample = var.sample(size, nominal=nominal.get(var, None)) if use_pdf \
1029 else var.sample_domain(size)
1030 good_idx = (x_sample < ub) & (x_sample > lb)
1031 num_reject = np.sum(~good_idx)
1033 while num_reject > 0:
1034 new_sample = var.sample((num_reject,), nominal=nominal.get(var, None)) if use_pdf \
1035 else var.sample_domain((num_reject,))
1036 x_sample[~good_idx] = new_sample
1037 good_idx = (x_sample < ub) & (x_sample > lb)
1038 num_reject = np.sum(~good_idx)
1040 x[..., i] = x_sample
1042 return x
1044 def plot_slice(self, slice_idx: IndicesRV = None, qoi_idx: IndicesRV = None, show_surr: bool = True,
1045 show_model: list = None, model_dir: str | Path = None, N: int = 50, nominal: dict[str: float] = None,
1046 random_walk: bool = False, from_file: str | Path = None):
1047 """Helper function to plot 1d slices of the surrogate and/or model(s) over the inputs.
1049 :param slice_idx: list of exogenous input variables or indices to take 1d slices of
1050 :param qoi_idx: list of model output variables or indices to plot 1d slices of
1051 :param show_surr: whether to show the surrogate prediction
1052 :param show_model: also plot model predictions, list() of ['best', 'worst', tuple(alpha), etc.]
1053 :param model_dir: base directory to save model outputs (if specified)
1054 :param N: the number of points to take in the 1d slice
1055 :param nominal: `dict` of `str(var)->nominal` to use as constant value for all non-sliced variables
1056 :param random_walk: whether to slice in a random d-dimensional direction or hold all params const while slicing
1057 :param from_file: path to a .pkl file to load a saved slice from disk
1058 :returns: `fig, ax` with `num_slice` by `num_qoi` subplots
1059 """
1060 # Manage loading important quantities from file (if provided)
1061 xs, ys_model, ys_surr = None, None, None
1062 if from_file is not None:
1063 with open(Path(from_file), 'rb') as fd:
1064 slice_data = pickle.load(fd)
1065 slice_idx = slice_data['slice_idx'] # Must use same input slices as save file
1066 show_model = slice_data['show_model'] # Must use same model data as save file
1067 qoi_idx = slice_data['qoi_idx'] if qoi_idx is None else qoi_idx
1068 xs = slice_data['xs']
1069 model_dir = None # Don't run or save any models if loading from file
1071 # Set default values (take up to the first 3 slices by default)
1072 rand_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4))
1073 if model_dir is not None:
1074 os.mkdir(Path(model_dir) / f'sweep_{rand_id}')
1075 if nominal is None:
1076 nominal = dict()
1077 slice_idx = list(np.arange(0, min(3, len(self.exo_vars)))) if slice_idx is None else slice_idx
1078 qoi_idx = list(np.arange(0, min(3, len(self.coupling_vars)))) if qoi_idx is None else qoi_idx
1079 if isinstance(slice_idx[0], str | BaseRV):
1080 slice_idx = [self.exo_vars.index(var) for var in slice_idx]
1081 if isinstance(qoi_idx[0], str | BaseRV):
1082 qoi_idx = [self.coupling_vars.index(var) for var in qoi_idx]
1084 exo_bds = [var.bounds() for var in self.exo_vars]
1085 xlabels = [self.exo_vars[idx].to_tex(units=True) for idx in slice_idx]
1086 ylabels = [self.coupling_vars[idx].to_tex(units=True) for idx in qoi_idx]
1088 # Construct slice model inputs (if not provided)
1089 if xs is None:
1090 xs = np.zeros((N, len(slice_idx), len(self.exo_vars)))
1091 for i in range(len(slice_idx)):
1092 if random_walk:
1093 # Make a random straight-line walk across d-cube
1094 r0 = np.squeeze(self.sample_inputs((1,), use_pdf=False), axis=0)
1095 r0[slice_idx[i]] = exo_bds[slice_idx[i]][0] # Start slice at this lower bound
1096 rf = np.squeeze(self.sample_inputs((1,), use_pdf=False), axis=0)
1097 rf[slice_idx[i]] = exo_bds[slice_idx[i]][1] # Slice up to this upper bound
1098 xs[0, i, :] = r0
1099 for k in range(1, N):
1100 xs[k, i, :] = xs[k-1, i, :] + (rf-r0)/(N-1)
1101 else:
1102 # Otherwise, only slice one variable
1103 for j in range(len(self.exo_vars)):
1104 if j == slice_idx[i]:
1105 xs[:, i, j] = np.linspace(exo_bds[slice_idx[i]][0], exo_bds[slice_idx[i]][1], N)
1106 else:
1107 xs[:, i, j] = nominal.get(self.exo_vars[j], self.exo_vars[j].nominal)
1109 # Walk through each model that is requested by show_model
1110 if show_model is not None:
1111 if from_file is not None:
1112 ys_model = slice_data['ys_model']
1113 else:
1114 ys_model = list()
1115 for model in show_model:
1116 output_dir = None
1117 if model_dir is not None:
1118 output_dir = (Path(model_dir) / f'sweep_{rand_id}' /
1119 str(model).replace('{', '').replace('}', '').replace(':', '=').replace("'", ''))
1120 os.mkdir(output_dir)
1121 ys_model.append(self(xs, use_model=model, model_dir=output_dir))
1122 if show_surr:
1123 ys_surr = self(xs) if from_file is None else slice_data['ys_surr']
1125 # Make len(qoi) by len(inputs) grid of subplots
1126 fig, axs = plt.subplots(len(qoi_idx), len(slice_idx), sharex='col', sharey='row')
1127 for i in range(len(qoi_idx)):
1128 for j in range(len(slice_idx)):
1129 if len(qoi_idx) == 1:
1130 ax = axs if len(slice_idx) == 1 else axs[j]
1131 elif len(slice_idx) == 1:
1132 ax = axs if len(qoi_idx) == 1 else axs[i]
1133 else:
1134 ax = axs[i, j]
1135 x = xs[:, j, slice_idx[j]]
1136 if show_model is not None:
1137 c = np.array([[0, 0, 0, 1], [0.5, 0.5, 0.5, 1]]) if len(show_model) <= 2 else (
1138 plt.get_cmap('jet')(np.linspace(0, 1, len(show_model))))
1139 for k in range(len(show_model)):
1140 model_str = (str(show_model[k]).replace('{', '').replace('}', '')
1141 .replace(':', '=').replace("'", ''))
1142 model_ret = ys_model[k]
1143 y_model = model_ret[:, j, qoi_idx[i]]
1144 label = {'best': 'High-fidelity' if len(show_model) > 1 else 'Model',
1145 'worst': 'Low-fidelity'}.get(model_str, model_str)
1146 ax.plot(x, y_model, ls='-', c=c[k, :], label=label)
1147 if show_surr:
1148 y_surr = ys_surr[:, j, qoi_idx[i]]
1149 ax.plot(x, y_surr, '--r', label='Surrogate')
1150 ylabel = ylabels[i] if j == 0 else ''
1151 xlabel = xlabels[j] if i == len(qoi_idx) - 1 else ''
1152 legend = (i == 0 and j == len(slice_idx) - 1)
1153 ax_default(ax, xlabel, ylabel, legend=legend)
1154 fig.set_size_inches(3 * len(slice_idx), 3 * len(qoi_idx))
1155 fig.tight_layout()
1157 # Save results (unless we were already loading from a save file)
1158 if from_file is None and self.root_dir is not None:
1159 fname = f's{",".join([str(i) for i in slice_idx])}_q{",".join([str(i) for i in qoi_idx])}'
1160 fname = f'sweep_rand{rand_id}_' + fname if random_walk else f'sweep_nom{rand_id}_' + fname
1161 fdir = Path(self.root_dir) if model_dir is None else Path(model_dir) / f'sweep_{rand_id}'
1162 fig.savefig(fdir / f'{fname}.png', dpi=300, format='png')
1163 save_dict = {'slice_idx': slice_idx, 'qoi_idx': qoi_idx, 'show_model': show_model, 'show_surr': show_surr,
1164 'nominal': nominal, 'random_walk': random_walk, 'xs': xs, 'ys_model': ys_model,
1165 'ys_surr': ys_surr}
1166 with open(fdir / f'{fname}.pkl', 'wb') as fd:
1167 pickle.dump(save_dict, fd)
1169 return fig, axs
1171 def plot_allocation(self, cmap: str = 'Blues', text_bar_width: float = 0.06, arrow_bar_width: float = 0.02):
1172 """Plot bar charts showing cost allocation during training.
1174 !!! Warning "Beta feature"
1175 This has pretty good default settings, but it might look terrible for your use. Mostly provided here as
1176 a template for making cost allocation bar charts. Please feel free to copy and edit in your own code.
1178 :param cmap: the colormap string identifier for `plt`
1179 :param text_bar_width: the minimum total cost fraction above which a bar will print centered model fidelity text
1180 :param arrow_bar_width: the minimum total cost fraction above which a bar will try to print text with an arrow;
1181 below this amount, the bar is too skinny and won't print any text
1182 :returns: `fig, ax`, Figure and Axes objects
1183 """
1184 # Get total cost (including offline overhead)
1185 train_alloc, offline_alloc, cost_cum = self.get_allocation()
1186 total_cost = cost_cum[-1]
1187 for node, alpha_dict in offline_alloc.items():
1188 for alpha, cost in alpha_dict.items():
1189 total_cost += cost[1]
1191 # Remove nodes with cost=0 from alloc dicts (i.e. analytical models)
1192 remove_nodes = []
1193 for node, alpha_dict in train_alloc.items():
1194 if len(alpha_dict) == 0:
1195 remove_nodes.append(node)
1196 for node in remove_nodes:
1197 del train_alloc[node]
1198 del offline_alloc[node]
1200 # Bar chart showing cost allocation breakdown for MF system at end
1201 fig, axs = plt.subplots(1, 2, sharey='row')
1202 width = 0.7
1203 x = np.arange(len(train_alloc))
1204 xlabels = list(train_alloc.keys())
1205 cmap = plt.get_cmap(cmap)
1206 for k in range(2):
1207 ax = axs[k]
1208 alloc = train_alloc if k == 0 else offline_alloc
1209 ax.set_title('Online training' if k == 0 else 'Overhead')
1210 for j, (node, alpha_dict) in enumerate(alloc.items()):
1211 bottom = 0
1212 c_intervals = np.linspace(0, 1, len(alpha_dict))
1213 bars = [(alpha, cost, cost[1] / total_cost) for alpha, cost in alpha_dict.items()]
1214 bars = sorted(bars, key=lambda ele: ele[2], reverse=True)
1215 for i, (alpha, cost, frac) in enumerate(bars):
1216 p = ax.bar(x[j], frac, width, color=cmap(c_intervals[i]), linewidth=1,
1217 edgecolor=[0, 0, 0], bottom=bottom)
1218 bottom += frac
1219 if frac > text_bar_width:
1220 ax.bar_label(p, labels=[f'{alpha}, {round(cost[0])}'], label_type='center')
1221 elif frac > arrow_bar_width:
1222 xy = (x[j] + width / 2, bottom - frac / 2) # Label smaller bars with a text off to the side
1223 ax.annotate(f'{alpha}, {round(cost[0])}', xy, xytext=(xy[0] + 0.2, xy[1]),
1224 arrowprops={'arrowstyle': '->', 'linewidth': 1})
1225 else:
1226 pass # Don't label really small bars
1227 ax_default(ax, '', "Fraction of total cost" if k == 0 else '', legend=False)
1228 ax.set_xticks(x, xlabels)
1229 ax.set_xlim(left=-1, right=x[-1] + 1)
1230 fig.set_size_inches(8, 4)
1231 fig.tight_layout()
1233 if self.root_dir is not None:
1234 fig.savefig(Path(self.root_dir) / 'mf_allocation.png', dpi=300, format='png')
1236 return fig, axs
1238 def get_component(self, comp_name: str) -> ComponentSurrogate:
1239 """Return the `ComponentSurrogate` object for this component.
1241 :param comp_name: name of the component to return
1242 :returns: the `ComponentSurrogate` object
1243 """
1244 comp = self if comp_name == 'System' else self.graph.nodes[comp_name]['surrogate']
1245 return comp
1247 def _print_title_str(self, title_str: str):
1248 """Log an important message."""
1249 self.logger.info('-' * int(len(title_str)/2) + title_str + '-' * int(len(title_str)/2))
1251 def _save_progress(self, filename: str):
1252 """Internal helper to save surrogate training progress (only if `root_dir` exists)
1254 :param filename: the name of the save file to `root/sys/filename.pkl`
1255 """
1256 if self.root_dir is not None:
1257 self.save_to_file(filename)
1259 def save_to_file(self, filename: str, save_dir: str | Path = None):
1260 """Save the `SystemSurrogate` object to a `.pkl` file.
1262 :param filename: filename of the `.pkl` file to save to
1263 :param save_dir: overrides existing surrogate root directory if provided, otherwise defaults to '.'
1264 """
1265 if save_dir is None:
1266 save_dir = '.' if self.root_dir is None else str(Path(self.root_dir) / 'sys')
1267 if not Path(save_dir).is_dir():
1268 save_dir = '.'
1270 exec_temp = self.executor # Temporarily save executor obj (can't pickle it)
1271 self.set_executor(None)
1272 with open(Path(save_dir) / filename, 'wb') as dill_file:
1273 dill.dump(self, dill_file)
1274 self.set_executor(exec_temp)
1275 self.logger.info(f'SystemSurrogate saved to {(Path(save_dir) / filename).resolve()}')
1277 def _set_output_dir(self, set_dict: dict[str: str | Path]):
1278 """Set the output directory for each component in `set_dict`.
1280 :param set_dict: a `dict` of component names (`str`) to their new output directories
1281 """
1282 for node, node_obj in self.graph.nodes.items():
1283 if node in set_dict:
1284 node_obj['surrogate']._set_output_dir(set_dict.get(node))
1286 def set_root_directory(self, root_dir: str | Path = None, stdout: bool = True, logger_name: str = None):
1287 """Set the root to a new directory, for example if you move to a new filesystem.
1289 :param root_dir: new root directory, don't save build products if None
1290 :param stdout: whether to connect the logger to console (default)
1291 :param logger_name: the logger name to use, defaults to class name
1292 """
1293 if root_dir is None:
1294 self.root_dir = None
1295 self.log_file = None
1296 else:
1297 self.root_dir = str(Path(root_dir).resolve())
1298 log_file = None
1299 if not (Path(self.root_dir) / 'sys').is_dir():
1300 os.mkdir(Path(self.root_dir) / 'sys')
1301 if not (Path(self.root_dir) / 'components').is_dir():
1302 os.mkdir(Path(self.root_dir) / 'components')
1303 for f in os.listdir(self.root_dir):
1304 if f.endswith('.log'):
1305 log_file = str((Path(self.root_dir) / f).resolve())
1306 break
1307 if log_file is None:
1308 fname = (datetime.datetime.now(tz=timezone.utc).isoformat().split('.')[0].replace(':', '.') +
1309 'UTC_sys.log')
1310 log_file = str((Path(self.root_dir) / fname).resolve())
1311 self.log_file = log_file
1313 self.set_logger(logger_name, stdout=stdout)
1315 # Update model output directories
1316 for node, node_obj in self.graph.nodes.items():
1317 surr = node_obj['surrogate']
1318 if self.root_dir is not None and surr.save_enabled():
1319 output_dir = str((Path(self.root_dir) / 'components' / node).resolve())
1320 if not Path(output_dir).is_dir():
1321 os.mkdir(output_dir)
1322 surr._set_output_dir(output_dir)
1324 def __getitem__(self, component: str) -> ComponentSurrogate:
1325 """Convenience method to get the `ComponentSurrogate object` from the `SystemSurrogate`.
1327 :param component: the name of the component to get
1328 :returns: the `ComponentSurrogate` object
1329 """
1330 return self.get_component(component)
1332 def __repr__(self):
1333 s = f'----SystemSurrogate----\nAdjacency: \n{nx.to_numpy_array(self.graph, dtype=int)}\n' \
1334 f'Exogenous inputs: {[str(var) for var in self.exo_vars]}\n'
1335 for node, node_obj in self.graph.nodes.items():
1336 s += f'Component: {node}\n{node_obj["surrogate"]}'
1337 return s
1339 def __str__(self):
1340 return self.__repr__()
1342 def set_executor(self, executor: Executor | None):
1343 """Set a new `concurrent.futures.Executor` object for parallel calls.
1345 :param executor: the new `Executor` object
1346 """
1347 self.executor = executor
1348 for node, node_obj in self.graph.nodes.items():
1349 node_obj['surrogate'].executor = executor
1351 def set_logger(self, name: str = None, log_file: str | Path = None, stdout: bool = True):
1352 """Set a new `logging.Logger` object with the given unique `name`.
1354 :param name: the name of the new logger object
1355 :param stdout: whether to connect the logger to console (default)
1356 :param log_file: log file (if provided)
1357 """
1358 if log_file is None:
1359 log_file = self.log_file
1360 if name is None:
1361 name = self.__class__.__name__
1362 self.log_file = log_file
1363 self.logger = get_logger(name, log_file=log_file, stdout=stdout)
1365 for node, node_obj in self.graph.nodes.items():
1366 surr = node_obj['surrogate']
1367 surr.logger = self.logger.getChild('Component')
1368 surr.log_file = self.log_file
1370 @staticmethod
1371 def load_from_file(filename: str | Path, root_dir: str | Path = None, executor: Executor = None,
1372 stdout: bool = True, logger_name: str = None):
1373 """Load a `SystemSurrogate` object from file.
1375 :param filename: the .pkl file to load
1376 :param root_dir: if provided, an `amisc_timestamp` directory will be created at `root_dir`. Ignored if the
1377 `.pkl` file already resides in an `amisc`-like directory. If none, then the surrogate object
1378 is only loaded into memory and is not given a file directory for any save artifacts.
1379 :param executor: a `concurrent.futures.Executor` object to set; clears it if None
1380 :param stdout: whether to log to console
1381 :param logger_name: the name of the logger to use, if None then uses class name by default
1382 :returns: the `SystemSurrogate` object
1383 """
1384 with open(Path(filename), 'rb') as dill_file:
1385 sys_surr = dill.load(dill_file)
1386 sys_surr.set_executor(executor)
1387 sys_surr.x_vars = sys_surr.exo_vars # backwards compatible v0.2.0
1389 copy_flag = False
1390 if root_dir is None:
1391 parts = Path(filename).resolve().parts
1392 if len(parts) > 2:
1393 if parts[-3].startswith('amisc_'):
1394 root_dir = Path(filename).parent.parent # Assumes amisc_root/sys/filename.pkl default structure
1395 else:
1396 if not Path(root_dir).is_dir():
1397 root_dir = '.'
1398 timestamp = datetime.datetime.now(tz=timezone.utc).isoformat().split('.')[0].replace(':', '.')
1399 root_dir = Path(root_dir) / ('amisc_' + timestamp)
1400 os.mkdir(root_dir)
1401 copy_flag = True
1403 sys_surr.set_root_directory(root_dir, stdout=stdout, logger_name=logger_name)
1405 if copy_flag:
1406 shutil.copyfile(Path(filename), root_dir / 'sys' / Path(filename).name)
1408 return sys_surr
1410 @staticmethod
1411 def _constrained_lls(A: np.ndarray, b: np.ndarray, C: np.ndarray, d: np.ndarray) -> np.ndarray:
1412 """Minimize $||Ax-b||_2$, subject to $Cx=d$, i.e. constrained linear least squares.
1414 !!! Note
1415 See http://www.seas.ucla.edu/~vandenbe/133A/lectures/cls.pdf for more detail.
1417 :param A: `(..., M, N)`, vandermonde matrix
1418 :param b: `(..., M, 1)`, data
1419 :param C: `(..., P, N)`, constraint operator
1420 :param d: `(..., P, 1)`, constraint condition
1421 :returns: `(..., N, 1)`, the solution parameter vector `x`
1422 """
1423 M = A.shape[-2]
1424 dims = len(A.shape[:-2])
1425 T_axes = tuple(np.arange(0, dims)) + (-1, -2)
1426 Q, R = np.linalg.qr(np.concatenate((A, C), axis=-2))
1427 Q1 = Q[..., :M, :]
1428 Q2 = Q[..., M:, :]
1429 Q1_T = np.transpose(Q1, axes=T_axes)
1430 Q2_T = np.transpose(Q2, axes=T_axes)
1431 Qtilde, Rtilde = np.linalg.qr(Q2_T)
1432 Qtilde_T = np.transpose(Qtilde, axes=T_axes)
1433 Rtilde_T_inv = np.linalg.pinv(np.transpose(Rtilde, axes=T_axes))
1434 w = np.linalg.pinv(Rtilde) @ (Qtilde_T @ Q1_T @ b - Rtilde_T_inv @ d)
1436 return np.linalg.pinv(R) @ (Q1_T @ b - Q2_T @ w)