Coverage for src/amisc/system.py: 87%

1029 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-03-10 15:12 +0000

1"""The `System` object is a framework for multidisciplinary models. It manages multiple single discipline component 

2models and the connections between them. It provides a top-level interface for constructing and evaluating surrogates. 

3 

4Features: 

5 

6- Manages multidisciplinary models in a graph data structure, supports feedforward and feedback connections 

7- Feedback connections are solved with a fixed-point iteration (FPI) nonlinear solver with anderson acceleration 

8- Top-level interface for training and using surrogates of each component model 

9- Adaptive experimental design for choosing training data efficiently 

10- Convenient testing, plotting, and performance metrics provided to assess quality of surrogates 

11- Detailed logging and traceback information 

12- Supports parallel or vectorized execution of component models 

13- Abstract and flexible interfacing with component models 

14- Easy serialization and deserialization to/from YAML files 

15- Supports approximating field quantities via compression 

16 

17Includes: 

18 

19- `TrainHistory` — a history of training iterations for the system surrogate 

20- `System` — the top-level object for managing multidisciplinary models 

21""" 

22# ruff: noqa: E702 

23from __future__ import annotations 

24 

25import copy 

26import datetime 

27import functools 

28import logging 

29import os 

30import pickle 

31import random 

32import string 

33import time 

34import warnings 

35from collections import ChainMap, UserList, deque 

36from concurrent.futures import ALL_COMPLETED, Executor, wait 

37from datetime import timezone 

38from pathlib import Path 

39from typing import Annotated, Callable, ClassVar, Literal, Optional 

40 

41import matplotlib.pyplot as plt 

42import networkx as nx 

43import numpy as np 

44import yaml 

45from matplotlib.ticker import MaxNLocator 

46from pydantic import BaseModel, ConfigDict, Field, field_validator 

47 

48from amisc.component import Component, IndexSet, MiscTree 

49from amisc.serialize import Serializable, _builtin 

50from amisc.typing import COORDS_STR_ID, LATENT_STR_ID, Dataset, MultiIndex, TrainIteration 

51from amisc.utils import ( 

52 _combine_latent_arrays, 

53 constrained_lls, 

54 format_inputs, 

55 format_outputs, 

56 get_logger, 

57 relative_error, 

58 to_model_dataset, 

59 to_surrogate_dataset, 

60) 

61from amisc.variable import VariableList 

62 

63__all__ = ['TrainHistory', 'System'] 

64 

65 

66class TrainHistory(UserList, Serializable): 

67 """Stores the training history of a system surrogate as a list of `TrainIteration` objects.""" 

68 

69 def __init__(self, data: list = None): 

70 data = data or [] 

71 super().__init__(self._validate_data(data)) 

72 

73 def serialize(self) -> list[dict]: 

74 """Return a list of each result in the history serialized to a `dict`.""" 

75 ret_list = [] 

76 for res in self: 

77 new_res = res.copy() 

78 new_res['alpha'] = str(res['alpha']) 

79 new_res['beta'] = str(res['beta']) 

80 ret_list.append(new_res) 

81 return ret_list 

82 

83 @classmethod 

84 def deserialize(cls, serialized_data: list[dict]) -> TrainHistory: 

85 """Deserialize a list of `dict` objects into a `TrainHistory` object.""" 

86 return TrainHistory(serialized_data) 

87 

88 @classmethod 

89 def _validate_data(cls, data: list[dict]) -> list[TrainIteration]: 

90 return [cls._validate_item(item) for item in data] 

91 

92 @classmethod 

93 def _validate_item(cls, item: dict): 

94 """Format a `TrainIteration` `dict` item before appending to the history.""" 

95 item.setdefault('test_error', None) 

96 item.setdefault('overhead_s', 0.0) 

97 item.setdefault('model_s', 0.0) 

98 item['alpha'] = MultiIndex(item['alpha']) 

99 item['beta'] = MultiIndex(item['beta']) 

100 item['num_evals'] = int(item['num_evals']) 

101 item['added_cost'] = float(item['added_cost']) 

102 item['added_error'] = float(item['added_error']) 

103 item['overhead_s'] = float(item['overhead_s']) 

104 item['model_s'] = float(item['model_s']) 

105 return item 

106 

107 def append(self, item: dict): 

108 super().append(self._validate_item(item)) 

109 

110 def __add__(self, other): 

111 other_list = other.data if isinstance(other, TrainHistory) else other 

112 return TrainHistory(data=self.data + other_list) 

113 

114 def extend(self, items): 

115 super().extend([self._validate_item(item) for item in items]) 

116 

117 def insert(self, index, item): 

118 super().insert(index, self._validate_item(item)) 

119 

120 def __setitem__(self, key, value): 

121 super().__setitem__(key, self._validate_item(value)) 

122 

123 def __eq__(self, other): 

124 """Two `TrainHistory` objects are equal if they have the same length and all items are equal, excluding nans.""" 

125 if not isinstance(other, TrainHistory): 

126 return False 

127 if len(self) != len(other): 

128 return False 

129 for item_self, item_other in zip(self, other): 

130 for key in item_self: 

131 if key in item_other: 

132 val_self = item_self[key] 

133 val_other = item_other[key] 

134 if isinstance(val_self, float) and isinstance(val_other, float): 

135 if not (np.isnan(val_self) and np.isnan(val_other)) and val_self != val_other: 

136 return False 

137 elif isinstance(val_self, dict) and isinstance(val_other, dict): 

138 for v, err in val_self.items(): 

139 if v in val_other: 

140 err_other = val_other[v] 

141 if not (np.isnan(err) and np.isnan(err_other)) and err != err_other: 

142 return False 

143 else: 

144 return False 

145 elif val_self != val_other: 

146 return False 

147 else: 

148 return False 

149 return True 

150 

151 

152class _Converged: 

153 """Helper class to track which samples have converged during `System.predict()`.""" 

154 def __init__(self, num_samples): 

155 self.num_samples = num_samples 

156 self.valid_idx = np.full(num_samples, True) # All samples are valid by default 

157 self.converged_idx = np.full(num_samples, False) # For FPI convergence 

158 

159 def reset_convergence(self): 

160 self.converged_idx = np.full(self.num_samples, False) 

161 

162 @property 

163 def curr_idx(self): 

164 return np.logical_and(self.valid_idx, ~self.converged_idx) 

165 

166 

167def _merge_shapes(target_shape, arr): 

168 """Helper to merge an array into the target shape.""" 

169 shape1, shape2 = target_shape, arr.shape 

170 if len(shape2) > len(shape1): 

171 shape1, shape2 = shape2, shape1 

172 result = [] 

173 for i in range(len(shape1)): 

174 if i < len(shape2): 

175 if shape1[i] == 1: 

176 result.append(shape2[i]) 

177 elif shape2[i] == 1: 

178 result.append(shape1[i]) 

179 else: 

180 result.append(shape1[i]) 

181 else: 

182 result.append(1) 

183 arr = arr.reshape(tuple(result)) 

184 return np.broadcast_to(arr, target_shape).copy() 

185 

186 

187class System(BaseModel, Serializable): 

188 """ 

189 Multidisciplinary (MD) surrogate framework top-level class. Construct a `System` from a list of 

190 `Component` models. 

191 

192 !!! Example 

193 ```python 

194 def f1(x): 

195 y = x ** 2 

196 return y 

197 def f2(y): 

198 z = y + 1 

199 return z 

200 

201 system = System(f1, f2) 

202 ``` 

203 

204 A `System` object can saved/loaded from `.yml` files using the `!System` yaml tag. 

205 

206 :ivar name: the name of the system 

207 :ivar components: list of `Component` models that make up the MD system 

208 :ivar train_history: history of training iterations for the system surrogate (filled in during training) 

209 

210 :ivar _root_dir: root directory where all surrogate build products are saved to file 

211 :ivar _logger: logger object for the system 

212 """ 

213 yaml_tag: ClassVar[str] = u'!System' 

214 model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True, validate_default=True, 

215 extra='allow') 

216 

217 name: Annotated[str, Field(default_factory=lambda: "System_" + "".join(random.choices(string.digits, k=3)))] 

218 components: Callable | Component | list[Callable | Component] 

219 train_history: list[dict] | TrainHistory = TrainHistory() 

220 amisc_version: str = None 

221 

222 _root_dir: Optional[str] 

223 _logger: Optional[logging.Logger] = None 

224 

225 def __init__(self, /, *args, components=None, root_dir=None, **kwargs): 

226 """Construct a `System` object from a list of `Component` models in `*args` or `components`. If 

227 a `root_dir` is provided, then a new directory will be created under `root_dir` with the name 

228 `amisc_{timestamp}`. This directory will be used to save all build products and log files. 

229 

230 :param components: list of `Component` models that make up the MD system 

231 :param root_dir: root directory where all surrogate build products are saved to file (optional) 

232 """ 

233 if components is None: 

234 components = [] 

235 for a in args: 

236 if isinstance(a, Component) or callable(a): 

237 components.append(a) 

238 else: 

239 try: 

240 components.extend(a) 

241 except TypeError as e: 

242 raise ValueError(f"Invalid component: {a}") from e 

243 

244 import amisc 

245 amisc_version = kwargs.pop('amisc_version', amisc.__version__) 

246 super().__init__(components=components, amisc_version=amisc_version, **kwargs) 

247 self.root_dir = root_dir 

248 

249 def __repr__(self): 

250 s = f'---- {self.name} ----\n' 

251 s += f'amisc version: {self.amisc_version}\n' 

252 s += f'Refinement level: {self.refine_level}\n' 

253 s += f'Components: {", ".join([comp.name for comp in self.components])}\n' 

254 s += f'Inputs: {", ".join([var.name for var in self.inputs()])}\n' 

255 s += f'Outputs: {", ".join([var.name for var in self.outputs()])}' 

256 return s 

257 

258 def __str__(self): 

259 return self.__repr__() 

260 

261 @field_validator('components') 

262 @classmethod 

263 def _validate_components(cls, comps) -> list[Component]: 

264 if not isinstance(comps, list): 

265 comps = [comps] 

266 comps = [Component.deserialize(c) for c in comps] 

267 

268 # Merge all variables to avoid name conflicts 

269 merged_vars = VariableList.merge(*[comp.inputs for comp in comps], *[comp.outputs for comp in comps]) 

270 for comp in comps: 

271 comp.inputs.update({var.name: var for var in merged_vars.values() if var in comp.inputs}) 

272 comp.outputs.update({var.name: var for var in merged_vars.values() if var in comp.outputs}) 

273 

274 return comps 

275 

276 @field_validator('train_history') 

277 @classmethod 

278 def _validate_train_history(cls, history) -> TrainHistory: 

279 if isinstance(history, TrainHistory): 

280 return history 

281 else: 

282 return TrainHistory.deserialize(history) 

283 

284 def graph(self) -> nx.DiGraph: 

285 """Build a directed graph of the system components based on their input-output relationships.""" 

286 graph = nx.DiGraph() 

287 model_deps = {} 

288 for comp in self.components: 

289 graph.add_node(comp.name) 

290 for output in comp.outputs: 

291 model_deps[output] = comp.name 

292 for comp in self.components: 

293 for in_var in comp.inputs: 

294 if in_var in model_deps: 

295 graph.add_edge(model_deps[in_var], comp.name) 

296 

297 return graph 

298 

299 def _save_on_error(func): 

300 """Gracefully exit and save the `System` object on any errors.""" 

301 @functools.wraps(func) 

302 def wrap(self, *args, **kwargs): 

303 try: 

304 return func(self, *args, **kwargs) 

305 except: 

306 if self.root_dir is not None: 

307 self.save_to_file(f'{self.name}_error.yml') 

308 self.logger.critical(f'An error occurred during execution of "{func.__name__}". Saving ' 

309 f'System object to {self.name}_error.yml', exc_info=True) 

310 self.logger.info(f'Final system surrogate on exit: \n {self}') 

311 raise 

312 return wrap 

313 _save_on_error = staticmethod(_save_on_error) 

314 

315 def insert_components(self, components: list | Callable | Component): 

316 """Insert new components into the system.""" 

317 components = components if isinstance(components, list) else [components] 

318 self.components = self.components + components 

319 

320 def swap_component(self, old_component: str | Component, new_component: Callable | Component): 

321 """Replace an old component with a new component.""" 

322 old_name = old_component if isinstance(old_component, str) else old_component.name 

323 comps = [comp if comp.name != old_name else new_component for comp in self.components] 

324 self.components = comps 

325 

326 def remove_component(self, component: str | Component): 

327 """Remove a component from the system.""" 

328 comp_name = component if isinstance(component, str) else component.name 

329 self.components = [comp for comp in self.components if comp.name != comp_name] 

330 

331 def inputs(self) -> VariableList: 

332 """Collect all inputs from each component in the `System` and combine them into a 

333 single [`VariableList`][amisc.variable.VariableList] object, excluding variables that are also outputs of 

334 any component. 

335 

336 :returns: A [`VariableList`][amisc.variable.VariableList] containing all inputs from the components. 

337 """ 

338 all_inputs = ChainMap(*[comp.inputs for comp in self.components]) 

339 return VariableList({k: all_inputs[k] for k in all_inputs.keys() - self.outputs().keys()}) 

340 

341 def outputs(self) -> VariableList: 

342 """Collect all outputs from each component in the `System` and combine them into a 

343 single [`VariableList`][amisc.variable.VariableList] object. 

344 

345 :returns: A [`VariableList`][amisc.variable.VariableList] containing all outputs from the components. 

346 """ 

347 return VariableList({k: v for k, v in ChainMap(*[comp.outputs for comp in self.components]).items()}) 

348 

349 def coupling_variables(self) -> VariableList: 

350 """Collect all coupling variables from each component in the `System` and combine them into a 

351 single [`VariableList`][amisc.variable.VariableList] object. 

352 

353 :returns: A [`VariableList`][amisc.variable.VariableList] containing all coupling variables from the components. 

354 """ 

355 all_outputs = self.outputs() 

356 return VariableList({k: all_outputs[k] for k in (all_outputs.keys() & 

357 ChainMap(*[comp.inputs for comp in self.components]).keys())}) 

358 

359 def variables(self): 

360 """Iterator over all variables in the system (inputs and outputs).""" 

361 yield from ChainMap(self.inputs(), self.outputs()).values() 

362 

363 @property 

364 def refine_level(self) -> int: 

365 """The total number of training iterations.""" 

366 return len(self.train_history) 

367 

368 @property 

369 def logger(self) -> logging.Logger: 

370 return self._logger 

371 

372 @logger.setter 

373 def logger(self, logger: logging.Logger): 

374 self._logger = logger 

375 for comp in self.components: 

376 comp.logger = logger 

377 

378 @staticmethod 

379 def timestamp() -> str: 

380 """Return a UTC timestamp string in the isoformat `YYYY-MM-DDTHH.MM.SS`.""" 

381 return datetime.datetime.now(tz=timezone.utc).isoformat().split('.')[0].replace(':', '.') 

382 

383 @property 

384 def root_dir(self): 

385 """Return the root directory of the surrogate (if available), otherwise `None`.""" 

386 return Path(self._root_dir) if self._root_dir is not None else None 

387 

388 @root_dir.setter 

389 def root_dir(self, root_dir: str | Path): 

390 """Set the root directory for all build products. If `root_dir` is `None`, then no products will be saved. 

391 Otherwise, log files, model outputs, surrogate files, etc. will be saved under this directory. 

392 

393 !!! Note "`amisc` root directory" 

394 If `root_dir` is not `None`, then a new directory will be created under `root_dir` with the name 

395 `amisc_{timestamp}`. This directory will be used to save all build products. If `root_dir` matches the 

396 `amisc_*` format, then it will be used directly. 

397 

398 :param root_dir: the root directory for all build products 

399 """ 

400 if root_dir is not None: 

401 parts = Path(root_dir).resolve().parts 

402 if parts[-1].startswith('amisc_'): 

403 self._root_dir = Path(root_dir).resolve().as_posix() 

404 if not self.root_dir.is_dir(): 

405 os.mkdir(self.root_dir) 

406 else: 

407 root_dir = Path(root_dir) / ('amisc_' + self.timestamp()) 

408 os.mkdir(root_dir) 

409 self._root_dir = Path(root_dir).resolve().as_posix() 

410 

411 log_file = None 

412 if not (pth := self.root_dir / 'surrogates').is_dir(): 

413 os.mkdir(pth) 

414 if not (pth := self.root_dir / 'components').is_dir(): 

415 os.mkdir(pth) 

416 for comp in self.components: 

417 if comp.model_kwarg_requested('output_path'): 

418 if not (comp_pth := pth / comp.name).is_dir(): 

419 os.mkdir(comp_pth) 

420 for f in os.listdir(self.root_dir): 

421 if f.endswith('.log'): 

422 log_file = (self.root_dir / f).resolve().as_posix() 

423 break 

424 if log_file is None: 

425 log_file = (self.root_dir / f'amisc_{self.timestamp()}.log').resolve().as_posix() 

426 self.set_logger(log_file=log_file) 

427 

428 else: 

429 self._root_dir = None 

430 self.set_logger(log_file=None) 

431 

432 def set_logger(self, log_file: str | Path | bool = None, stdout: bool = None, logger: logging.Logger = None, 

433 level: int = logging.INFO): 

434 """Set a new `logging.Logger` object. 

435 

436 :param log_file: log to this file if str or Path (defaults to whatever is currently set or empty); 

437 set `False` to remove file logging or set `True` to create a default log file in the root dir 

438 :param stdout: whether to connect the logger to console (defaults to whatever is currently set or `False`) 

439 :param logger: the logging object to use (this will override the `log_file` and `stdout` arguments if set); 

440 if `None`, then a new logger is created according to `log_file` and `stdout` 

441 :param level: the logging level to set the logger to (defaults to `logging.INFO`) 

442 """ 

443 # Decide whether to use stdout 

444 if stdout is None: 

445 stdout = False 

446 if self._logger is not None: 

447 for handler in self._logger.handlers: 

448 if isinstance(handler, logging.StreamHandler): 

449 stdout = True 

450 break 

451 

452 # Decide what log_file to use (if any) 

453 if log_file is True: 

454 log_file = pth / f'amisc_{self.timestamp()}.log' if (pth := self.root_dir) is not None else ( 

455 f'amisc_{self.timestamp()}.log') 

456 elif log_file is None: 

457 if self._logger is not None: 

458 for handler in self._logger.handlers: 

459 if isinstance(handler, logging.FileHandler): 

460 log_file = handler.baseFilename 

461 break 

462 elif log_file is False: 

463 log_file = None 

464 

465 self._logger = logger or get_logger(self.name, log_file=log_file, stdout=stdout, level=level) 

466 

467 for comp in self.components: 

468 comp.set_logger(log_file=log_file, stdout=stdout, logger=logger, level=level) 

469 

470 def sample_inputs(self, size: tuple | int, 

471 component: str = 'System', 

472 normalize: bool = True, 

473 use_pdf: bool | str | list[str] = False, 

474 include: str | list[str] = None, 

475 exclude: str | list[str] = None, 

476 nominal: dict[str, float] = None) -> Dataset: 

477 """Return samples of the inputs according to provided options. Will return samples in the 

478 normalized/compressed space of the surrogate by default. See [`to_model_dataset`][amisc.utils.to_model_dataset] 

479 to convert the samples to be usable by the true model directly. 

480 

481 :param size: tuple or integer specifying shape or number of samples to obtain 

482 :param component: which component to sample inputs for (defaults to full system exogenous inputs) 

483 :param normalize: whether to normalize the samples (defaults to True) 

484 :param use_pdf: whether to sample from variable pdfs (defaults to False, which will instead sample from the 

485 variable domain bounds). If a string or list of strings is provided, then only those variables 

486 or variable categories will be sampled using their pdfs. 

487 :param include: a list of variable or variable categories to include in the sampling. Defaults to using all 

488 input variables. 

489 :param exclude: a list of variable or variable categories to exclude from the sampling. Empty by default. 

490 :param nominal: `dict(var_id=value)` of nominal values for params with relative uncertainty. Specify nominal 

491 values as unnormalized (will be normalized if `normalize=True`) 

492 :returns: `dict` of `(*size,)` samples for each selected input variable 

493 """ 

494 size = (size, ) if isinstance(size, int) else size 

495 nominal = nominal or dict() 

496 inputs = self.inputs() if component == 'System' else self[component].inputs 

497 if include is None: 

498 include = [] 

499 if not isinstance(include, list): 

500 include = [include] 

501 if exclude is None: 

502 exclude = [] 

503 if not isinstance(exclude, list): 

504 exclude = [exclude] 

505 if isinstance(use_pdf, str): 

506 use_pdf = [use_pdf] 

507 

508 selected_inputs = [] 

509 for var in inputs: 

510 if len(include) == 0 or var.name in include or var.category in include: 

511 if var.name not in exclude and var.category not in exclude: 

512 selected_inputs.append(var) 

513 

514 samples = {} 

515 for var in selected_inputs: 

516 # Sample from latent variable domains for field quantities 

517 if var.compression is not None: 

518 latent = var.sample_domain(size) 

519 for i in range(latent.shape[-1]): 

520 samples[f'{var.name}{LATENT_STR_ID}{i}'] = latent[..., i] 

521 

522 # Sample scalars normally 

523 else: 

524 if (domain := var.get_domain()) is None: 

525 raise RuntimeError(f"Trying to sample variable '{var}' with empty domain. Please set a domain " 

526 f"for this variable. Samples outside the provided domain will be rejected.") 

527 lb, ub = domain 

528 pdf = (var.name in use_pdf or var.category in use_pdf) if isinstance(use_pdf, list) else use_pdf 

529 nom = nominal.get(var.name, None) 

530 

531 x_sample = var.sample(size, nominal=nom) if pdf else var.sample_domain(size) 

532 good_idx = (x_sample < ub) & (x_sample > lb) 

533 num_reject = np.sum(~good_idx) 

534 

535 while num_reject > 0: 

536 new_sample = var.sample((num_reject,), nominal=nom) if pdf else var.sample_domain((num_reject,)) 

537 x_sample[~good_idx] = new_sample 

538 good_idx = (x_sample < ub) & (x_sample > lb) 

539 num_reject = np.sum(~good_idx) 

540 

541 samples[var.name] = var.normalize(x_sample) if normalize else x_sample 

542 

543 return samples 

544 

545 def simulate_fit(self): 

546 """Loop back through training history and simulate each iteration. Will yield the internal data structures 

547 of each `Component` surrogate after each iteration of training (without needing to call `fit()` or any 

548 of the underlying models). This might be useful, for example, for computing the surrogate predictions on 

549 a new test set or viewing cumulative training costs. 

550 

551 !!! Example 

552 Say you have a new test set: `(new_xtest, new_ytest)`, and you want to compute the accuracy of the 

553 surrogate fit at each iteration of the training history: 

554 

555 ```python 

556 for train_iter, active_sets, candidate_sets, misc_coeff_train, misc_coeff_test in system.simulate_fit(): 

557 # Do something with the surrogate data structures 

558 new_ysurr = system.predict(new_xtest, index_set=active_sets, misc_coeff=misc_coeff_train) 

559 train_error = relative_error(new_ysurr, new_ytest) 

560 ``` 

561 

562 :return: a generator of the active index sets, candidate index sets, and MISC coefficients 

563 of each component model at each iteration of the training history 

564 """ 

565 # "Simulated" data structures for each component 

566 active_sets = {comp.name: IndexSet() for comp in self.components} # active index sets for each component 

567 candidate_sets = {comp.name: IndexSet() for comp in self.components} # candidate sets for each component 

568 misc_coeff_train = {comp.name: MiscTree() for comp in self.components} # MISC coeff for active sets 

569 misc_coeff_test = {comp.name: MiscTree() for comp in self.components} # MISC coeff for active + candidate sets 

570 

571 for train_result in self.train_history: 

572 # The selected refinement component and indices 

573 comp_star = train_result['component'] 

574 alpha_star = train_result['alpha'] 

575 beta_star = train_result['beta'] 

576 comp = self[comp_star] 

577 

578 # Get forward neighbors for the selected index 

579 neighbors = comp._neighbors(alpha_star, beta_star, active_set=active_sets[comp_star], forward=True) 

580 

581 # "Activate" the index in the simulated data structure 

582 s = set() 

583 s.add((alpha_star, beta_star)) 

584 comp.update_misc_coeff(IndexSet(s), index_set=active_sets[comp_star], 

585 misc_coeff=misc_coeff_train[comp_star]) 

586 

587 if (alpha_star, beta_star) in candidate_sets[comp_star]: 

588 candidate_sets[comp_star].remove((alpha_star, beta_star)) 

589 else: 

590 # Only for initial index which didn't come from the candidate set 

591 comp.update_misc_coeff(IndexSet(s), index_set=active_sets[comp_star].union(candidate_sets[comp_star]), 

592 misc_coeff=misc_coeff_test[comp_star]) 

593 active_sets[comp_star].update(s) 

594 

595 comp.update_misc_coeff(neighbors, index_set=active_sets[comp_star].union(candidate_sets[comp_star]), 

596 misc_coeff=misc_coeff_test[comp_star]) # neighbors will only ever pass here once 

597 candidate_sets[comp_star].update(neighbors) 

598 

599 # Caller can now do whatever they want as if the system surrogate were at this training iteration 

600 # See the "index_set" and "misc_coeff" overrides for `System.predict()` for example 

601 yield train_result, active_sets, candidate_sets, misc_coeff_train, misc_coeff_test 

602 

603 def add_output(self): 

604 """Add an output variable retroactively to a component surrogate. User should provide a callable that 

605 takes a save path and extracts the model output data for given training point/location. 

606 """ 

607 # TODO 

608 # Loop back through the surrogate training history 

609 # Simulate activate_index and extract the model output from file rather than calling the model 

610 # Update all interpolator states 

611 raise NotImplementedError 

612 

613 @_save_on_error 

614 def fit(self, targets: list = None, 

615 num_refine: int = 100, 

616 max_iter: int = 20, 

617 max_tol: float = 1e-3, 

618 runtime_hr: float = 1., 

619 estimate_bounds: bool = False, 

620 update_bounds: bool = True, 

621 test_set: tuple | str | Path = None, 

622 start_test_check: int = None, 

623 save_interval: int = 0, 

624 plot_interval: int = 1, 

625 cache_interval: int = 0, 

626 executor: Executor = None, 

627 weight_fcns: dict[str, callable] | Literal['pdf'] | None = 'pdf'): 

628 """Train the system surrogate adaptively by iterative refinement until an end condition is met. 

629 

630 :param targets: list of system output variables to focus refinement on, use all outputs if not specified 

631 :param num_refine: number of input samples to compute error indicators on 

632 :param max_iter: the maximum number of refinement steps to take 

633 :param max_tol: the max allowable value in relative L2 error to achieve 

634 :param runtime_hr: the threshold wall clock time (hr) at which to stop further refinement (will go 

635 until all models finish the current iteration) 

636 :param estimate_bounds: whether to estimate bounds for the coupling variables; will only try to estimate from 

637 the `test_set` if provided (defaults to `True`). Otherwise, you should manually 

638 provide domains for all coupling variables. 

639 :param update_bounds: whether to continuously update coupling variable bounds during refinement 

640 :param test_set: `tuple` of `(xtest, ytest)` to show convergence of surrogate to the true model. The test set 

641 inputs and outputs are specified as `dicts` of `np.ndarrays` with keys corresponding to the 

642 variable names. Can also pass a path to a `.pkl` file that has the test set data as 

643 {'test_set': (xtest, ytest)}. 

644 :param start_test_check: the iteration to start checking the test set error (defaults to the number 

645 of components); surrogate evaluation isn't useful during initialization so you 

646 should at least allow one iteration per component before checking test set error 

647 :param save_interval: number of refinement steps between each progress save, none if 0; `System.root_dir` 

648 must be specified to save to file 

649 :param plot_interval: how often to plot the error indicator and test set error (defaults to every iteration); 

650 will only plot and save to file if a root directory is set 

651 :param cache_interval: how often to cache component data in order to speed up future training iterations (at 

652 the cost of additional memory usage); defaults to 0 (no caching) 

653 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations (optional, but 

654 recommended for expensive models) 

655 :param weight_fcns: a `dict` of weight functions to apply to each input variable for training data selection; 

656 defaults to using the pdf of each variable. If None, then no weighting is applied. 

657 """ 

658 start_test_check = start_test_check or sum([1 for _ in self.components if _.has_surrogate]) 

659 targets = targets or self.outputs() 

660 xtest, ytest = self._get_test_set(test_set) 

661 max_iter = self.refine_level + max_iter 

662 

663 # Estimate bounds from test set if provided (override current bounds if they are set) 

664 if estimate_bounds: 

665 if ytest is not None: 

666 y_samples = to_surrogate_dataset(ytest, self.outputs(), del_fields=True)[0] # normalize/compress 

667 _combine_latent_arrays(y_samples) 

668 coupling_vars = {k: v for k, v in self.coupling_variables().items() if k in y_samples} 

669 y_min, y_max = {}, {} 

670 for var in coupling_vars.values(): 

671 y_min[var] = np.nanmin(y_samples[var], axis=0) 

672 y_max[var] = np.nanmax(y_samples[var], axis=0) 

673 if var.compression is not None: 

674 new_domain = list(zip(y_min[var].tolist(), y_max[var].tolist())) 

675 var.update_domain(new_domain, override=True) 

676 else: 

677 new_domain = (float(y_min[var]), float(y_max[var])) 

678 var.update_domain(var.denormalize(new_domain), override=True) 

679 del y_samples 

680 else: 

681 self.logger.warning('Could not estimate bounds for coupling variables: no test set provided. ' 

682 'Make sure you manually provide (good) coupling variable domains.') 

683 

684 # Track convergence progress on the error indicator and test set (plot to file) 

685 if self.root_dir is not None: 

686 err_record = [res['added_error'] for res in self.train_history] 

687 err_fig, err_ax = plt.subplots(figsize=(6, 5), layout='tight') 

688 

689 if xtest is not None and ytest is not None: 

690 num_plot = min(len(targets), 3) 

691 test_record = np.full((self.refine_level, num_plot), np.nan) 

692 t_fig, t_ax = plt.subplots(1, num_plot, figsize=(3.5 * num_plot, 4), layout='tight', squeeze=False, 

693 sharey='row') 

694 for j, res in enumerate(self.train_history): 

695 for i, var in enumerate(targets[:num_plot]): 

696 if (perf := res.get('test_error')) is not None: 

697 test_record[j, i] = perf[var] 

698 

699 total_overhead = 0.0 

700 total_model_wall_time = 0.0 

701 t_start = time.time() 

702 while True: 

703 # Adaptive refinement step 

704 t_iter_start = time.time() 

705 train_result = self.refine(targets=targets, num_refine=num_refine, update_bounds=update_bounds, 

706 executor=executor, weight_fcns=weight_fcns) 

707 if train_result['component'] is None: 

708 self._print_title_str('Termination criteria reached: No candidates left to refine') 

709 break 

710 

711 # Keep track of algorithmic overhead (before and after call_model for this iteration) 

712 m_start, m_end = self[train_result['component']].get_model_timestamps() # Start and end of call_model 

713 if m_start is not None and m_end is not None: 

714 train_result['overhead_s'] = (m_start - t_iter_start) + (time.time() - m_end) 

715 train_result['model_s'] = m_end - m_start 

716 else: 

717 train_result['overhead_s'] = time.time() - t_iter_start 

718 train_result['model_s'] = 0.0 

719 total_overhead += train_result['overhead_s'] 

720 total_model_wall_time += train_result['model_s'] 

721 

722 curr_error = train_result['added_error'] 

723 

724 # Plot progress of error indicator 

725 if self.root_dir is not None: 

726 err_record.append(curr_error) 

727 

728 if plot_interval > 0 and self.refine_level % plot_interval == 0: 

729 err_ax.clear(); err_ax.set_yscale('log'); err_ax.grid() 

730 err_ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 

731 err_ax.plot(err_record, '-k') 

732 err_ax.set_xlabel('Iteration'); err_ax.set_ylabel('Relative error indicator') 

733 err_fig.savefig(str(Path(self.root_dir) / 'error_indicator.pdf'), format='pdf', bbox_inches='tight') 

734 

735 # Save performance on a test set 

736 if xtest is not None and ytest is not None: 

737 # don't compute if components are uninitialized 

738 perf = self.test_set_performance(xtest, ytest) if self.refine_level + 1 >= start_test_check else ( 

739 {str(var): np.nan for var in ytest if COORDS_STR_ID not in var}) 

740 train_result['test_error'] = perf.copy() 

741 

742 if self.root_dir is not None: 

743 test_record = np.vstack((test_record, np.array([perf[var] for var in targets[:num_plot]]))) 

744 

745 if plot_interval > 0 and self.refine_level % plot_interval == 0: 

746 for i in range(num_plot): 

747 with warnings.catch_warnings(): 

748 warnings.simplefilter("ignore", UserWarning) 

749 t_ax[0, i].clear(); t_ax[0, i].set_yscale('log'); t_ax[0, i].grid() 

750 t_ax[0, i].xaxis.set_major_locator(MaxNLocator(integer=True)) 

751 t_ax[0, i].plot(test_record[:, i], '-k') 

752 t_ax[0, i].set_title(self.outputs()[targets[i]].get_tex(units=True)) 

753 t_ax[0, i].set_xlabel('Iteration') 

754 t_ax[0, i].set_ylabel('Test set relative error' if i==0 else '') 

755 t_fig.savefig(str(Path(self.root_dir) / 'test_set_error.pdf'),format='pdf',bbox_inches='tight') 

756 

757 self.train_history.append(train_result) 

758 

759 if self.root_dir is not None and save_interval > 0 and self.refine_level % save_interval == 0: 

760 iter_name = f'{self.name}_iter{self.refine_level}' 

761 if not (pth := self.root_dir / 'surrogates' / iter_name).is_dir(): 

762 os.mkdir(pth) 

763 self.save_to_file(f'{iter_name}.yml', save_dir=pth) # Save to an iteration-specific directory 

764 

765 if cache_interval > 0 and self.refine_level % cache_interval == 0: 

766 for comp in self.components: 

767 comp.cache() 

768 

769 # Check all end conditions 

770 if self.refine_level >= max_iter: 

771 self._print_title_str(f'Termination criteria reached: Max iteration {self.refine_level}/{max_iter}') 

772 break 

773 if curr_error < max_tol: 

774 self._print_title_str(f'Termination criteria reached: relative error {curr_error} < tol {max_tol}') 

775 break 

776 if ((time.time() - t_start) / 3600.0) >= runtime_hr: 

777 t_end = time.time() 

778 actual = datetime.timedelta(seconds=t_end - t_start) 

779 target = datetime.timedelta(seconds=runtime_hr * 3600) 

780 train_surplus = ((t_end - t_start) - runtime_hr * 3600) / 3600 

781 self._print_title_str(f'Termination criteria reached: runtime {str(actual)} > {str(target)}') 

782 self.logger.info(f'Surplus wall time: {train_surplus:.3f}/{runtime_hr:.3f} hours ' 

783 f'(+{100 * train_surplus / runtime_hr:.2f}%)') 

784 break 

785 

786 self.logger.info(f'Model evaluation algorithm efficiency: ' 

787 f'{100 * total_model_wall_time / (total_model_wall_time + total_overhead):.2f}%') 

788 

789 if self.root_dir is not None: 

790 iter_name = f'{self.name}_iter{self.refine_level}' 

791 if not (pth := self.root_dir / 'surrogates' / iter_name).is_dir(): 

792 os.mkdir(pth) 

793 self.save_to_file(f'{iter_name}.yml', save_dir=pth) 

794 

795 if xtest is not None and ytest is not None: 

796 self._save_test_set((xtest, ytest)) 

797 

798 self.logger.info(f'Final system surrogate: \n {self}') 

799 

800 def test_set_performance(self, xtest: Dataset, ytest: Dataset, index_set='test') -> Dataset: 

801 """Compute the relative L2 error on a test set for the given target output variables. 

802 

803 :param xtest: `dict` of test set input samples (unnormalized) 

804 :param ytest: `dict` of test set output samples (unnormalized) 

805 :param index_set: index set to use for prediction (defaults to 'train') 

806 :returns: `dict` of relative L2 errors for each target output variable 

807 """ 

808 targets = [var for var in ytest.keys() if COORDS_STR_ID not in var and var in self.outputs()] 

809 coords = {var: ytest[var] for var in ytest if COORDS_STR_ID in var} 

810 xtest = to_surrogate_dataset(xtest, self.inputs(), del_fields=True)[0] 

811 ysurr = self.predict(xtest, index_set=index_set, targets=targets) 

812 ysurr = to_model_dataset(ysurr, self.outputs(), del_latent=True, **coords)[0] 

813 perf = {} 

814 for var in targets: 

815 # Handle relative error for object arrays (field qtys) 

816 ytest_obj = np.issubdtype(ytest[var].dtype, np.object_) 

817 ysurr_obj = np.issubdtype(ysurr[var].dtype, np.object_) 

818 if ytest_obj or ysurr_obj: 

819 _iterable = np.ndindex(ytest[var].shape) if ytest_obj else np.ndindex(ysurr[var].shape) 

820 num, den = [], [] 

821 for index in _iterable: 

822 pred, targ = ysurr[var][index], ytest[var][index] 

823 num.append((pred - targ)**2) 

824 den.append(targ ** 2) 

825 perf[var] = float(np.sqrt(sum([np.sum(n) for n in num]) / sum([np.sum(d) for d in den]))) 

826 else: 

827 perf[var] = float(relative_error(ysurr[var], ytest[var])) 

828 

829 return perf 

830 

831 def refine(self, targets: list = None, num_refine: int = 100, update_bounds: bool = True, executor: Executor = None, 

832 weight_fcns: dict[str, callable] | Literal['pdf'] | None = 'pdf') -> TrainIteration: 

833 """Perform a single adaptive refinement step on the system surrogate. 

834 

835 :param targets: list of system output variables to focus refinement on, use all outputs if not specified 

836 :param num_refine: number of input samples to compute error indicators on 

837 :param update_bounds: whether to continuously update coupling variable bounds during refinement 

838 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations 

839 :param weight_fcns: weight functions for choosing new training data for each input variable; defaults to 

840 the PDFs of each variable. If None, then no weighting is applied. 

841 :returns: `dict` of the refinement results indicating the chosen component and candidate index 

842 """ 

843 self._print_title_str(f'Refining system surrogate: iteration {self.refine_level + 1}') 

844 targets = targets or self.outputs() 

845 

846 # Check for uninitialized components and refine those first 

847 for comp in self.components: 

848 if len(comp.active_set) == 0 and comp.has_surrogate: 

849 alpha_star = (0,) * len(comp.model_fidelity) 

850 beta_star = (0,) * len(comp.max_beta) 

851 self.logger.info(f"Initializing component {comp.name}: adding {(alpha_star, beta_star)} to active set") 

852 model_dir = (pth / 'components' / comp.name) if (pth := self.root_dir) is not None else None 

853 comp.activate_index(alpha_star, beta_star, model_dir=model_dir, executor=executor, 

854 weight_fcns=weight_fcns) 

855 num_evals = comp.get_cost(alpha_star, beta_star) 

856 cost_star = max(1., comp.model_costs.get(alpha_star, 1.) * num_evals) # Cpu time (s) 

857 err_star = np.nan 

858 return {'component': comp.name, 'alpha': alpha_star, 'beta': beta_star, 'num_evals': int(num_evals), 

859 'added_cost': float(cost_star), 'added_error': float(err_star)} 

860 

861 # Compute entire integrated-surrogate on a random test set for global system QoI error estimation 

862 x_samples = self.sample_inputs(num_refine) 

863 y_curr = self.predict(x_samples, index_set='train', targets=targets) 

864 _combine_latent_arrays(y_curr) 

865 coupling_vars = {k: v for k, v in self.coupling_variables().items() if k in y_curr} 

866 

867 y_min, y_max = None, None 

868 if update_bounds: 

869 y_min = {var: np.nanmin(y_curr[var], axis=0, keepdims=True) for var in coupling_vars} # (1, ydim) 

870 y_max = {var: np.nanmax(y_curr[var], axis=0, keepdims=True) for var in coupling_vars} # (1, ydim) 

871 

872 # Find the candidate surrogate with the largest error indicator 

873 error_max, error_indicator = -np.inf, -np.inf 

874 comp_star, alpha_star, beta_star, err_star, cost_star = None, None, None, -np.inf, 0 

875 for comp in self.components: 

876 if not comp.has_surrogate: # Skip analytic models that don't need a surrogate 

877 continue 

878 

879 self.logger.info(f"Estimating error for component '{comp.name}'...") 

880 

881 if len(comp.candidate_set) > 0: 

882 candidates = list(comp.candidate_set) 

883 if executor is None: 

884 ret = [self.predict(x_samples, targets=targets, index_set={comp.name: {(alpha, beta)}}, 

885 incremental={comp.name: True}) 

886 for alpha, beta in candidates] 

887 else: 

888 temp_buffer = self._remove_unpickleable() 

889 futures = [executor.submit(self.predict, x_samples, targets=targets, 

890 index_set={comp.name: {(alpha, beta)}}, incremental={comp.name: True}) 

891 for alpha, beta in candidates] 

892 wait(futures, timeout=None, return_when=ALL_COMPLETED) 

893 ret = [f.result() for f in futures] 

894 self._restore_unpickleable(temp_buffer) 

895 

896 for i, y_cand in enumerate(ret): 

897 alpha, beta = candidates[i] 

898 _combine_latent_arrays(y_cand) 

899 error = {} 

900 for var, arr in y_cand.items(): 

901 if var in targets: 

902 error[var] = relative_error(arr, y_curr[var], skip_nan=True) 

903 

904 if update_bounds and var in coupling_vars: 

905 y_min[var] = np.nanmin(np.concatenate((y_min[var], arr), axis=0), axis=0, keepdims=True) 

906 y_max[var] = np.nanmax(np.concatenate((y_max[var], arr), axis=0), axis=0, keepdims=True) 

907 

908 delta_error = np.nanmax([np.nanmax(error[var]) for var in error]) # Max error over all target QoIs 

909 num_evals = comp.get_cost(alpha, beta) 

910 delta_work = max(1., comp.model_costs.get(alpha, 1.) * num_evals) # Cpu time (s) 

911 error_indicator = delta_error / delta_work 

912 

913 self.logger.info(f"Candidate multi-index: {(alpha, beta)}. Relative error: {delta_error}. " 

914 f"Error indicator: {error_indicator}.") 

915 

916 if error_indicator > error_max: 

917 error_max = error_indicator 

918 comp_star, alpha_star, beta_star, err_star, cost_star = ( 

919 comp.name, alpha, beta, delta_error, delta_work) 

920 else: 

921 self.logger.info(f"Component '{comp.name}' has no available candidates left!") 

922 

923 # Update all coupling variable ranges 

924 if update_bounds: 

925 for var in coupling_vars.values(): 

926 if np.all(~np.isnan(y_min[var])) and np.all(~np.isnan(y_max[var])): 

927 if var.compression is not None: 

928 new_domain = list(zip(np.squeeze(y_min[var], axis=0).tolist(), 

929 np.squeeze(y_max[var], axis=0).tolist())) 

930 var.update_domain(new_domain) 

931 else: 

932 new_domain = (y_min[var][0], y_max[var][0]) 

933 var.update_domain(var.denormalize(new_domain)) # bds will be in norm space from predict() call 

934 

935 # Add the chosen multi-index to the chosen component 

936 if comp_star is not None: 

937 self.logger.info(f"Candidate multi-index {(alpha_star, beta_star)} chosen for component '{comp_star}'.") 

938 model_dir = (pth / 'components' / comp_star) if (pth := self.root_dir) is not None else None 

939 self[comp_star].activate_index(alpha_star, beta_star, model_dir=model_dir, executor=executor, 

940 weight_fcns=weight_fcns) 

941 num_evals = self[comp_star].get_cost(alpha_star, beta_star) 

942 else: 

943 self.logger.info(f"No candidates left for refinement, iteration: {self.refine_level}") 

944 num_evals = 0 

945 

946 # Return the results of the refinement step 

947 return {'component': comp_star, 'alpha': alpha_star, 'beta': beta_star, 'num_evals': int(num_evals), 

948 'added_cost': float(cost_star), 'added_error': float(err_star)} 

949 

950 def predict(self, x: dict | Dataset, 

951 max_fpi_iter: int = 100, 

952 anderson_mem: int = 10, 

953 fpi_tol: float = 1e-10, 

954 use_model: str | tuple | dict = None, 

955 model_dir: str | Path = None, 

956 verbose: bool = False, 

957 index_set: dict[str: IndexSet | Literal['train', 'test']] = 'test', 

958 misc_coeff: dict[str: MiscTree] = None, 

959 normalized_inputs: bool = True, 

960 incremental: dict[str, bool] = False, 

961 targets: list[str] = None, 

962 executor: Executor = None, 

963 var_shape: dict[str, tuple] = None) -> Dataset: 

964 """Evaluate the system surrogate at inputs `x`. Return `y = system(x)`. 

965 

966 !!! Warning "Computing the true model with feedback loops" 

967 You can use this function to predict outputs for your MD system using the full-order models rather than the 

968 surrogate, by specifying `use_model`. This is convenient because the `System` manages all the 

969 coupled information flow between models automatically. However, it is *highly* recommended to not use 

970 the full model if your system contains feedback loops. The FPI nonlinear solver would be infeasible using 

971 anything more computationally demanding than the surrogate. 

972 

973 :param x: `dict` of input samples for each variable in the system 

974 :param max_fpi_iter: the limit on convergence for the fixed-point iteration routine 

975 :param anderson_mem: hyperparameter for tuning the convergence of FPI with anderson acceleration 

976 :param fpi_tol: tolerance limit for convergence of fixed-point iteration 

977 :param use_model: 'best'=highest-fidelity, 'worst'=lowest-fidelity, tuple=specific fidelity, None=surrogate, 

978 specify a `dict` of the above to assign different model fidelities for diff components 

979 :param model_dir: directory to save model outputs if `use_model` is specified 

980 :param verbose: whether to print out iteration progress during execution 

981 :param index_set: `dict(comp=[indices])` to override the active set for a component, defaults to using the 

982 `test` set for every component. Can also specify `train` for any component or a valid 

983 `IndexSet` object. If `incremental` is specified, will be overwritten with `train`. 

984 :param misc_coeff: `dict(comp=MiscTree)` to override the default coefficients for a component, passes through 

985 along with `index_set` and `incremental` to `comp.predict()`. 

986 :param normalized_inputs: true if the passed inputs are compressed/normalized for surrogate evaluation 

987 (default), such as inputs returned by `sample_inputs`. Set to `False` if you are 

988 passing inputs as the true models would expect them instead (i.e. not normalized). 

989 :param incremental: whether to add `index_set` to the current active set for each component (temporarily); 

990 this will set `index_set='train'` for all other components (since incremental will 

991 augment the "training" active sets, not the "testing" candidate sets) 

992 :param targets: list of output variables to return, defaults to returning all system outputs 

993 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations 

994 :param var_shape: (Optional) `dict` of shapes for field quantity inputs in `x` -- you would only specify this 

995 if passing field qtys directly to the models (i.e. not using `sample_inputs`) 

996 :returns: `dict` of output variables - the surrogate approximation of the system outputs (or the true model) 

997 """ 

998 # Format inputs and allocate space 

999 var_shape = var_shape or {} 

1000 x, loop_shape = format_inputs(x, var_shape=var_shape) # {'x': (N, *var_shape)} 

1001 y = {} 

1002 all_inputs = ChainMap(x, y) # track all inputs (including coupling vars in y) 

1003 N = int(np.prod(loop_shape)) 

1004 t1 = 0 

1005 output_dir = None 

1006 norm_status = {var: normalized_inputs for var in x} # keep track of whether inputs are normalized or not 

1007 graph = self.graph() 

1008 

1009 # Keep track of what outputs are computed 

1010 is_computed = {} 

1011 for var in (targets or self.outputs()): 

1012 if (v := self.outputs().get(var, None)) is not None: 

1013 if v.compression is not None: 

1014 for field in v.compression.fields: 

1015 is_computed[field] = False 

1016 else: 

1017 is_computed[var] = False 

1018 

1019 def _set_default(struct: dict, default=None): 

1020 """Helper to set a default value for each component key in a `dict`. Ensures all components have a value.""" 

1021 if struct is not None: 

1022 if not isinstance(struct, dict): 

1023 struct = {node: struct for node in graph.nodes} # use same for each component 

1024 else: 

1025 struct = {node: default for node in graph.nodes} 

1026 return {node: struct.get(node, default) for node in graph.nodes} 

1027 

1028 # Ensure use_model, index_set, and incremental are specified for each component model 

1029 use_model = _set_default(use_model, None) 

1030 incremental = _set_default(incremental, False) # default to train if incremental anywhere 

1031 index_set = _set_default(index_set, 'train' if any([incremental[node] for node in graph.nodes]) else 'test') 

1032 misc_coeff = _set_default(misc_coeff, None) 

1033 

1034 samples = _Converged(N) # track convergence of samples 

1035 

1036 def _gather_comp_inputs(comp, coupling=None): 

1037 """Helper to gather inputs for a component, making sure they are normalized correctly. Any coupling 

1038 variables passed in will be used in preference over `all_inputs`. 

1039 """ 

1040 # Will access but not modify: all_inputs, use_model, norm_status 

1041 field_coords = {} 

1042 comp_input = {} 

1043 coupling = coupling or {} 

1044 

1045 # Take coupling variables as a priority 

1046 comp_input.update({var: np.copy(arr[samples.curr_idx, ...]) for var, arr in 

1047 coupling.items() if str(var).split(LATENT_STR_ID)[0] in comp.inputs}) 

1048 # Gather all other inputs 

1049 for var, arr in all_inputs.items(): 

1050 var_id = str(var).split(LATENT_STR_ID)[0] 

1051 if var_id in comp.inputs and var not in coupling: 

1052 comp_input[var] = np.copy(arr[samples.curr_idx, ...]) 

1053 

1054 # Gather field coordinates 

1055 for var in comp.inputs: 

1056 coords_str = f'{var}{COORDS_STR_ID}' 

1057 if (coords := all_inputs.get(coords_str)) is not None: 

1058 field_coords[coords_str] = coords[samples.curr_idx, ...] 

1059 elif (coords := comp.model_kwargs.get(coords_str)) is not None: 

1060 field_coords[coords_str] = coords 

1061 

1062 # Gather extra fields (will never be in coupling since field couplings should always be latent coeff) 

1063 for var in comp.inputs: 

1064 if var not in comp_input and var.compression is not None: 

1065 for field in var.compression.fields: 

1066 if field in all_inputs: 

1067 comp_input[field] = np.copy(all_inputs[field][samples.curr_idx, ...]) 

1068 

1069 call_model = use_model.get(comp.name, None) is not None 

1070 

1071 # Make sure we format all inputs for model evaluation (i.e. denormalize) 

1072 if call_model: 

1073 norm_inputs = {var: arr for var, arr in comp_input.items() if norm_status[var]} 

1074 if len(norm_inputs) > 0: 

1075 denorm_inputs, fc = to_model_dataset(norm_inputs, comp.inputs, del_latent=True, **field_coords) 

1076 for var in norm_inputs: 

1077 del comp_input[var] 

1078 field_coords.update(fc) 

1079 comp_input.update(denorm_inputs) 

1080 

1081 # Otherwise, make sure we format inputs for surrogate evaluation (i.e. normalize) 

1082 else: 

1083 denorm_inputs = {var: arr for var, arr in comp_input.items() if not norm_status[var]} 

1084 if len(denorm_inputs) > 0: 

1085 norm_inputs, _ = to_surrogate_dataset(denorm_inputs, comp.inputs, del_fields=True, **field_coords) 

1086 

1087 for var in denorm_inputs: 

1088 del comp_input[var] 

1089 comp_input.update(norm_inputs) 

1090 

1091 return comp_input, field_coords, call_model 

1092 

1093 # Convert system into DAG by grouping strongly-connected-components 

1094 dag = nx.condensation(graph) 

1095 

1096 # Compute component models in topological order 

1097 for supernode in nx.topological_sort(dag): 

1098 if np.all(list(is_computed.values())): 

1099 break # Exit early if all selected return qois are computed 

1100 

1101 scc = [n for n in dag.nodes[supernode]['members']] 

1102 samples.reset_convergence() 

1103 

1104 # Compute single component feedforward output (no FPI needed) 

1105 if len(scc) == 1: 

1106 if verbose: 

1107 self.logger.info(f"Running component '{scc[0]}'...") 

1108 t1 = time.time() 

1109 

1110 # Gather inputs 

1111 comp = self[scc[0]] 

1112 comp_input, field_coords, call_model = _gather_comp_inputs(comp) 

1113 

1114 # Compute outputs 

1115 if model_dir is not None: 

1116 output_dir = Path(model_dir) / scc[0] 

1117 if not output_dir.exists(): 

1118 os.mkdir(output_dir) 

1119 comp_output = comp.predict(comp_input, use_model=use_model.get(scc[0]), model_dir=output_dir, 

1120 index_set=index_set.get(scc[0]), incremental=incremental.get(scc[0]), 

1121 misc_coeff=misc_coeff.get(scc[0]), executor=executor, **field_coords) 

1122 

1123 for var, arr in comp_output.items(): 

1124 if var == 'errors': 

1125 if y.get(var) is None: 

1126 y.setdefault(var, np.full((N,), None, dtype=object)) 

1127 global_indices = np.arange(N)[samples.curr_idx] 

1128 

1129 for local_idx, err_info in arr.items(): 

1130 global_idx = int(global_indices[local_idx]) 

1131 err_info['index'] = global_idx 

1132 y[var][global_idx] = err_info 

1133 continue 

1134 

1135 is_numeric = np.issubdtype(arr.dtype, np.number) 

1136 if is_numeric: # for scalars or vectorized field quantities 

1137 output_shape = arr.shape[1:] 

1138 if y.get(var) is None: 

1139 y.setdefault(var, np.full((N, *output_shape), np.nan)) 

1140 y[var][samples.curr_idx, ...] = arr 

1141 

1142 else: # for fields returned as object arrays 

1143 if y.get(var) is None: 

1144 y.setdefault(var, np.full((N,), None, dtype=object)) 

1145 y[var][samples.curr_idx] = arr 

1146 

1147 # Update valid indices and status for component outputs 

1148 for var in comp_output: 

1149 if str(var).split(LATENT_STR_ID)[0] in comp.outputs: 

1150 is_numeric = np.issubdtype(y[var].dtype, np.number) 

1151 new_valid = ~np.any(np.isnan(y[var]), axis=tuple(range(1, y[var].ndim))) if is_numeric else ( 

1152 [False if arr is None else ~np.any(np.isnan(arr)) for i, arr in enumerate(y[var])] 

1153 ) 

1154 samples.valid_idx = np.logical_and(samples.valid_idx, new_valid) 

1155 

1156 is_computed[str(var).split(LATENT_STR_ID)[0]] = True 

1157 norm_status[var] = not call_model 

1158 

1159 if verbose: 

1160 self.logger.info(f"Component '{scc[0]}' completed. Runtime: {time.time() - t1} s") 

1161 

1162 # Handle FPI for SCCs with more than one component 

1163 else: 

1164 # Set the initial guess for all coupling vars (middle of domain) 

1165 scc_inputs = ChainMap(*[self[comp].inputs for comp in scc]) 

1166 scc_outputs = ChainMap(*[self[comp].outputs for comp in scc]) 

1167 coupling_vars = [scc_inputs.get(var) for var in (scc_inputs.keys() - x.keys()) if var in scc_outputs] 

1168 coupling_prev = {} 

1169 for var in coupling_vars: 

1170 if (domain := var.get_domain()) is None: 

1171 raise RuntimeError(f"Coupling variable '{var}' has an empty domain. All coupling variables " 

1172 f"require a domain for the fixed-point iteration (FPI) solver.") 

1173 

1174 if isinstance(domain, list): # Latent coefficients are the coupling variables 

1175 for i, d in enumerate(domain): 

1176 lb, ub = d 

1177 coupling_prev[f'{var.name}{LATENT_STR_ID}{i}'] = np.broadcast_to((lb + ub) / 2, (N,)).copy() 

1178 norm_status[f'{var.name}{LATENT_STR_ID}{i}'] = True 

1179 else: 

1180 lb, ub = var.normalize(domain) 

1181 shape = (N,) + (1,) * len(var_shape.get(var, ())) 

1182 coupling_prev[var] = np.broadcast_to((lb + ub) / 2, shape).copy() 

1183 norm_status[var] = True 

1184 

1185 residual_hist = deque(maxlen=anderson_mem) 

1186 coupling_hist = deque(maxlen=anderson_mem) 

1187 

1188 def _end_conditions_met(): 

1189 """Helper to compute residual, update history, and check end conditions.""" 

1190 residual = {} 

1191 converged_idx = np.full(N, True) 

1192 for var in coupling_prev: 

1193 residual[var] = y[var] - coupling_prev[var] 

1194 var_conv = np.all(np.abs(residual[var]) <= fpi_tol, axis=tuple(range(1, residual[var].ndim))) 

1195 converged_idx = np.logical_and(converged_idx, var_conv) 

1196 samples.valid_idx = np.logical_and(samples.valid_idx, ~np.isnan(coupling_prev[var])) 

1197 samples.converged_idx = np.logical_or(samples.converged_idx, converged_idx) 

1198 

1199 for var in coupling_prev: 

1200 coupling_prev[var][samples.curr_idx, ...] = y[var][samples.curr_idx, ...] 

1201 residual_hist.append(copy.deepcopy(residual)) 

1202 coupling_hist.append(copy.deepcopy(coupling_prev)) 

1203 

1204 if int(np.sum(samples.curr_idx)) == 0: 

1205 if verbose: 

1206 self.logger.info(f'FPI converged for SCC {scc} in {k} iterations with tol ' 

1207 f'{fpi_tol}. Final time: {time.time() - t1} s') 

1208 return True 

1209 

1210 max_error = np.max([np.max(np.abs(res[samples.curr_idx, ...])) for res in residual.values()]) 

1211 if verbose: 

1212 self.logger.info(f'FPI iter: {k}. Max residual: {max_error}. Time: {time.time() - t1} s') 

1213 

1214 if k >= max_fpi_iter: 

1215 self.logger.warning(f'FPI did not converge in {max_fpi_iter} iterations for SCC {scc}: ' 

1216 f'{max_error} > tol {fpi_tol}. Some samples will be returned as NaN.') 

1217 for var in coupling_prev: 

1218 y[var][~samples.converged_idx, ...] = np.nan 

1219 samples.valid_idx = np.logical_and(samples.valid_idx, samples.converged_idx) 

1220 return True 

1221 else: 

1222 return False 

1223 

1224 # Main FPI loop 

1225 if verbose: 

1226 self.logger.info(f"Initializing FPI for SCC {scc} ...") 

1227 t1 = time.time() 

1228 k = 0 

1229 while True: 

1230 for node in scc: 

1231 # Gather inputs from exogenous and coupling sources 

1232 comp = self[node] 

1233 comp_input, kwds, call_model = _gather_comp_inputs(comp, coupling=coupling_prev) 

1234 

1235 # Compute outputs (just don't do this FPI with expensive real models, please..) 

1236 comp_output = comp.predict(comp_input, use_model=use_model.get(node), model_dir=None, 

1237 index_set=index_set.get(node), incremental=incremental.get(node), 

1238 misc_coeff=misc_coeff.get(node), executor=executor, **kwds) 

1239 

1240 for var, arr in comp_output.items(): 

1241 if var == 'errors': 

1242 if y.get(var) is None: 

1243 y.setdefault(var, np.full((N,), None, dtype=object)) 

1244 global_indices = np.arange(N)[samples.curr_idx] 

1245 

1246 for local_idx, err_info in arr.items(): 

1247 global_idx = int(global_indices[local_idx]) 

1248 err_info['index'] = global_idx 

1249 y[var][global_idx] = err_info 

1250 continue 

1251 

1252 if np.issubdtype(arr.dtype, np.number): # scalars and vectorized field quantities 

1253 output_shape = arr.shape[1:] 

1254 if y.get(var) is not None: 

1255 if output_shape != y.get(var).shape[1:]: 

1256 y[var] = _merge_shapes((N, *output_shape), y[var]) 

1257 else: 

1258 y.setdefault(var, np.full((N, *output_shape), np.nan)) 

1259 y[var][samples.curr_idx, ...] = arr 

1260 else: # fields returned as object arrays 

1261 if y.get(var) is None: 

1262 y.setdefault(var, np.full((N,), None, dtype=object)) 

1263 y[var][samples.curr_idx] = arr 

1264 

1265 if str(var).split(LATENT_STR_ID)[0] in comp.outputs: 

1266 norm_status[var] = not call_model 

1267 is_computed[str(var).split(LATENT_STR_ID)[0]] = True 

1268 

1269 # Compute residual and check end conditions 

1270 if _end_conditions_met(): 

1271 break 

1272 

1273 # Skip anderson acceleration on first iteration 

1274 if k == 0: 

1275 k += 1 

1276 continue 

1277 

1278 # Iterate with anderson acceleration (only iterate on samples that are not yet converged) 

1279 N_curr = int(np.sum(samples.curr_idx)) 

1280 mk = len(residual_hist) # Max of anderson mem 

1281 var_shapes = [] 

1282 xdims = [] 

1283 for var in coupling_prev: 

1284 shape = coupling_prev[var].shape[1:] 

1285 var_shapes.append(shape) 

1286 xdims.append(int(np.prod(shape))) 

1287 N_couple = int(np.sum(xdims)) 

1288 res_snap = np.empty((N_curr, N_couple, mk)) # Shortened snapshot of residual history 

1289 coupling_snap = np.empty((N_curr, N_couple, mk)) # Shortened snapshot of coupling history 

1290 for i, (coupling_iter, residual_iter) in enumerate(zip(coupling_hist, residual_hist)): 

1291 start_idx = 0 

1292 for j, var in enumerate(coupling_prev): 

1293 end_idx = start_idx + xdims[j] 

1294 coupling_snap[:, start_idx:end_idx, i] = coupling_iter[var][samples.curr_idx, ...].reshape((N_curr, -1)) # noqa: E501 

1295 res_snap[:, start_idx:end_idx, i] = residual_iter[var][samples.curr_idx, ...].reshape((N_curr, -1)) # noqa: E501 

1296 start_idx = end_idx 

1297 C = np.ones((N_curr, 1, mk)) 

1298 b = np.zeros((N_curr, N_couple, 1)) 

1299 d = np.ones((N_curr, 1, 1)) 

1300 alpha = np.expand_dims(constrained_lls(res_snap, b, C, d), axis=-3) # (..., 1, mk, 1) 

1301 coupling_new = np.squeeze(coupling_snap[:, :, np.newaxis, :] @ alpha, axis=(-1, -2)) 

1302 start_idx = 0 

1303 for j, var in enumerate(coupling_prev): 

1304 end_idx = start_idx + xdims[j] 

1305 coupling_prev[var][samples.curr_idx, ...] = coupling_new[:, start_idx:end_idx].reshape((N_curr, *var_shapes[j])) # noqa: E501 

1306 start_idx = end_idx 

1307 k += 1 

1308 

1309 # Return all component outputs; samples that didn't converge during FPI are left as np.nan 

1310 return format_outputs(y, loop_shape) 

1311 

1312 def __call__(self, *args, **kwargs): 

1313 """Convenience wrapper to allow calling as `ret = System(x)`.""" 

1314 return self.predict(*args, **kwargs) 

1315 

1316 def __eq__(self, other): 

1317 if not isinstance(other, System): 

1318 return False 

1319 return (self.components == other.components and 

1320 self.name == other.name and 

1321 self.train_history == other.train_history) 

1322 

1323 def __getitem__(self, component: str) -> Component: 

1324 """Convenience method to get a `Component` object from the `System`. 

1325 

1326 :param component: the name of the component to get 

1327 :returns: the `Component` object 

1328 """ 

1329 return self.get_component(component) 

1330 

1331 def get_component(self, comp_name: str) -> Component: 

1332 """Return the `Component` object for this component. 

1333 

1334 :param comp_name: name of the component to return 

1335 :raises KeyError: if the component does not exist 

1336 :returns: the `Component` object 

1337 """ 

1338 if comp_name.lower() == 'system': 

1339 return self 

1340 else: 

1341 for comp in self.components: 

1342 if comp.name == comp_name: 

1343 return comp 

1344 raise KeyError(f"Component '{comp_name}' not found in system.") 

1345 

1346 def _print_title_str(self, title_str: str): 

1347 """Log an important message.""" 

1348 self.logger.info('-' * int(len(title_str)/2) + title_str + '-' * int(len(title_str)/2)) 

1349 

1350 def _remove_unpickleable(self) -> dict: 

1351 """Remove and return unpickleable attributes before pickling (just the logger).""" 

1352 stdout = False 

1353 log_file = None 

1354 if self._logger is not None: 

1355 for handler in self._logger.handlers: 

1356 if isinstance(handler, logging.StreamHandler): 

1357 stdout = True 

1358 break 

1359 for handler in self._logger.handlers: 

1360 if isinstance(handler, logging.FileHandler): 

1361 log_file = handler.baseFilename 

1362 break 

1363 

1364 buffer = {'log_stdout': stdout, 'log_file': log_file} 

1365 self.logger = None 

1366 return buffer 

1367 

1368 def _restore_unpickleable(self, buffer: dict): 

1369 """Restore the unpickleable attributes from `buffer` after unpickling.""" 

1370 self.set_logger(log_file=buffer.get('log_file', None), stdout=buffer.get('log_stdout', None)) 

1371 

1372 def _get_test_set(self, test_set: str | Path | tuple = None) -> tuple: 

1373 """Try to load a test set from the root directory if it exists.""" 

1374 if isinstance(test_set, tuple): 

1375 return test_set # (xtest, ytest) 

1376 else: 

1377 ret = (None, None) 

1378 if test_set is not None: 

1379 test_set = Path(test_set) 

1380 elif self.root_dir is not None: 

1381 test_set = self.root_dir / 'test_set.pkl' 

1382 

1383 if test_set is not None: 

1384 if test_set.exists(): 

1385 with open(test_set, 'rb') as fd: 

1386 data = pickle.load(fd) 

1387 ret = data['test_set'] 

1388 

1389 return ret 

1390 

1391 def _save_test_set(self, test_set: tuple = None): 

1392 """Save the test set to the root directory if possible.""" 

1393 if self.root_dir is not None and test_set is not None: 

1394 test_file = self.root_dir / 'test_set.pkl' 

1395 if not test_file.exists(): 

1396 with open(test_file, 'wb') as fd: 

1397 pickle.dump({'test_set': test_set}, fd) 

1398 

1399 def save_to_file(self, filename: str, save_dir: str | Path = None, dumper=None): 

1400 """Save surrogate to file. Defaults to `root/surrogates/filename.yml` with the default yaml encoder. 

1401 

1402 :param filename: the name of the save file 

1403 :param save_dir: the directory to save the file to (defaults to `root/surrogates` or `cwd()`) 

1404 :param dumper: the encoder to use (defaults to the `amisc` yaml encoder) 

1405 """ 

1406 from amisc import YamlLoader 

1407 encoder = dumper or YamlLoader 

1408 save_dir = save_dir or self.root_dir or Path.cwd() 

1409 if Path(save_dir) == self.root_dir: 

1410 save_dir = self.root_dir / 'surrogates' 

1411 encoder.dump(self, Path(save_dir) / filename) 

1412 

1413 @staticmethod 

1414 def load_from_file(filename: str | Path, root_dir: str | Path = None, loader=None): 

1415 """Load surrogate from file. Defaults to yaml loading. Tries to infer `amisc` directory structure. 

1416 

1417 :param filename: the name of the load file 

1418 :param root_dir: set this as the surrogate's root directory (will try to load from `amisc_` fmt by default) 

1419 :param loader: the encoder to use (defaults to the `amisc` yaml encoder) 

1420 """ 

1421 from amisc import YamlLoader 

1422 encoder = loader or YamlLoader 

1423 system = encoder.load(filename) 

1424 root_dir = root_dir or system.root_dir 

1425 

1426 # Try to infer amisc_root/surrogates/iter/filename structure 

1427 if root_dir is None: 

1428 parts = Path(filename).resolve().parts 

1429 if len(parts) > 1 and parts[-2].startswith('amisc_'): 

1430 root_dir = Path(filename).resolve().parent 

1431 elif len(parts) > 2 and parts[-3].startswith('amisc_'): 

1432 root_dir = Path(filename).resolve().parent.parent 

1433 elif len(parts) > 3 and parts[-4].startswith('amisc_'): 

1434 root_dir = Path(filename).resolve().parent.parent.parent 

1435 

1436 system.root_dir = root_dir 

1437 return system 

1438 

1439 def clear(self): 

1440 """Clear all surrogate model data and reset the system.""" 

1441 for comp in self.components: 

1442 comp.clear() 

1443 self.train_history.clear() 

1444 

1445 def plot_slice(self, inputs: list[str] = None, 

1446 outputs: list[str] = None, 

1447 num_steps: int = 20, 

1448 show_surr: bool = True, 

1449 show_model: str | tuple | list = None, 

1450 save_dir: str | Path = None, 

1451 executor: Executor = None, 

1452 nominal: dict[str: float] = None, 

1453 random_walk: bool = False, 

1454 from_file: str | Path = None, 

1455 subplot_size_in: float = 3.): 

1456 """Helper function to plot 1d slices of the surrogate and/or model outputs over the inputs. A single 

1457 "slice" works by smoothly stepping from the lower bound of an input to its upper bound, while holding all other 

1458 inputs constant at their nominal values (or smoothly varying them if `random_walk=True`). 

1459 This function is useful for visualizing the behavior of the system surrogate and/or model(s) over a 

1460 single input variable at a time. 

1461 

1462 :param inputs: list of input variables to take 1d slices of (defaults to first 3 in `System.inputs`) 

1463 :param outputs: list of model output variables to plot 1d slices of (defaults to first 3 in `System.outputs`) 

1464 :param num_steps: the number of points to take in the 1d slice for each input variable; this amounts to a total 

1465 of `num_steps*len(inputs)` model/surrogate evaluations 

1466 :param show_surr: whether to show the surrogate prediction 

1467 :param show_model: also compute and plot model predictions, `list` of ['best', 'worst', tuple(alpha), etc.] 

1468 :param save_dir: base directory to save model outputs and plots (if specified) 

1469 :param executor: a `concurrent.futures.Executor` object to parallelize model or surrogate evaluations 

1470 :param nominal: `dict` of `var->nominal` to use as constant values for all non-sliced variables (use 

1471 unnormalized values only; use `var_LATENT0` to specify nominal latent values) 

1472 :param random_walk: whether to slice in a random d-dimensional direction instead of holding all non-slice 

1473 variables const at `nominal` 

1474 :param from_file: path to a `.pkl` file to load a saved slice from disk 

1475 :param subplot_size_in: side length size of each square subplot in inches 

1476 :returns: `fig, ax` with `len(inputs)` by `len(outputs)` subplots 

1477 """ 

1478 # Manage loading important quantities from file (if provided) 

1479 input_slices, output_slices_model, output_slices_surr = None, None, None 

1480 if from_file is not None: 

1481 with open(Path(from_file), 'rb') as fd: 

1482 slice_data = pickle.load(fd) 

1483 inputs = slice_data['inputs'] # Must use same input slices as save file 

1484 show_model = slice_data['show_model'] # Must use same model data as save file 

1485 outputs = slice_data.get('outputs') if outputs is None else outputs 

1486 input_slices = slice_data['input_slices'] 

1487 save_dir = None # Don't run or save any models if loading from file 

1488 

1489 # Set default values (take up to the first 3 inputs by default) 

1490 all_inputs = self.inputs() 

1491 all_outputs = self.outputs() 

1492 rand_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4)) 

1493 if save_dir is not None: 

1494 os.mkdir(Path(save_dir) / f'slice_{rand_id}') 

1495 if nominal is None: 

1496 nominal = dict() 

1497 inputs = all_inputs[:3] if inputs is None else inputs 

1498 outputs = all_outputs[:3] if outputs is None else outputs 

1499 

1500 if show_model is not None and not isinstance(show_model, list): 

1501 show_model = [show_model] 

1502 

1503 # Handle field quantities (directly use latent variables or only the first one) 

1504 for i, var in enumerate(list(inputs)): 

1505 if LATENT_STR_ID not in str(var) and all_inputs[var].compression is not None: 

1506 inputs[i] = f'{var}{LATENT_STR_ID}0' 

1507 for i, var in enumerate(list(outputs)): 

1508 if LATENT_STR_ID not in str(var) and all_outputs[var].compression is not None: 

1509 outputs[i] = f'{var}{LATENT_STR_ID}0' 

1510 

1511 bds = all_inputs.get_domains() 

1512 xlabels = [all_inputs[var].get_tex(units=False) if LATENT_STR_ID not in str(var) else 

1513 all_inputs[str(var).split(LATENT_STR_ID)[0]].get_tex(units=False) + 

1514 f' (latent {str(var).split(LATENT_STR_ID)[1]})' for var in inputs] 

1515 

1516 ylabels = [all_outputs[var].get_tex(units=False) if LATENT_STR_ID not in str(var) else 

1517 all_outputs[str(var).split(LATENT_STR_ID)[0]].get_tex(units=False) + 

1518 f' (latent {str(var).split(LATENT_STR_ID)[1]})' for var in outputs] 

1519 

1520 # Construct slices of model inputs (if not provided) 

1521 if input_slices is None: 

1522 input_slices = {} # Each input variable with shape (num_steps, num_slice) 

1523 for i in range(len(inputs)): 

1524 if random_walk: 

1525 # Make a random straight-line walk across d-cube 

1526 r0 = self.sample_inputs((1,), use_pdf=False) 

1527 rf = self.sample_inputs((1,), use_pdf=False) 

1528 

1529 for var, bd in bds.items(): 

1530 if var == inputs[i]: 

1531 r0[var] = np.atleast_1d(bd[0]) # Start slice at this lower bound 

1532 rf[var] = np.atleast_1d(bd[1]) # Slice up to this upper bound 

1533 

1534 step_size = (rf[var] - r0[var]) / (num_steps - 1) 

1535 arr = r0[var] + step_size * np.arange(num_steps) 

1536 

1537 input_slices[var] = arr[..., np.newaxis] if input_slices.get(var) is None else ( 

1538 np.concatenate((input_slices[var], arr[..., np.newaxis]), axis=-1)) 

1539 else: 

1540 # Otherwise, only slice one variable 

1541 for var, bd in bds.items(): 

1542 nom = nominal.get(var, np.mean(bd)) if LATENT_STR_ID in str(var) else ( 

1543 all_inputs[var].normalize(nominal.get(var, all_inputs[var].get_nominal()))) 

1544 arr = np.linspace(bd[0], bd[1], num_steps) if var == inputs[i] else np.full(num_steps, nom) 

1545 

1546 input_slices[var] = arr[..., np.newaxis] if input_slices.get(var) is None else ( 

1547 np.concatenate((input_slices[var], arr[..., np.newaxis]), axis=-1)) 

1548 

1549 # Walk through each model that is requested by show_model 

1550 if show_model is not None: 

1551 if from_file is not None: 

1552 output_slices_model = slice_data['output_slices_model'] 

1553 else: 

1554 output_slices_model = list() 

1555 for model in show_model: 

1556 output_dir = None 

1557 if save_dir is not None: 

1558 output_dir = (Path(save_dir) / f'slice_{rand_id}' / 

1559 str(model).replace('{', '').replace('}', '').replace(':', '=').replace("'", '')) 

1560 os.mkdir(output_dir) 

1561 output_slices_model.append(self.predict(input_slices, use_model=model, model_dir=output_dir, 

1562 executor=executor)) 

1563 if show_surr: 

1564 output_slices_surr = self.predict(input_slices, executor=executor) \ 

1565 if from_file is None else slice_data['output_slices_surr'] 

1566 

1567 # Make len(outputs) by len(inputs) grid of subplots 

1568 fig, axs = plt.subplots(len(outputs), len(inputs), sharex='col', sharey='row', squeeze=False) 

1569 for i, output_var in enumerate(outputs): 

1570 for j, input_var in enumerate(inputs): 

1571 ax = axs[i, j] 

1572 x = input_slices[input_var][:, j] 

1573 

1574 if show_model is not None: 

1575 c = np.array([[0, 0, 0, 1], [0.5, 0.5, 0.5, 1]]) if len(show_model) <= 2 else ( 

1576 plt.get_cmap('jet')(np.linspace(0, 1, len(show_model)))) 

1577 for k in range(len(show_model)): 

1578 model_str = (str(show_model[k]).replace('{', '').replace('}', '') 

1579 .replace(':', '=').replace("'", '')) 

1580 model_ret = to_surrogate_dataset(output_slices_model[k], all_outputs)[0] 

1581 y_model = model_ret[output_var][:, j] 

1582 label = {'best': 'High-fidelity' if len(show_model) > 1 else 'Model', 

1583 'worst': 'Low-fidelity'}.get(model_str, model_str) 

1584 ax.plot(x, y_model, ls='-', c=c[k, :], label=label) 

1585 

1586 if show_surr: 

1587 y_surr = output_slices_surr[output_var][:, j] 

1588 ax.plot(x, y_surr, '--r', label='Surrogate') 

1589 

1590 ax.set_xlabel(xlabels[j] if i == len(outputs) - 1 else '') 

1591 ax.set_ylabel(ylabels[i] if j == 0 else '') 

1592 if i == 0 and j == len(inputs) - 1: 

1593 ax.legend() 

1594 fig.set_size_inches(subplot_size_in * len(inputs), subplot_size_in * len(outputs)) 

1595 fig.tight_layout() 

1596 

1597 # Save results (unless we were already loading from a save file) 

1598 if from_file is None and save_dir is not None: 

1599 fname = f'in={",".join([str(v) for v in inputs])}_out={",".join([str(v) for v in outputs])}' 

1600 fname = f'slice_rand{rand_id}_' + fname if random_walk else f'slice_nom{rand_id}_' + fname 

1601 fdir = Path(save_dir) / f'slice_{rand_id}' 

1602 fig.savefig(fdir / f'{fname}.pdf', bbox_inches='tight', format='pdf') 

1603 save_dict = {'inputs': inputs, 'outputs': outputs, 'show_model': show_model, 'show_surr': show_surr, 

1604 'nominal': nominal, 'random_walk': random_walk, 'input_slices': input_slices, 

1605 'output_slices_model': output_slices_model, 'output_slices_surr': output_slices_surr} 

1606 with open(fdir / f'{fname}.pkl', 'wb') as fd: 

1607 pickle.dump(save_dict, fd) 

1608 

1609 return fig, axs 

1610 

1611 def get_allocation(self): 

1612 """Get a breakdown of cost allocation during training. 

1613 

1614 :returns: `cost_alloc, model_cost, overhead_cost, model_evals` - the cost allocation per model/fidelity, 

1615 the model evaluation cost per iteration (in s of CPU time), the algorithmic overhead cost per 

1616 iteration, and the total number of model evaluations at each training iteration 

1617 """ 

1618 cost_alloc = dict() # Cost allocation (cpu time in s) per node and model fidelity 

1619 model_cost = [] # Cost of model evaluations (CPU time in s) per iteration 

1620 overhead_cost = [] # Algorithm overhead costs (CPU time in s) per iteration 

1621 model_evals = [] # Number of model evaluations at each training iteration 

1622 

1623 prev_cands = {comp.name: IndexSet() for comp in self.components} # empty candidate sets 

1624 

1625 # Add cumulative training costs 

1626 for train_res, active_sets, cand_sets, misc_coeff_train, misc_coeff_test in self.simulate_fit(): 

1627 comp = train_res['component'] 

1628 alpha = train_res['alpha'] 

1629 beta = train_res['beta'] 

1630 overhead = train_res['overhead_s'] 

1631 

1632 cost_alloc.setdefault(comp, dict()) 

1633 

1634 new_cands = cand_sets[comp].union({(alpha, beta)}) - prev_cands[comp] # newly computed candidates 

1635 

1636 iter_cost = 0. 

1637 iter_eval = 0 

1638 for alpha_new, beta_new in new_cands: 

1639 cost_alloc[comp].setdefault(alpha_new, 0.) 

1640 

1641 added_eval = self[comp].get_cost(alpha_new, beta_new) 

1642 single_cost = self[comp].model_costs.get(alpha_new, 1.) 

1643 

1644 iter_cost += added_eval * single_cost 

1645 iter_eval += added_eval 

1646 

1647 cost_alloc[comp][alpha_new] += added_eval * single_cost 

1648 

1649 overhead_cost.append(overhead) 

1650 model_cost.append(iter_cost) 

1651 model_evals.append(iter_eval) 

1652 prev_cands[comp] = cand_sets[comp].union({(alpha, beta)}) 

1653 

1654 return cost_alloc, np.atleast_1d(model_cost), np.atleast_1d(overhead_cost), np.atleast_1d(model_evals) 

1655 

1656 def plot_allocation(self, cmap: str = 'Blues', text_bar_width: float = 0.06, arrow_bar_width: float = 0.02): 

1657 """Plot bar charts showing cost allocation during training. 

1658 

1659 !!! Warning "Beta feature" 

1660 This has pretty good default settings, but it might look terrible for your use. Mostly provided here as 

1661 a template for making cost allocation bar charts. Please feel free to copy and edit in your own code. 

1662 

1663 :param cmap: the colormap string identifier for `plt` 

1664 :param text_bar_width: the minimum total cost fraction above which a bar will print centered model fidelity text 

1665 :param arrow_bar_width: the minimum total cost fraction above which a bar will try to print text with an arrow; 

1666 below this amount, the bar is too skinny and won't print any text 

1667 :returns: `fig, ax`, Figure and Axes objects 

1668 """ 

1669 # Get total cost 

1670 cost_alloc, model_cost, _, _ = self.get_allocation() 

1671 total_cost = np.cumsum(model_cost)[-1] 

1672 

1673 # Remove nodes with cost=0 from alloc dicts (i.e. analytical models) 

1674 remove_nodes = [] 

1675 for node, alpha_dict in cost_alloc.items(): 

1676 if len(alpha_dict) == 0: 

1677 remove_nodes.append(node) 

1678 for node in remove_nodes: 

1679 del cost_alloc[node] 

1680 

1681 # Bar chart showing cost allocation breakdown for MF system at final iteration 

1682 fig, ax = plt.subplots(figsize=(6, 5), layout='tight') 

1683 width = 0.7 

1684 x = np.arange(len(cost_alloc)) 

1685 xlabels = list(cost_alloc.keys()) # One bar for each component 

1686 cmap = plt.get_cmap(cmap) 

1687 

1688 for j, (node, alpha_dict) in enumerate(cost_alloc.items()): 

1689 bottom = 0 

1690 c_intervals = np.linspace(0, 1, len(alpha_dict)) 

1691 bars = [(alpha, cost, cost / total_cost) for alpha, cost in alpha_dict.items()] 

1692 bars = sorted(bars, key=lambda ele: ele[2], reverse=True) 

1693 for i, (alpha, cost, frac) in enumerate(bars): 

1694 p = ax.bar(x[j], frac, width, color=cmap(c_intervals[i]), linewidth=1, 

1695 edgecolor=[0, 0, 0], bottom=bottom) 

1696 bottom += frac 

1697 num_evals = round(cost / self[node].model_costs.get(alpha, 1.)) 

1698 if frac > text_bar_width: 

1699 ax.bar_label(p, labels=[f'{alpha}, {num_evals}'], label_type='center') 

1700 elif frac > arrow_bar_width: 

1701 xy = (x[j] + width / 2, bottom - frac / 2) # Label smaller bars with a text off to the side 

1702 ax.annotate(f'{alpha}, {num_evals}', xy, xytext=(xy[0] + 0.2, xy[1]), 

1703 arrowprops={'arrowstyle': '->', 'linewidth': 1}) 

1704 else: 

1705 pass # Don't label really small bars 

1706 ax.set_xlabel('') 

1707 ax.set_ylabel('Fraction of total cost') 

1708 ax.set_xticks(x, xlabels) 

1709 ax.set_xlim(left=-1, right=x[-1] + 1) 

1710 

1711 if self.root_dir is not None: 

1712 fig.savefig(Path(self.root_dir) / 'mf_allocation.pdf', bbox_inches='tight', format='pdf') 

1713 

1714 return fig, ax 

1715 

1716 def serialize(self, keep_components=False, serialize_args=None, serialize_kwargs=None) -> dict: 

1717 """Convert to a `dict` with only standard Python types for fields. 

1718 

1719 :param keep_components: whether to serialize the components as well (defaults to False) 

1720 :param serialize_args: `dict` of arguments to pass to each component's serialize method 

1721 :param serialize_kwargs: `dict` of keyword arguments to pass to each component's serialize method 

1722 :returns: a `dict` representation of the `System` object 

1723 """ 

1724 serialize_args = serialize_args or dict() 

1725 serialize_kwargs = serialize_kwargs or dict() 

1726 d = {} 

1727 for key, value in self.__dict__.items(): 

1728 if value is not None and not key.startswith('_'): 

1729 if key == 'components' and not keep_components: 

1730 d[key] = [comp.serialize(keep_yaml_objects=False, 

1731 serialize_args=serialize_args.get(comp.name), 

1732 serialize_kwargs=serialize_kwargs.get(comp.name)) 

1733 for comp in value] 

1734 elif key == 'train_history': 

1735 if len(value) > 0: 

1736 d[key] = value.serialize() 

1737 else: 

1738 if not isinstance(value, _builtin): 

1739 self.logger.warning(f"Attribute '{key}' of type '{type(value)}' may not be a builtin " 

1740 f"Python type. This may cause issues when saving/loading from file.") 

1741 d[key] = value 

1742 

1743 for key, value in self.model_extra.items(): 

1744 if isinstance(value, _builtin): 

1745 d[key] = value 

1746 

1747 return d 

1748 

1749 @classmethod 

1750 def deserialize(cls, serialized_data: dict) -> System: 

1751 """Construct a `System` object from serialized data.""" 

1752 return cls(**serialized_data) 

1753 

1754 @staticmethod 

1755 def _yaml_representer(dumper: yaml.Dumper, system: System): 

1756 """Serialize a `System` object to a YAML representation.""" 

1757 return dumper.represent_mapping(System.yaml_tag, system.serialize(keep_components=True)) 

1758 

1759 @staticmethod 

1760 def _yaml_constructor(loader: yaml.Loader, node): 

1761 """Convert the `!SystemSurrogate` tag in yaml to a [`System`][amisc.system.System] object.""" 

1762 if isinstance(node, yaml.SequenceNode): 

1763 return [ele if isinstance(ele, System) else System.deserialize(ele) 

1764 for ele in loader.construct_sequence(node, deep=True)] 

1765 elif isinstance(node, yaml.MappingNode): 

1766 return System.deserialize(loader.construct_mapping(node, deep=True)) 

1767 else: 

1768 raise NotImplementedError(f'The "{System.yaml_tag}" yaml tag can only be used on a yaml sequence or ' 

1769 f'mapping, not a "{type(node)}".')