Coverage for src/amisc/system.py: 87%

1006 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-01-24 04:51 +0000

1"""The `System` object is a framework for multidisciplinary models. It manages multiple single discipline component 

2models and the connections between them. It provides a top-level interface for constructing and evaluating surrogates. 

3 

4Features: 

5 

6- Manages multidisciplinary models in a graph data structure, supports feedforward and feedback connections 

7- Feedback connections are solved with a fixed-point iteration (FPI) nonlinear solver with anderson acceleration 

8- Top-level interface for training and using surrogates of each component model 

9- Adaptive experimental design for choosing training data efficiently 

10- Convenient testing, plotting, and performance metrics provided to assess quality of surrogates 

11- Detailed logging and traceback information 

12- Supports parallel or vectorized execution of component models 

13- Abstract and flexible interfacing with component models 

14- Easy serialization and deserialization to/from YAML files 

15- Supports approximating field quantities via compression 

16 

17Includes: 

18 

19- `TrainHistory` — a history of training iterations for the system surrogate 

20- `System` — the top-level object for managing multidisciplinary models 

21""" 

22# ruff: noqa: E702 

23from __future__ import annotations 

24 

25import copy 

26import datetime 

27import functools 

28import logging 

29import os 

30import pickle 

31import random 

32import string 

33import time 

34import warnings 

35from collections import ChainMap, UserList, deque 

36from concurrent.futures import ALL_COMPLETED, Executor, wait 

37from datetime import timezone 

38from pathlib import Path 

39from typing import Annotated, Callable, ClassVar, Literal, Optional 

40 

41import matplotlib.pyplot as plt 

42import networkx as nx 

43import numpy as np 

44import yaml 

45from matplotlib.ticker import MaxNLocator 

46from pydantic import BaseModel, ConfigDict, Field, field_validator 

47 

48from amisc.component import Component, IndexSet, MiscTree 

49from amisc.serialize import Serializable, _builtin 

50from amisc.typing import COORDS_STR_ID, LATENT_STR_ID, Dataset, MultiIndex, TrainIteration 

51from amisc.utils import ( 

52 _combine_latent_arrays, 

53 constrained_lls, 

54 format_inputs, 

55 format_outputs, 

56 get_logger, 

57 relative_error, 

58 to_model_dataset, 

59 to_surrogate_dataset, 

60) 

61from amisc.variable import VariableList 

62 

63__all__ = ['TrainHistory', 'System'] 

64 

65 

66class TrainHistory(UserList, Serializable): 

67 """Stores the training history of a system surrogate as a list of `TrainIteration` objects.""" 

68 

69 def __init__(self, data: list = None): 

70 data = data or [] 

71 super().__init__(self._validate_data(data)) 

72 

73 def serialize(self) -> list[dict]: 

74 """Return a list of each result in the history serialized to a `dict`.""" 

75 ret_list = [] 

76 for res in self: 

77 new_res = res.copy() 

78 new_res['alpha'] = str(res['alpha']) 

79 new_res['beta'] = str(res['beta']) 

80 ret_list.append(new_res) 

81 return ret_list 

82 

83 @classmethod 

84 def deserialize(cls, serialized_data: list[dict]) -> TrainHistory: 

85 """Deserialize a list of `dict` objects into a `TrainHistory` object.""" 

86 return TrainHistory(serialized_data) 

87 

88 @classmethod 

89 def _validate_data(cls, data: list[dict]) -> list[TrainIteration]: 

90 return [cls._validate_item(item) for item in data] 

91 

92 @classmethod 

93 def _validate_item(cls, item: dict): 

94 """Format a `TrainIteration` `dict` item before appending to the history.""" 

95 item.setdefault('test_error', None) 

96 item.setdefault('overhead_s', 0.0) 

97 item.setdefault('model_s', 0.0) 

98 item['alpha'] = MultiIndex(item['alpha']) 

99 item['beta'] = MultiIndex(item['beta']) 

100 item['num_evals'] = int(item['num_evals']) 

101 item['added_cost'] = float(item['added_cost']) 

102 item['added_error'] = float(item['added_error']) 

103 item['overhead_s'] = float(item['overhead_s']) 

104 item['model_s'] = float(item['model_s']) 

105 return item 

106 

107 def append(self, item: dict): 

108 super().append(self._validate_item(item)) 

109 

110 def __add__(self, other): 

111 other_list = other.data if isinstance(other, TrainHistory) else other 

112 return TrainHistory(data=self.data + other_list) 

113 

114 def extend(self, items): 

115 super().extend([self._validate_item(item) for item in items]) 

116 

117 def insert(self, index, item): 

118 super().insert(index, self._validate_item(item)) 

119 

120 def __setitem__(self, key, value): 

121 super().__setitem__(key, self._validate_item(value)) 

122 

123 def __eq__(self, other): 

124 """Two `TrainHistory` objects are equal if they have the same length and all items are equal, excluding nans.""" 

125 if not isinstance(other, TrainHistory): 

126 return False 

127 if len(self) != len(other): 

128 return False 

129 for item_self, item_other in zip(self, other): 

130 for key in item_self: 

131 if key in item_other: 

132 val_self = item_self[key] 

133 val_other = item_other[key] 

134 if isinstance(val_self, float) and isinstance(val_other, float): 

135 if not (np.isnan(val_self) and np.isnan(val_other)) and val_self != val_other: 

136 return False 

137 elif isinstance(val_self, dict) and isinstance(val_other, dict): 

138 for v, err in val_self.items(): 

139 if v in val_other: 

140 err_other = val_other[v] 

141 if not (np.isnan(err) and np.isnan(err_other)) and err != err_other: 

142 return False 

143 else: 

144 return False 

145 elif val_self != val_other: 

146 return False 

147 else: 

148 return False 

149 return True 

150 

151 

152class _Converged: 

153 """Helper class to track which samples have converged during `System.predict()`.""" 

154 def __init__(self, num_samples): 

155 self.num_samples = num_samples 

156 self.valid_idx = np.full(num_samples, True) # All samples are valid by default 

157 self.converged_idx = np.full(num_samples, False) # For FPI convergence 

158 

159 def reset_convergence(self): 

160 self.converged_idx = np.full(self.num_samples, False) 

161 

162 @property 

163 def curr_idx(self): 

164 return np.logical_and(self.valid_idx, ~self.converged_idx) 

165 

166 

167def _merge_shapes(target_shape, arr): 

168 """Helper to merge an array into the target shape.""" 

169 shape1, shape2 = target_shape, arr.shape 

170 if len(shape2) > len(shape1): 

171 shape1, shape2 = shape2, shape1 

172 result = [] 

173 for i in range(len(shape1)): 

174 if i < len(shape2): 

175 if shape1[i] == 1: 

176 result.append(shape2[i]) 

177 elif shape2[i] == 1: 

178 result.append(shape1[i]) 

179 else: 

180 result.append(shape1[i]) 

181 else: 

182 result.append(1) 

183 arr = arr.reshape(tuple(result)) 

184 return np.broadcast_to(arr, target_shape).copy() 

185 

186 

187class System(BaseModel, Serializable): 

188 """ 

189 Multidisciplinary (MD) surrogate framework top-level class. Construct a `System` from a list of 

190 `Component` models. 

191 

192 !!! Example 

193 ```python 

194 def f1(x): 

195 y = x ** 2 

196 return y 

197 def f2(y): 

198 z = y + 1 

199 return z 

200 

201 system = System(f1, f2) 

202 ``` 

203 

204 A `System` object can saved/loaded from `.yml` files using the `!System` yaml tag. 

205 

206 :ivar name: the name of the system 

207 :ivar components: list of `Component` models that make up the MD system 

208 :ivar train_history: history of training iterations for the system surrogate (filled in during training) 

209 

210 :ivar _root_dir: root directory where all surrogate build products are saved to file 

211 :ivar _logger: logger object for the system 

212 """ 

213 yaml_tag: ClassVar[str] = u'!System' 

214 model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True, validate_default=True, 

215 extra='allow') 

216 

217 name: Annotated[str, Field(default_factory=lambda: "System_" + "".join(random.choices(string.digits, k=3)))] 

218 components: Callable | Component | list[Callable | Component] 

219 train_history: list[dict] | TrainHistory = TrainHistory() 

220 amisc_version: str = None 

221 

222 _root_dir: Optional[str] 

223 _logger: Optional[logging.Logger] = None 

224 

225 def __init__(self, /, *args, components=None, root_dir=None, **kwargs): 

226 """Construct a `System` object from a list of `Component` models in `*args` or `components`. If 

227 a `root_dir` is provided, then a new directory will be created under `root_dir` with the name 

228 `amisc_{timestamp}`. This directory will be used to save all build products and log files. 

229 

230 :param components: list of `Component` models that make up the MD system 

231 :param root_dir: root directory where all surrogate build products are saved to file (optional) 

232 """ 

233 if components is None: 

234 components = [] 

235 for a in args: 

236 if isinstance(a, Component) or callable(a): 

237 components.append(a) 

238 else: 

239 try: 

240 components.extend(a) 

241 except TypeError as e: 

242 raise ValueError(f"Invalid component: {a}") from e 

243 

244 import amisc 

245 amisc_version = kwargs.pop('amisc_version', amisc.__version__) 

246 super().__init__(components=components, amisc_version=amisc_version, **kwargs) 

247 self.root_dir = root_dir 

248 

249 def __repr__(self): 

250 s = f'---- {self.name} ----\n' 

251 s += f'amisc version: {self.amisc_version}\n' 

252 s += f'Refinement level: {self.refine_level}\n' 

253 s += f'Components: {", ".join([comp.name for comp in self.components])}\n' 

254 s += f'Inputs: {", ".join([var.name for var in self.inputs()])}\n' 

255 s += f'Outputs: {", ".join([var.name for var in self.outputs()])}' 

256 return s 

257 

258 def __str__(self): 

259 return self.__repr__() 

260 

261 @field_validator('components') 

262 @classmethod 

263 def _validate_components(cls, comps) -> list[Component]: 

264 if not isinstance(comps, list): 

265 comps = [comps] 

266 comps = [Component.deserialize(c) for c in comps] 

267 

268 # Merge all variables to avoid name conflicts 

269 merged_vars = VariableList.merge(*[comp.inputs for comp in comps], *[comp.outputs for comp in comps]) 

270 for comp in comps: 

271 comp.inputs.update({var.name: var for var in merged_vars.values() if var in comp.inputs}) 

272 comp.outputs.update({var.name: var for var in merged_vars.values() if var in comp.outputs}) 

273 

274 return comps 

275 

276 @field_validator('train_history') 

277 @classmethod 

278 def _validate_train_history(cls, history) -> TrainHistory: 

279 if isinstance(history, TrainHistory): 

280 return history 

281 else: 

282 return TrainHistory.deserialize(history) 

283 

284 def graph(self) -> nx.DiGraph: 

285 """Build a directed graph of the system components based on their input-output relationships.""" 

286 graph = nx.DiGraph() 

287 model_deps = {} 

288 for comp in self.components: 

289 graph.add_node(comp.name) 

290 for output in comp.outputs: 

291 model_deps[output] = comp.name 

292 for comp in self.components: 

293 for in_var in comp.inputs: 

294 if in_var in model_deps: 

295 graph.add_edge(model_deps[in_var], comp.name) 

296 

297 return graph 

298 

299 def _save_on_error(func): 

300 """Gracefully exit and save the `System` object on any errors.""" 

301 @functools.wraps(func) 

302 def wrap(self, *args, **kwargs): 

303 try: 

304 return func(self, *args, **kwargs) 

305 except: 

306 if self.root_dir is not None: 

307 self.save_to_file(f'{self.name}_error.yml') 

308 self.logger.critical(f'An error occurred during execution of "{func.__name__}". Saving ' 

309 f'System object to {self.name}_error.yml', exc_info=True) 

310 self.logger.info(f'Final system surrogate on exit: \n {self}') 

311 raise 

312 return wrap 

313 _save_on_error = staticmethod(_save_on_error) 

314 

315 def insert_components(self, components: list | Callable | Component): 

316 """Insert new components into the system.""" 

317 components = components if isinstance(components, list) else [components] 

318 self.components = self.components + components 

319 

320 def swap_component(self, old_component: str | Component, new_component: Callable | Component): 

321 """Replace an old component with a new component.""" 

322 old_name = old_component if isinstance(old_component, str) else old_component.name 

323 comps = [comp if comp.name != old_name else new_component for comp in self.components] 

324 self.components = comps 

325 

326 def remove_component(self, component: str | Component): 

327 """Remove a component from the system.""" 

328 comp_name = component if isinstance(component, str) else component.name 

329 self.components = [comp for comp in self.components if comp.name != comp_name] 

330 

331 def inputs(self) -> VariableList: 

332 """Collect all inputs from each component in the `System` and combine them into a 

333 single [`VariableList`][amisc.variable.VariableList] object, excluding variables that are also outputs of 

334 any component. 

335 

336 :returns: A [`VariableList`][amisc.variable.VariableList] containing all inputs from the components. 

337 """ 

338 all_inputs = ChainMap(*[comp.inputs for comp in self.components]) 

339 return VariableList({k: all_inputs[k] for k in all_inputs.keys() - self.outputs().keys()}) 

340 

341 def outputs(self) -> VariableList: 

342 """Collect all outputs from each component in the `System` and combine them into a 

343 single [`VariableList`][amisc.variable.VariableList] object. 

344 

345 :returns: A [`VariableList`][amisc.variable.VariableList] containing all outputs from the components. 

346 """ 

347 return VariableList({k: v for k, v in ChainMap(*[comp.outputs for comp in self.components]).items()}) 

348 

349 def coupling_variables(self) -> VariableList: 

350 """Collect all coupling variables from each component in the `System` and combine them into a 

351 single [`VariableList`][amisc.variable.VariableList] object. 

352 

353 :returns: A [`VariableList`][amisc.variable.VariableList] containing all coupling variables from the components. 

354 """ 

355 all_outputs = self.outputs() 

356 return VariableList({k: all_outputs[k] for k in (all_outputs.keys() & 

357 ChainMap(*[comp.inputs for comp in self.components]).keys())}) 

358 

359 def variables(self): 

360 """Iterator over all variables in the system (inputs and outputs).""" 

361 yield from ChainMap(self.inputs(), self.outputs()).values() 

362 

363 @property 

364 def refine_level(self) -> int: 

365 """The total number of training iterations.""" 

366 return len(self.train_history) 

367 

368 @property 

369 def logger(self) -> logging.Logger: 

370 return self._logger 

371 

372 @logger.setter 

373 def logger(self, logger: logging.Logger): 

374 self._logger = logger 

375 for comp in self.components: 

376 comp.logger = logger 

377 

378 @staticmethod 

379 def timestamp() -> str: 

380 """Return a UTC timestamp string in the isoformat `YYYY-MM-DDTHH.MM.SS`.""" 

381 return datetime.datetime.now(tz=timezone.utc).isoformat().split('.')[0].replace(':', '.') 

382 

383 @property 

384 def root_dir(self): 

385 """Return the root directory of the surrogate (if available), otherwise `None`.""" 

386 return Path(self._root_dir) if self._root_dir is not None else None 

387 

388 @root_dir.setter 

389 def root_dir(self, root_dir: str | Path): 

390 """Set the root directory for all build products. If `root_dir` is `None`, then no products will be saved. 

391 Otherwise, log files, model outputs, surrogate files, etc. will be saved under this directory. 

392 

393 !!! Note "`amisc` root directory" 

394 If `root_dir` is not `None`, then a new directory will be created under `root_dir` with the name 

395 `amisc_{timestamp}`. This directory will be used to save all build products. If `root_dir` matches the 

396 `amisc_*` format, then it will be used directly. 

397 

398 :param root_dir: the root directory for all build products 

399 """ 

400 if root_dir is not None: 

401 parts = Path(root_dir).resolve().parts 

402 if parts[-1].startswith('amisc_'): 

403 self._root_dir = Path(root_dir).resolve().as_posix() 

404 if not self.root_dir.is_dir(): 

405 os.mkdir(self.root_dir) 

406 else: 

407 root_dir = Path(root_dir) / ('amisc_' + self.timestamp()) 

408 os.mkdir(root_dir) 

409 self._root_dir = Path(root_dir).resolve().as_posix() 

410 

411 log_file = None 

412 if not (pth := self.root_dir / 'surrogates').is_dir(): 

413 os.mkdir(pth) 

414 if not (pth := self.root_dir / 'components').is_dir(): 

415 os.mkdir(pth) 

416 for comp in self.components: 

417 if comp.model_kwarg_requested('output_path'): 

418 if not (comp_pth := pth / comp.name).is_dir(): 

419 os.mkdir(comp_pth) 

420 for f in os.listdir(self.root_dir): 

421 if f.endswith('.log'): 

422 log_file = (self.root_dir / f).resolve().as_posix() 

423 break 

424 if log_file is None: 

425 log_file = (self.root_dir / f'amisc_{self.timestamp()}.log').resolve().as_posix() 

426 self.set_logger(log_file=log_file) 

427 

428 else: 

429 self._root_dir = None 

430 self.set_logger(log_file=None) 

431 

432 def set_logger(self, log_file: str | Path | bool = None, stdout: bool = None, logger: logging.Logger = None, 

433 level: int = logging.INFO): 

434 """Set a new `logging.Logger` object. 

435 

436 :param log_file: log to this file if str or Path (defaults to whatever is currently set or empty); 

437 set `False` to remove file logging or set `True` to create a default log file in the root dir 

438 :param stdout: whether to connect the logger to console (defaults to whatever is currently set or `False`) 

439 :param logger: the logging object to use (this will override the `log_file` and `stdout` arguments if set); 

440 if `None`, then a new logger is created according to `log_file` and `stdout` 

441 :param level: the logging level to set the logger to (defaults to `logging.INFO`) 

442 """ 

443 # Decide whether to use stdout 

444 if stdout is None: 

445 stdout = False 

446 if self._logger is not None: 

447 for handler in self._logger.handlers: 

448 if isinstance(handler, logging.StreamHandler): 

449 stdout = True 

450 break 

451 

452 # Decide what log_file to use (if any) 

453 if log_file is True: 

454 log_file = pth / f'amisc_{self.timestamp()}.log' if (pth := self.root_dir) is not None else ( 

455 f'amisc_{self.timestamp()}.log') 

456 elif log_file is None: 

457 if self._logger is not None: 

458 for handler in self._logger.handlers: 

459 if isinstance(handler, logging.FileHandler): 

460 log_file = handler.baseFilename 

461 break 

462 elif log_file is False: 

463 log_file = None 

464 

465 self._logger = logger or get_logger(self.name, log_file=log_file, stdout=stdout, level=level) 

466 

467 for comp in self.components: 

468 comp.set_logger(log_file=log_file, stdout=stdout, logger=logger, level=level) 

469 

470 def sample_inputs(self, size: tuple | int, 

471 component: str = 'System', 

472 normalize: bool = True, 

473 use_pdf: bool | str | list[str] = False, 

474 include: str | list[str] = None, 

475 exclude: str | list[str] = None, 

476 nominal: dict[str, float] = None) -> Dataset: 

477 """Return samples of the inputs according to provided options. Will return samples in the 

478 normalized/compressed space of the surrogate by default. See [`to_model_dataset`][amisc.utils.to_model_dataset] 

479 to convert the samples to be usable by the true model directly. 

480 

481 :param size: tuple or integer specifying shape or number of samples to obtain 

482 :param component: which component to sample inputs for (defaults to full system exogenous inputs) 

483 :param normalize: whether to normalize the samples (defaults to True) 

484 :param use_pdf: whether to sample from variable pdfs (defaults to False, which will instead sample from the 

485 variable domain bounds). If a string or list of strings is provided, then only those variables 

486 or variable categories will be sampled using their pdfs. 

487 :param include: a list of variable or variable categories to include in the sampling. Defaults to using all 

488 input variables. 

489 :param exclude: a list of variable or variable categories to exclude from the sampling. Empty by default. 

490 :param nominal: `dict(var_id=value)` of nominal values for params with relative uncertainty. Specify nominal 

491 values as unnormalized (will be normalized if `normalize=True`) 

492 :returns: `dict` of `(*size,)` samples for each selected input variable 

493 """ 

494 size = (size, ) if isinstance(size, int) else size 

495 nominal = nominal or dict() 

496 inputs = self.inputs() if component == 'System' else self[component].inputs 

497 if include is None: 

498 include = [] 

499 if not isinstance(include, list): 

500 include = [include] 

501 if exclude is None: 

502 exclude = [] 

503 if not isinstance(exclude, list): 

504 exclude = [exclude] 

505 if isinstance(use_pdf, str): 

506 use_pdf = [use_pdf] 

507 

508 selected_inputs = [] 

509 for var in inputs: 

510 if len(include) == 0 or var.name in include or var.category in include: 

511 if var.name not in exclude and var.category not in exclude: 

512 selected_inputs.append(var) 

513 

514 samples = {} 

515 for var in selected_inputs: 

516 # Sample from latent variable domains for field quantities 

517 if var.compression is not None: 

518 latent = var.sample_domain(size) 

519 for i in range(latent.shape[-1]): 

520 samples[f'{var.name}{LATENT_STR_ID}{i}'] = latent[..., i] 

521 

522 # Sample scalars normally 

523 else: 

524 lb, ub = var.get_domain() 

525 pdf = (var.name in use_pdf or var.category in use_pdf) if isinstance(use_pdf, list) else use_pdf 

526 nom = nominal.get(var.name, None) 

527 

528 x_sample = var.sample(size, nominal=nom) if pdf else var.sample_domain(size) 

529 good_idx = (x_sample < ub) & (x_sample > lb) 

530 num_reject = np.sum(~good_idx) 

531 

532 while num_reject > 0: 

533 new_sample = var.sample((num_reject,), nominal=nom) if pdf else var.sample_domain((num_reject,)) 

534 x_sample[~good_idx] = new_sample 

535 good_idx = (x_sample < ub) & (x_sample > lb) 

536 num_reject = np.sum(~good_idx) 

537 

538 samples[var.name] = var.normalize(x_sample) if normalize else x_sample 

539 

540 return samples 

541 

542 def simulate_fit(self): 

543 """Loop back through training history and simulate each iteration. Will yield the internal data structures 

544 of each `Component` surrogate after each iteration of training (without needing to call `fit()` or any 

545 of the underlying models). This might be useful, for example, for computing the surrogate predictions on 

546 a new test set or viewing cumulative training costs. 

547 

548 !!! Example 

549 Say you have a new test set: `(new_xtest, new_ytest)`, and you want to compute the accuracy of the 

550 surrogate fit at each iteration of the training history: 

551 

552 ```python 

553 for train_iter, active_sets, candidate_sets, misc_coeff_train, misc_coeff_test in system.simulate_fit(): 

554 # Do something with the surrogate data structures 

555 new_ysurr = system.predict(new_xtest, index_set=active_sets, misc_coeff=misc_coeff_train) 

556 train_error = relative_error(new_ysurr, new_ytest) 

557 ``` 

558 

559 :return: a generator of the active index sets, candidate index sets, and MISC coefficients 

560 of each component model at each iteration of the training history 

561 """ 

562 # "Simulated" data structures for each component 

563 active_sets = {comp.name: IndexSet() for comp in self.components} # active index sets for each component 

564 candidate_sets = {comp.name: IndexSet() for comp in self.components} # candidate sets for each component 

565 misc_coeff_train = {comp.name: MiscTree() for comp in self.components} # MISC coeff for active sets 

566 misc_coeff_test = {comp.name: MiscTree() for comp in self.components} # MISC coeff for active + candidate sets 

567 

568 for train_result in self.train_history: 

569 # The selected refinement component and indices 

570 comp_star = train_result['component'] 

571 alpha_star = train_result['alpha'] 

572 beta_star = train_result['beta'] 

573 comp = self[comp_star] 

574 

575 # Get forward neighbors for the selected index 

576 neighbors = comp._neighbors(alpha_star, beta_star, active_set=active_sets[comp_star], forward=True) 

577 

578 # "Activate" the index in the simulated data structure 

579 s = set() 

580 s.add((alpha_star, beta_star)) 

581 comp.update_misc_coeff(IndexSet(s), index_set=active_sets[comp_star], 

582 misc_coeff=misc_coeff_train[comp_star]) 

583 

584 if (alpha_star, beta_star) in candidate_sets[comp_star]: 

585 candidate_sets[comp_star].remove((alpha_star, beta_star)) 

586 else: 

587 # Only for initial index which didn't come from the candidate set 

588 comp.update_misc_coeff(IndexSet(s), index_set=active_sets[comp_star].union(candidate_sets[comp_star]), 

589 misc_coeff=misc_coeff_test[comp_star]) 

590 active_sets[comp_star].update(s) 

591 

592 comp.update_misc_coeff(neighbors, index_set=active_sets[comp_star].union(candidate_sets[comp_star]), 

593 misc_coeff=misc_coeff_test[comp_star]) # neighbors will only ever pass here once 

594 candidate_sets[comp_star].update(neighbors) 

595 

596 # Caller can now do whatever they want as if the system surrogate were at this training iteration 

597 # See the "index_set" and "misc_coeff" overrides for `System.predict()` for example 

598 yield train_result, active_sets, candidate_sets, misc_coeff_train, misc_coeff_test 

599 

600 def add_output(self): 

601 """Add an output variable retroactively to a component surrogate. User should provide a callable that 

602 takes a save path and extracts the model output data for given training point/location. 

603 """ 

604 # TODO 

605 # Loop back through the surrogate training history 

606 # Simulate activate_index and extract the model output from file rather than calling the model 

607 # Update all interpolator states 

608 raise NotImplementedError 

609 

610 @_save_on_error 

611 def fit(self, targets: list = None, 

612 num_refine: int = 100, 

613 max_iter: int = 20, 

614 max_tol: float = 1e-3, 

615 runtime_hr: float = 1., 

616 estimate_bounds: bool = False, 

617 update_bounds: bool = True, 

618 test_set: tuple | str | Path = None, 

619 start_test_check: int = None, 

620 save_interval: int = 0, 

621 plot_interval: int = 1, 

622 cache_interval: int = 0, 

623 executor: Executor = None, 

624 weight_fcns: dict[str, callable] | Literal['pdf'] | None = 'pdf'): 

625 """Train the system surrogate adaptively by iterative refinement until an end condition is met. 

626 

627 :param targets: list of system output variables to focus refinement on, use all outputs if not specified 

628 :param num_refine: number of input samples to compute error indicators on 

629 :param max_iter: the maximum number of refinement steps to take 

630 :param max_tol: the max allowable value in relative L2 error to achieve 

631 :param runtime_hr: the threshold wall clock time (hr) at which to stop further refinement (will go 

632 until all models finish the current iteration) 

633 :param estimate_bounds: whether to estimate bounds for the coupling variables; will only try to estimate from 

634 the `test_set` if provided (defaults to `True`). Otherwise, you should manually 

635 provide domains for all coupling variables. 

636 :param update_bounds: whether to continuously update coupling variable bounds during refinement 

637 :param test_set: `tuple` of `(xtest, ytest)` to show convergence of surrogate to the true model. The test set 

638 inputs and outputs are specified as `dicts` of `np.ndarrays` with keys corresponding to the 

639 variable names. Can also pass a path to a `.pkl` file that has the test set data as 

640 {'test_set': (xtest, ytest)}. 

641 :param start_test_check: the iteration to start checking the test set error (defaults to the number 

642 of components); surrogate evaluation isn't useful during initialization so you 

643 should at least allow one iteration per component before checking test set error 

644 :param save_interval: number of refinement steps between each progress save, none if 0; `System.root_dir` 

645 must be specified to save to file 

646 :param plot_interval: how often to plot the error indicator and test set error (defaults to every iteration); 

647 will only plot and save to file if a root directory is set 

648 :param cache_interval: how often to cache component data in order to speed up future training iterations (at 

649 the cost of additional memory usage); defaults to 0 (no caching) 

650 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations (optional, but 

651 recommended for expensive models) 

652 :param weight_fcns: a `dict` of weight functions to apply to each input variable for training data selection; 

653 defaults to using the pdf of each variable. If None, then no weighting is applied. 

654 """ 

655 start_test_check = start_test_check or sum([1 for _ in self.components if _.has_surrogate]) 

656 targets = targets or self.outputs() 

657 xtest, ytest = self._get_test_set(test_set) 

658 max_iter = self.refine_level + max_iter 

659 

660 # Estimate bounds from test set if provided (override current bounds if they are set) 

661 if estimate_bounds: 

662 if ytest is not None: 

663 y_samples = to_surrogate_dataset(ytest, self.outputs(), del_fields=True)[0] # normalize/compress 

664 _combine_latent_arrays(y_samples) 

665 coupling_vars = {k: v for k, v in self.coupling_variables().items() if k in y_samples} 

666 y_min, y_max = {}, {} 

667 for var in coupling_vars.values(): 

668 y_min[var] = np.nanmin(y_samples[var], axis=0) 

669 y_max[var] = np.nanmax(y_samples[var], axis=0) 

670 if var.compression is not None: 

671 new_domain = list(zip(y_min[var].tolist(), y_max[var].tolist())) 

672 var.update_domain(new_domain, override=True) 

673 else: 

674 new_domain = (float(y_min[var]), float(y_max[var])) 

675 var.update_domain(var.denormalize(new_domain), override=True) 

676 del y_samples 

677 else: 

678 self.logger.warning('Could not estimate bounds for coupling variables: no test set provided. ' 

679 'Make sure you manually provide (good) coupling variable domains.') 

680 

681 # Track convergence progress on the error indicator and test set (plot to file) 

682 if self.root_dir is not None: 

683 err_record = [res['added_error'] for res in self.train_history] 

684 err_fig, err_ax = plt.subplots(figsize=(6, 5), layout='tight') 

685 

686 if xtest is not None and ytest is not None: 

687 num_plot = min(len(targets), 3) 

688 test_record = np.full((self.refine_level, num_plot), np.nan) 

689 t_fig, t_ax = plt.subplots(1, num_plot, figsize=(3.5 * num_plot, 4), layout='tight', squeeze=False, 

690 sharey='row') 

691 for j, res in enumerate(self.train_history): 

692 for i, var in enumerate(targets[:num_plot]): 

693 if (perf := res.get('test_error')) is not None: 

694 test_record[j, i] = perf[var] 

695 

696 total_overhead = 0.0 

697 total_model_wall_time = 0.0 

698 t_start = time.time() 

699 while True: 

700 # Adaptive refinement step 

701 t_iter_start = time.time() 

702 train_result = self.refine(targets=targets, num_refine=num_refine, update_bounds=update_bounds, 

703 executor=executor, weight_fcns=weight_fcns) 

704 if train_result['component'] is None: 

705 self._print_title_str('Termination criteria reached: No candidates left to refine') 

706 break 

707 

708 # Keep track of algorithmic overhead (before and after call_model for this iteration) 

709 m_start, m_end = self[train_result['component']].get_model_timestamps() # Start and end of call_model 

710 if m_start is not None and m_end is not None: 

711 train_result['overhead_s'] = (m_start - t_iter_start) + (time.time() - m_end) 

712 train_result['model_s'] = m_end - m_start 

713 else: 

714 train_result['overhead_s'] = time.time() - t_iter_start 

715 train_result['model_s'] = 0.0 

716 total_overhead += train_result['overhead_s'] 

717 total_model_wall_time += train_result['model_s'] 

718 

719 curr_error = train_result['added_error'] 

720 

721 # Plot progress of error indicator 

722 if self.root_dir is not None: 

723 err_record.append(curr_error) 

724 

725 if plot_interval > 0 and self.refine_level % plot_interval == 0: 

726 err_ax.clear(); err_ax.set_yscale('log'); err_ax.grid() 

727 err_ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 

728 err_ax.plot(err_record, '-k') 

729 err_ax.set_xlabel('Iteration'); err_ax.set_ylabel('Relative error indicator') 

730 err_fig.savefig(str(Path(self.root_dir) / 'error_indicator.pdf'), format='pdf', bbox_inches='tight') 

731 

732 # Save performance on a test set 

733 if xtest is not None and ytest is not None: 

734 # don't compute if components are uninitialized 

735 perf = self.test_set_performance(xtest, ytest) if self.refine_level + 1 >= start_test_check else ( 

736 {str(var): np.nan for var in ytest if COORDS_STR_ID not in var}) 

737 train_result['test_error'] = perf.copy() 

738 

739 if self.root_dir is not None: 

740 test_record = np.vstack((test_record, np.array([perf[var] for var in targets[:num_plot]]))) 

741 

742 if plot_interval > 0 and self.refine_level % plot_interval == 0: 

743 for i in range(num_plot): 

744 with warnings.catch_warnings(): 

745 warnings.simplefilter("ignore", UserWarning) 

746 t_ax[0, i].clear(); t_ax[0, i].set_yscale('log'); t_ax[0, i].grid() 

747 t_ax[0, i].xaxis.set_major_locator(MaxNLocator(integer=True)) 

748 t_ax[0, i].plot(test_record[:, i], '-k') 

749 t_ax[0, i].set_title(self.outputs()[targets[i]].get_tex(units=True)) 

750 t_ax[0, i].set_xlabel('Iteration') 

751 t_ax[0, i].set_ylabel('Test set relative error' if i==0 else '') 

752 t_fig.savefig(str(Path(self.root_dir) / 'test_set_error.pdf'),format='pdf',bbox_inches='tight') 

753 

754 self.train_history.append(train_result) 

755 

756 if self.root_dir is not None and save_interval > 0 and self.refine_level % save_interval == 0: 

757 iter_name = f'{self.name}_iter{self.refine_level}' 

758 if not (pth := self.root_dir / 'surrogates' / iter_name).is_dir(): 

759 os.mkdir(pth) 

760 self.save_to_file(f'{iter_name}.yml', save_dir=pth) # Save to an iteration-specific directory 

761 

762 if cache_interval > 0 and self.refine_level % cache_interval == 0: 

763 for comp in self.components: 

764 comp.cache() 

765 

766 # Check all end conditions 

767 if self.refine_level >= max_iter: 

768 self._print_title_str(f'Termination criteria reached: Max iteration {self.refine_level}/{max_iter}') 

769 break 

770 if curr_error < max_tol: 

771 self._print_title_str(f'Termination criteria reached: relative error {curr_error} < tol {max_tol}') 

772 break 

773 if ((time.time() - t_start) / 3600.0) >= runtime_hr: 

774 t_end = time.time() 

775 actual = datetime.timedelta(seconds=t_end - t_start) 

776 target = datetime.timedelta(seconds=runtime_hr * 3600) 

777 train_surplus = ((t_end - t_start) - runtime_hr * 3600) / 3600 

778 self._print_title_str(f'Termination criteria reached: runtime {str(actual)} > {str(target)}') 

779 self.logger.info(f'Surplus wall time: {train_surplus:.3f}/{runtime_hr:.3f} hours ' 

780 f'(+{100 * train_surplus / runtime_hr:.2f}%)') 

781 break 

782 

783 self.logger.info(f'Model evaluation algorithm efficiency: ' 

784 f'{100 * total_model_wall_time / (total_model_wall_time + total_overhead):.2f}%') 

785 

786 if self.root_dir is not None: 

787 iter_name = f'{self.name}_iter{self.refine_level}' 

788 if not (pth := self.root_dir / 'surrogates' / iter_name).is_dir(): 

789 os.mkdir(pth) 

790 self.save_to_file(f'{iter_name}.yml', save_dir=pth) 

791 

792 if xtest is not None and ytest is not None: 

793 self._save_test_set((xtest, ytest)) 

794 

795 self.logger.info(f'Final system surrogate: \n {self}') 

796 

797 def test_set_performance(self, xtest: Dataset, ytest: Dataset, index_set='test') -> Dataset: 

798 """Compute the relative L2 error on a test set for the given target output variables. 

799 

800 :param xtest: `dict` of test set input samples (unnormalized) 

801 :param ytest: `dict` of test set output samples (unnormalized) 

802 :param index_set: index set to use for prediction (defaults to 'train') 

803 :returns: `dict` of relative L2 errors for each target output variable 

804 """ 

805 targets = [var for var in ytest.keys() if COORDS_STR_ID not in var and var in self.outputs()] 

806 coords = {var: ytest[var] for var in ytest if COORDS_STR_ID in var} 

807 xtest = to_surrogate_dataset(xtest, self.inputs(), del_fields=True)[0] 

808 ysurr = self.predict(xtest, index_set=index_set, targets=targets) 

809 ysurr = to_model_dataset(ysurr, self.outputs(), del_latent=True, **coords)[0] 

810 perf = {} 

811 for var in targets: 

812 # Handle relative error for object arrays (field qtys) 

813 ytest_obj = np.issubdtype(ytest[var].dtype, np.object_) 

814 ysurr_obj = np.issubdtype(ysurr[var].dtype, np.object_) 

815 if ytest_obj or ysurr_obj: 

816 _iterable = np.ndindex(ytest[var].shape) if ytest_obj else np.ndindex(ysurr[var].shape) 

817 num, den = [], [] 

818 for index in _iterable: 

819 pred, targ = ysurr[var][index], ytest[var][index] 

820 num.append((pred - targ)**2) 

821 den.append(targ ** 2) 

822 perf[var] = float(np.sqrt(sum([np.sum(n) for n in num]) / sum([np.sum(d) for d in den]))) 

823 else: 

824 perf[var] = float(relative_error(ysurr[var], ytest[var])) 

825 

826 return perf 

827 

828 def refine(self, targets: list = None, num_refine: int = 100, update_bounds: bool = True, executor: Executor = None, 

829 weight_fcns: dict[str, callable] | Literal['pdf'] | None = 'pdf') -> TrainIteration: 

830 """Perform a single adaptive refinement step on the system surrogate. 

831 

832 :param targets: list of system output variables to focus refinement on, use all outputs if not specified 

833 :param num_refine: number of input samples to compute error indicators on 

834 :param update_bounds: whether to continuously update coupling variable bounds during refinement 

835 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations 

836 :param weight_fcns: weight functions for choosing new training data for each input variable; defaults to 

837 the PDFs of each variable. If None, then no weighting is applied. 

838 :returns: `dict` of the refinement results indicating the chosen component and candidate index 

839 """ 

840 self._print_title_str(f'Refining system surrogate: iteration {self.refine_level + 1}') 

841 targets = targets or self.outputs() 

842 

843 # Check for uninitialized components and refine those first 

844 for comp in self.components: 

845 if len(comp.active_set) == 0 and comp.has_surrogate: 

846 alpha_star = (0,) * len(comp.model_fidelity) 

847 beta_star = (0,) * len(comp.max_beta) 

848 self.logger.info(f"Initializing component {comp.name}: adding {(alpha_star, beta_star)} to active set") 

849 model_dir = (pth / 'components' / comp.name) if (pth := self.root_dir) is not None else None 

850 comp.activate_index(alpha_star, beta_star, model_dir=model_dir, executor=executor, 

851 weight_fcns=weight_fcns) 

852 num_evals = comp.get_cost(alpha_star, beta_star) 

853 cost_star = comp.model_costs.get(alpha_star, 1.) * num_evals # Cpu time (s) 

854 err_star = np.nan 

855 return {'component': comp.name, 'alpha': alpha_star, 'beta': beta_star, 'num_evals': int(num_evals), 

856 'added_cost': float(cost_star), 'added_error': float(err_star)} 

857 

858 # Compute entire integrated-surrogate on a random test set for global system QoI error estimation 

859 x_samples = self.sample_inputs(num_refine) 

860 y_curr = self.predict(x_samples, index_set='train', targets=targets) 

861 _combine_latent_arrays(y_curr) 

862 coupling_vars = {k: v for k, v in self.coupling_variables().items() if k in y_curr} 

863 

864 y_min, y_max = None, None 

865 if update_bounds: 

866 y_min = {var: np.nanmin(y_curr[var], axis=0, keepdims=True) for var in coupling_vars} # (1, ydim) 

867 y_max = {var: np.nanmax(y_curr[var], axis=0, keepdims=True) for var in coupling_vars} # (1, ydim) 

868 

869 # Find the candidate surrogate with the largest error indicator 

870 error_max, error_indicator = -np.inf, -np.inf 

871 comp_star, alpha_star, beta_star, err_star, cost_star = None, None, None, -np.inf, 0 

872 for comp in self.components: 

873 if not comp.has_surrogate: # Skip analytic models that don't need a surrogate 

874 continue 

875 

876 self.logger.info(f"Estimating error for component '{comp.name}'...") 

877 

878 if len(comp.candidate_set) > 0: 

879 candidates = list(comp.candidate_set) 

880 if executor is None: 

881 ret = [self.predict(x_samples, targets=targets, index_set={comp.name: {(alpha, beta)}}, 

882 incremental={comp.name: True}) 

883 for alpha, beta in candidates] 

884 else: 

885 temp_buffer = self._remove_unpickleable() 

886 futures = [executor.submit(self.predict, x_samples, targets=targets, 

887 index_set={comp.name: {(alpha, beta)}}, incremental={comp.name: True}) 

888 for alpha, beta in candidates] 

889 wait(futures, timeout=None, return_when=ALL_COMPLETED) 

890 ret = [f.result() for f in futures] 

891 self._restore_unpickleable(temp_buffer) 

892 

893 for i, y_cand in enumerate(ret): 

894 alpha, beta = candidates[i] 

895 _combine_latent_arrays(y_cand) 

896 error = {} 

897 for var, arr in y_cand.items(): 

898 if var in targets: 

899 error[var] = relative_error(arr, y_curr[var], skip_nan=True) 

900 

901 if update_bounds and var in coupling_vars: 

902 y_min[var] = np.nanmin(np.concatenate((y_min[var], arr), axis=0), axis=0, keepdims=True) 

903 y_max[var] = np.nanmax(np.concatenate((y_max[var], arr), axis=0), axis=0, keepdims=True) 

904 

905 delta_error = np.nanmax([np.nanmax(error[var]) for var in error]) # Max error over all target QoIs 

906 num_evals = comp.get_cost(alpha, beta) 

907 delta_work = comp.model_costs.get(alpha, 1.) * num_evals # Cpu time (s) 

908 error_indicator = delta_error / delta_work 

909 

910 self.logger.info(f"Candidate multi-index: {(alpha, beta)}. Relative error: {delta_error}. " 

911 f"Error indicator: {error_indicator}.") 

912 

913 if error_indicator > error_max: 

914 error_max = error_indicator 

915 comp_star, alpha_star, beta_star, err_star, cost_star = ( 

916 comp.name, alpha, beta, delta_error, delta_work) 

917 else: 

918 self.logger.info(f"Component '{comp.name}' has no available candidates left!") 

919 

920 # Update all coupling variable ranges 

921 if update_bounds: 

922 for var in coupling_vars.values(): 

923 if np.all(~np.isnan(y_min[var])) and np.all(~np.isnan(y_max[var])): 

924 if var.compression is not None: 

925 new_domain = list(zip(np.squeeze(y_min[var], axis=0).tolist(), 

926 np.squeeze(y_max[var], axis=0).tolist())) 

927 var.update_domain(new_domain) 

928 else: 

929 new_domain = (y_min[var][0], y_max[var][0]) 

930 var.update_domain(var.denormalize(new_domain)) # bds will be in norm space from predict() call 

931 

932 # Add the chosen multi-index to the chosen component 

933 if comp_star is not None: 

934 self.logger.info(f"Candidate multi-index {(alpha_star, beta_star)} chosen for component '{comp_star}'.") 

935 model_dir = (pth / 'components' / comp_star) if (pth := self.root_dir) is not None else None 

936 self[comp_star].activate_index(alpha_star, beta_star, model_dir=model_dir, executor=executor, 

937 weight_fcns=weight_fcns) 

938 num_evals = self[comp_star].get_cost(alpha_star, beta_star) 

939 else: 

940 self.logger.info(f"No candidates left for refinement, iteration: {self.refine_level}") 

941 num_evals = 0 

942 

943 # Return the results of the refinement step 

944 return {'component': comp_star, 'alpha': alpha_star, 'beta': beta_star, 'num_evals': int(num_evals), 

945 'added_cost': float(cost_star), 'added_error': float(err_star)} 

946 

947 def predict(self, x: dict | Dataset, 

948 max_fpi_iter: int = 100, 

949 anderson_mem: int = 10, 

950 fpi_tol: float = 1e-10, 

951 use_model: str | tuple | dict = None, 

952 model_dir: str | Path = None, 

953 verbose: bool = False, 

954 index_set: dict[str: IndexSet | Literal['train', 'test']] = 'test', 

955 misc_coeff: dict[str: MiscTree] = None, 

956 normalized_inputs: bool = True, 

957 incremental: dict[str, bool] = False, 

958 targets: list[str] = None, 

959 executor: Executor = None, 

960 var_shape: dict[str, tuple] = None) -> Dataset: 

961 """Evaluate the system surrogate at inputs `x`. Return `y = system(x)`. 

962 

963 !!! Warning "Computing the true model with feedback loops" 

964 You can use this function to predict outputs for your MD system using the full-order models rather than the 

965 surrogate, by specifying `use_model`. This is convenient because the `System` manages all the 

966 coupled information flow between models automatically. However, it is *highly* recommended to not use 

967 the full model if your system contains feedback loops. The FPI nonlinear solver would be infeasible using 

968 anything more computationally demanding than the surrogate. 

969 

970 :param x: `dict` of input samples for each variable in the system 

971 :param max_fpi_iter: the limit on convergence for the fixed-point iteration routine 

972 :param anderson_mem: hyperparameter for tuning the convergence of FPI with anderson acceleration 

973 :param fpi_tol: tolerance limit for convergence of fixed-point iteration 

974 :param use_model: 'best'=highest-fidelity, 'worst'=lowest-fidelity, tuple=specific fidelity, None=surrogate, 

975 specify a `dict` of the above to assign different model fidelities for diff components 

976 :param model_dir: directory to save model outputs if `use_model` is specified 

977 :param verbose: whether to print out iteration progress during execution 

978 :param index_set: `dict(comp=[indices])` to override the active set for a component, defaults to using the 

979 `test` set for every component. Can also specify `train` for any component or a valid 

980 `IndexSet` object. If `incremental` is specified, will be overwritten with `train`. 

981 :param misc_coeff: `dict(comp=MiscTree)` to override the default coefficients for a component, passes through 

982 along with `index_set` and `incremental` to `comp.predict()`. 

983 :param normalized_inputs: true if the passed inputs are compressed/normalized for surrogate evaluation 

984 (default), such as inputs returned by `sample_inputs`. Set to `False` if you are 

985 passing inputs as the true models would expect them instead (i.e. not normalized). 

986 :param incremental: whether to add `index_set` to the current active set for each component (temporarily); 

987 this will set `index_set='train'` for all other components (since incremental will 

988 augment the "training" active sets, not the "testing" candidate sets) 

989 :param targets: list of output variables to return, defaults to returning all system outputs 

990 :param executor: a `concurrent.futures.Executor` object to parallelize model evaluations 

991 :param var_shape: (Optional) `dict` of shapes for field quantity inputs in `x` -- you would only specify this 

992 if passing field qtys directly to the models (i.e. not using `sample_inputs`) 

993 :returns: `dict` of output variables - the surrogate approximation of the system outputs (or the true model) 

994 """ 

995 # Format inputs and allocate space 

996 var_shape = var_shape or {} 

997 x, loop_shape = format_inputs(x, var_shape=var_shape) # {'x': (N, *var_shape)} 

998 y = {} 

999 all_inputs = ChainMap(x, y) # track all inputs (including coupling vars in y) 

1000 N = int(np.prod(loop_shape)) 

1001 t1 = 0 

1002 output_dir = None 

1003 norm_status = {var: normalized_inputs for var in x} # keep track of whether inputs are normalized or not 

1004 graph = self.graph() 

1005 

1006 # Keep track of what outputs are computed 

1007 is_computed = {} 

1008 for var in (targets or self.outputs()): 

1009 if (v := self.outputs().get(var, None)) is not None: 

1010 if v.compression is not None: 

1011 for field in v.compression.fields: 

1012 is_computed[field] = False 

1013 else: 

1014 is_computed[var] = False 

1015 

1016 def _set_default(struct: dict, default=None): 

1017 """Helper to set a default value for each component key in a `dict`. Ensures all components have a value.""" 

1018 if struct is not None: 

1019 if not isinstance(struct, dict): 

1020 struct = {node: struct for node in graph.nodes} # use same for each component 

1021 else: 

1022 struct = {node: default for node in graph.nodes} 

1023 return {node: struct.get(node, default) for node in graph.nodes} 

1024 

1025 # Ensure use_model, index_set, and incremental are specified for each component model 

1026 use_model = _set_default(use_model, None) 

1027 incremental = _set_default(incremental, False) # default to train if incremental anywhere 

1028 index_set = _set_default(index_set, 'train' if any([incremental[node] for node in graph.nodes]) else 'test') 

1029 misc_coeff = _set_default(misc_coeff, None) 

1030 

1031 samples = _Converged(N) # track convergence of samples 

1032 

1033 def _gather_comp_inputs(comp, coupling=None): 

1034 """Helper to gather inputs for a component, making sure they are normalized correctly. Any coupling 

1035 variables passed in will be used in preference over `all_inputs`. 

1036 """ 

1037 # Will access but not modify: all_inputs, use_model, norm_status 

1038 field_coords = {} 

1039 comp_input = {} 

1040 coupling = coupling or {} 

1041 

1042 # Take coupling variables as a priority 

1043 comp_input.update({var: np.copy(arr[samples.curr_idx, ...]) for var, arr in 

1044 coupling.items() if str(var).split(LATENT_STR_ID)[0] in comp.inputs}) 

1045 # Gather all other inputs 

1046 for var, arr in all_inputs.items(): 

1047 var_id = str(var).split(LATENT_STR_ID)[0] 

1048 if var_id in comp.inputs and var not in coupling: 

1049 comp_input[var] = np.copy(arr[samples.curr_idx, ...]) 

1050 

1051 # Gather field coordinates 

1052 for var in comp.inputs: 

1053 coords_str = f'{var}{COORDS_STR_ID}' 

1054 if (coords := all_inputs.get(coords_str)) is not None: 

1055 field_coords[coords_str] = coords[samples.curr_idx, ...] 

1056 elif (coords := comp.model_kwargs.get(coords_str)) is not None: 

1057 field_coords[coords_str] = coords 

1058 

1059 # Gather extra fields (will never be in coupling since field couplings should always be latent coeff) 

1060 for var in comp.inputs: 

1061 if var not in comp_input and var.compression is not None: 

1062 for field in var.compression.fields: 

1063 if field in all_inputs: 

1064 comp_input[field] = np.copy(all_inputs[field][samples.curr_idx, ...]) 

1065 

1066 call_model = use_model.get(comp.name, None) is not None 

1067 

1068 # Make sure we format all inputs for model evaluation (i.e. denormalize) 

1069 if call_model: 

1070 norm_inputs = {var: arr for var, arr in comp_input.items() if norm_status[var]} 

1071 if len(norm_inputs) > 0: 

1072 denorm_inputs, fc = to_model_dataset(norm_inputs, comp.inputs, del_latent=True, **field_coords) 

1073 for var in norm_inputs: 

1074 del comp_input[var] 

1075 field_coords.update(fc) 

1076 comp_input.update(denorm_inputs) 

1077 

1078 # Otherwise, make sure we format inputs for surrogate evaluation (i.e. normalize) 

1079 else: 

1080 denorm_inputs = {var: arr for var, arr in comp_input.items() if not norm_status[var]} 

1081 if len(denorm_inputs) > 0: 

1082 norm_inputs, _ = to_surrogate_dataset(denorm_inputs, comp.inputs, del_fields=True, **field_coords) 

1083 

1084 for var in denorm_inputs: 

1085 del comp_input[var] 

1086 comp_input.update(norm_inputs) 

1087 

1088 return comp_input, field_coords, call_model 

1089 

1090 # Convert system into DAG by grouping strongly-connected-components 

1091 dag = nx.condensation(graph) 

1092 

1093 # Compute component models in topological order 

1094 for supernode in nx.topological_sort(dag): 

1095 if np.all(list(is_computed.values())): 

1096 break # Exit early if all selected return qois are computed 

1097 

1098 scc = [n for n in dag.nodes[supernode]['members']] 

1099 samples.reset_convergence() 

1100 

1101 # Compute single component feedforward output (no FPI needed) 

1102 if len(scc) == 1: 

1103 if verbose: 

1104 self.logger.info(f"Running component '{scc[0]}'...") 

1105 t1 = time.time() 

1106 

1107 # Gather inputs 

1108 comp = self[scc[0]] 

1109 comp_input, field_coords, call_model = _gather_comp_inputs(comp) 

1110 

1111 # Compute outputs 

1112 if model_dir is not None: 

1113 output_dir = Path(model_dir) / scc[0] 

1114 if not output_dir.exists(): 

1115 os.mkdir(output_dir) 

1116 comp_output = comp.predict(comp_input, use_model=use_model.get(scc[0]), model_dir=output_dir, 

1117 index_set=index_set.get(scc[0]), incremental=incremental.get(scc[0]), 

1118 misc_coeff=misc_coeff.get(scc[0]), executor=executor, **field_coords) 

1119 

1120 for var, arr in comp_output.items(): 

1121 is_numeric = np.issubdtype(arr.dtype, np.number) 

1122 if is_numeric: # for scalars or vectorized field quantities 

1123 output_shape = arr.shape[1:] 

1124 if y.get(var) is None: 

1125 y.setdefault(var, np.full((N, *output_shape), np.nan)) 

1126 y[var][samples.curr_idx, ...] = arr 

1127 

1128 else: # for fields returned as object arrays 

1129 if y.get(var) is None: 

1130 y.setdefault(var, np.full((N,), None, dtype=object)) 

1131 y[var][samples.curr_idx] = arr 

1132 

1133 # Update valid indices and status for component outputs 

1134 if str(var).split(LATENT_STR_ID)[0] in comp.outputs: 

1135 new_valid = ~np.any(np.isnan(y[var]), axis=tuple(range(1, y[var].ndim))) if is_numeric else ( 

1136 [~np.any(np.isnan(y[var][i])) for i in range(N)] 

1137 ) 

1138 samples.valid_idx = np.logical_and(samples.valid_idx, new_valid) 

1139 

1140 is_computed[str(var).split(LATENT_STR_ID)[0]] = True 

1141 norm_status[var] = not call_model 

1142 

1143 if verbose: 

1144 self.logger.info(f"Component '{scc[0]}' completed. Runtime: {time.time() - t1} s") 

1145 

1146 # Handle FPI for SCCs with more than one component 

1147 else: 

1148 # Set the initial guess for all coupling vars (middle of domain) 

1149 scc_inputs = ChainMap(*[self[comp].inputs for comp in scc]) 

1150 scc_outputs = ChainMap(*[self[comp].outputs for comp in scc]) 

1151 coupling_vars = [scc_inputs.get(var) for var in (scc_inputs.keys() - x.keys()) if var in scc_outputs] 

1152 coupling_prev = {} 

1153 for var in coupling_vars: 

1154 domain = var.get_domain() 

1155 if isinstance(domain, list): # Latent coefficients are the coupling variables 

1156 for i, d in enumerate(domain): 

1157 lb, ub = d 

1158 coupling_prev[f'{var.name}{LATENT_STR_ID}{i}'] = np.broadcast_to((lb + ub) / 2, (N,)).copy() 

1159 norm_status[f'{var.name}{LATENT_STR_ID}{i}'] = True 

1160 else: 

1161 lb, ub = var.normalize(domain) 

1162 shape = (N,) + (1,) * len(var_shape.get(var, ())) 

1163 coupling_prev[var] = np.broadcast_to((lb + ub) / 2, shape).copy() 

1164 norm_status[var] = True 

1165 

1166 residual_hist = deque(maxlen=anderson_mem) 

1167 coupling_hist = deque(maxlen=anderson_mem) 

1168 

1169 def _end_conditions_met(): 

1170 """Helper to compute residual, update history, and check end conditions.""" 

1171 residual = {} 

1172 converged_idx = np.full(N, True) 

1173 for var in coupling_prev: 

1174 residual[var] = y[var] - coupling_prev[var] 

1175 var_conv = np.all(np.abs(residual[var]) <= fpi_tol, axis=tuple(range(1, residual[var].ndim))) 

1176 converged_idx = np.logical_and(converged_idx, var_conv) 

1177 samples.valid_idx = np.logical_and(samples.valid_idx, ~np.isnan(coupling_prev[var])) 

1178 samples.converged_idx = np.logical_or(samples.converged_idx, converged_idx) 

1179 

1180 for var in coupling_prev: 

1181 coupling_prev[var][samples.curr_idx, ...] = y[var][samples.curr_idx, ...] 

1182 residual_hist.append(copy.deepcopy(residual)) 

1183 coupling_hist.append(copy.deepcopy(coupling_prev)) 

1184 

1185 if int(np.sum(samples.curr_idx)) == 0: 

1186 if verbose: 

1187 self.logger.info(f'FPI converged for SCC {scc} in {k} iterations with tol ' 

1188 f'{fpi_tol}. Final time: {time.time() - t1} s') 

1189 return True 

1190 

1191 max_error = np.max([np.max(np.abs(res[samples.curr_idx, ...])) for res in residual.values()]) 

1192 if verbose: 

1193 self.logger.info(f'FPI iter: {k}. Max residual: {max_error}. Time: {time.time() - t1} s') 

1194 

1195 if k >= max_fpi_iter: 

1196 self.logger.warning(f'FPI did not converge in {max_fpi_iter} iterations for SCC {scc}: ' 

1197 f'{max_error} > tol {fpi_tol}. Some samples will be returned as NaN.') 

1198 for var in coupling_prev: 

1199 y[var][~samples.converged_idx, ...] = np.nan 

1200 samples.valid_idx = np.logical_and(samples.valid_idx, samples.converged_idx) 

1201 return True 

1202 else: 

1203 return False 

1204 

1205 # Main FPI loop 

1206 if verbose: 

1207 self.logger.info(f"Initializing FPI for SCC {scc} ...") 

1208 t1 = time.time() 

1209 k = 0 

1210 while True: 

1211 for node in scc: 

1212 # Gather inputs from exogenous and coupling sources 

1213 comp = self[node] 

1214 comp_input, kwds, call_model = _gather_comp_inputs(comp, coupling=coupling_prev) 

1215 

1216 # Compute outputs (just don't do this FPI with expensive real models, please..) 

1217 comp_output = comp.predict(comp_input, use_model=use_model.get(node), model_dir=None, 

1218 index_set=index_set.get(node), incremental=incremental.get(node), 

1219 misc_coeff=misc_coeff.get(node), executor=executor, **kwds) 

1220 for var, arr in comp_output.items(): 

1221 if np.issubdtype(arr.dtype, np.number): # scalars and vectorized field quantities 

1222 output_shape = arr.shape[1:] 

1223 if y.get(var) is not None: 

1224 if output_shape != y.get(var).shape[1:]: 

1225 y[var] = _merge_shapes((N, *output_shape), y[var]) 

1226 else: 

1227 y.setdefault(var, np.full((N, *output_shape), np.nan)) 

1228 y[var][samples.curr_idx, ...] = arr 

1229 else: # fields returned as object arrays 

1230 if y.get(var) is None: 

1231 y.setdefault(var, np.full((N,), None, dtype=object)) 

1232 y[var][samples.curr_idx] = arr 

1233 

1234 if str(var).split(LATENT_STR_ID)[0] in comp.outputs: 

1235 norm_status[var] = not call_model 

1236 is_computed[str(var).split(LATENT_STR_ID)[0]] = True 

1237 

1238 # Compute residual and check end conditions 

1239 if _end_conditions_met(): 

1240 break 

1241 

1242 # Skip anderson acceleration on first iteration 

1243 if k == 0: 

1244 k += 1 

1245 continue 

1246 

1247 # Iterate with anderson acceleration (only iterate on samples that are not yet converged) 

1248 N_curr = int(np.sum(samples.curr_idx)) 

1249 mk = len(residual_hist) # Max of anderson mem 

1250 var_shapes = [] 

1251 xdims = [] 

1252 for var in coupling_prev: 

1253 shape = coupling_prev[var].shape[1:] 

1254 var_shapes.append(shape) 

1255 xdims.append(int(np.prod(shape))) 

1256 N_couple = int(np.sum(xdims)) 

1257 res_snap = np.empty((N_curr, N_couple, mk)) # Shortened snapshot of residual history 

1258 coupling_snap = np.empty((N_curr, N_couple, mk)) # Shortened snapshot of coupling history 

1259 for i, (coupling_iter, residual_iter) in enumerate(zip(coupling_hist, residual_hist)): 

1260 start_idx = 0 

1261 for j, var in enumerate(coupling_prev): 

1262 end_idx = start_idx + xdims[j] 

1263 coupling_snap[:, start_idx:end_idx, i] = coupling_iter[var][samples.curr_idx, ...].reshape((N_curr, -1)) # noqa: E501 

1264 res_snap[:, start_idx:end_idx, i] = residual_iter[var][samples.curr_idx, ...].reshape((N_curr, -1)) # noqa: E501 

1265 start_idx = end_idx 

1266 C = np.ones((N_curr, 1, mk)) 

1267 b = np.zeros((N_curr, N_couple, 1)) 

1268 d = np.ones((N_curr, 1, 1)) 

1269 alpha = np.expand_dims(constrained_lls(res_snap, b, C, d), axis=-3) # (..., 1, mk, 1) 

1270 coupling_new = np.squeeze(coupling_snap[:, :, np.newaxis, :] @ alpha, axis=(-1, -2)) 

1271 start_idx = 0 

1272 for j, var in enumerate(coupling_prev): 

1273 end_idx = start_idx + xdims[j] 

1274 coupling_prev[var][samples.curr_idx, ...] = coupling_new[:, start_idx:end_idx].reshape((N_curr, *var_shapes[j])) # noqa: E501 

1275 start_idx = end_idx 

1276 k += 1 

1277 

1278 # Return all component outputs; samples that didn't converge during FPI are left as np.nan 

1279 return format_outputs(y, loop_shape) 

1280 

1281 def __call__(self, *args, **kwargs): 

1282 """Convenience wrapper to allow calling as `ret = System(x)`.""" 

1283 return self.predict(*args, **kwargs) 

1284 

1285 def __eq__(self, other): 

1286 if not isinstance(other, System): 

1287 return False 

1288 return (self.components == other.components and 

1289 self.name == other.name and 

1290 self.train_history == other.train_history) 

1291 

1292 def __getitem__(self, component: str) -> Component: 

1293 """Convenience method to get a `Component` object from the `System`. 

1294 

1295 :param component: the name of the component to get 

1296 :returns: the `Component` object 

1297 """ 

1298 return self.get_component(component) 

1299 

1300 def get_component(self, comp_name: str) -> Component: 

1301 """Return the `Component` object for this component. 

1302 

1303 :param comp_name: name of the component to return 

1304 :raises KeyError: if the component does not exist 

1305 :returns: the `Component` object 

1306 """ 

1307 if comp_name.lower() == 'system': 

1308 return self 

1309 else: 

1310 for comp in self.components: 

1311 if comp.name == comp_name: 

1312 return comp 

1313 raise KeyError(f"Component '{comp_name}' not found in system.") 

1314 

1315 def _print_title_str(self, title_str: str): 

1316 """Log an important message.""" 

1317 self.logger.info('-' * int(len(title_str)/2) + title_str + '-' * int(len(title_str)/2)) 

1318 

1319 def _remove_unpickleable(self) -> dict: 

1320 """Remove and return unpickleable attributes before pickling (just the logger).""" 

1321 stdout = False 

1322 log_file = None 

1323 if self._logger is not None: 

1324 for handler in self._logger.handlers: 

1325 if isinstance(handler, logging.StreamHandler): 

1326 stdout = True 

1327 break 

1328 for handler in self._logger.handlers: 

1329 if isinstance(handler, logging.FileHandler): 

1330 log_file = handler.baseFilename 

1331 break 

1332 

1333 buffer = {'log_stdout': stdout, 'log_file': log_file} 

1334 self.logger = None 

1335 return buffer 

1336 

1337 def _restore_unpickleable(self, buffer: dict): 

1338 """Restore the unpickleable attributes from `buffer` after unpickling.""" 

1339 self.set_logger(log_file=buffer.get('log_file', None), stdout=buffer.get('log_stdout', None)) 

1340 

1341 def _get_test_set(self, test_set: str | Path | tuple = None) -> tuple: 

1342 """Try to load a test set from the root directory if it exists.""" 

1343 if isinstance(test_set, tuple): 

1344 return test_set # (xtest, ytest) 

1345 else: 

1346 ret = (None, None) 

1347 if test_set is not None: 

1348 test_set = Path(test_set) 

1349 elif self.root_dir is not None: 

1350 test_set = self.root_dir / 'test_set.pkl' 

1351 

1352 if test_set is not None: 

1353 if test_set.exists(): 

1354 with open(test_set, 'rb') as fd: 

1355 data = pickle.load(fd) 

1356 ret = data['test_set'] 

1357 

1358 return ret 

1359 

1360 def _save_test_set(self, test_set: tuple = None): 

1361 """Save the test set to the root directory if possible.""" 

1362 if self.root_dir is not None and test_set is not None: 

1363 test_file = self.root_dir / 'test_set.pkl' 

1364 if not test_file.exists(): 

1365 with open(test_file, 'wb') as fd: 

1366 pickle.dump({'test_set': test_set}, fd) 

1367 

1368 def save_to_file(self, filename: str, save_dir: str | Path = None, dumper=None): 

1369 """Save surrogate to file. Defaults to `root/surrogates/filename.yml` with the default yaml encoder. 

1370 

1371 :param filename: the name of the save file 

1372 :param save_dir: the directory to save the file to (defaults to `root/surrogates` or `cwd()`) 

1373 :param dumper: the encoder to use (defaults to the `amisc` yaml encoder) 

1374 """ 

1375 from amisc import YamlLoader 

1376 encoder = dumper or YamlLoader 

1377 save_dir = save_dir or self.root_dir or Path.cwd() 

1378 if Path(save_dir) == self.root_dir: 

1379 save_dir = self.root_dir / 'surrogates' 

1380 encoder.dump(self, Path(save_dir) / filename) 

1381 

1382 @staticmethod 

1383 def load_from_file(filename: str | Path, root_dir: str | Path = None, loader=None): 

1384 """Load surrogate from file. Defaults to yaml loading. Tries to infer `amisc` directory structure. 

1385 

1386 :param filename: the name of the load file 

1387 :param root_dir: set this as the surrogate's root directory (will try to load from `amisc_` fmt by default) 

1388 :param loader: the encoder to use (defaults to the `amisc` yaml encoder) 

1389 """ 

1390 from amisc import YamlLoader 

1391 encoder = loader or YamlLoader 

1392 system = encoder.load(filename) 

1393 root_dir = root_dir or system.root_dir 

1394 

1395 # Try to infer amisc_root/surrogates/iter/filename structure 

1396 if root_dir is None: 

1397 parts = Path(filename).resolve().parts 

1398 if len(parts) > 1 and parts[-2].startswith('amisc_'): 

1399 root_dir = Path(filename).resolve().parent 

1400 elif len(parts) > 2 and parts[-3].startswith('amisc_'): 

1401 root_dir = Path(filename).resolve().parent.parent 

1402 elif len(parts) > 3 and parts[-4].startswith('amisc_'): 

1403 root_dir = Path(filename).resolve().parent.parent.parent 

1404 

1405 system.root_dir = root_dir 

1406 return system 

1407 

1408 def clear(self): 

1409 """Clear all surrogate model data and reset the system.""" 

1410 for comp in self.components: 

1411 comp.clear() 

1412 self.train_history.clear() 

1413 

1414 def plot_slice(self, inputs: list[str] = None, 

1415 outputs: list[str] = None, 

1416 num_steps: int = 20, 

1417 show_surr: bool = True, 

1418 show_model: str | tuple | list = None, 

1419 save_dir: str | Path = None, 

1420 executor: Executor = None, 

1421 nominal: dict[str: float] = None, 

1422 random_walk: bool = False, 

1423 from_file: str | Path = None, 

1424 subplot_size_in: float = 3.): 

1425 """Helper function to plot 1d slices of the surrogate and/or model outputs over the inputs. A single 

1426 "slice" works by smoothly stepping from the lower bound of an input to its upper bound, while holding all other 

1427 inputs constant at their nominal values (or smoothly varying them if `random_walk=True`). 

1428 This function is useful for visualizing the behavior of the system surrogate and/or model(s) over a 

1429 single input variable at a time. 

1430 

1431 :param inputs: list of input variables to take 1d slices of (defaults to first 3 in `System.inputs`) 

1432 :param outputs: list of model output variables to plot 1d slices of (defaults to first 3 in `System.outputs`) 

1433 :param num_steps: the number of points to take in the 1d slice for each input variable; this amounts to a total 

1434 of `num_steps*len(inputs)` model/surrogate evaluations 

1435 :param show_surr: whether to show the surrogate prediction 

1436 :param show_model: also compute and plot model predictions, `list` of ['best', 'worst', tuple(alpha), etc.] 

1437 :param save_dir: base directory to save model outputs and plots (if specified) 

1438 :param executor: a `concurrent.futures.Executor` object to parallelize model or surrogate evaluations 

1439 :param nominal: `dict` of `var->nominal` to use as constant values for all non-sliced variables (use 

1440 unnormalized values only; use `var_LATENT0` to specify nominal latent values) 

1441 :param random_walk: whether to slice in a random d-dimensional direction instead of holding all non-slice 

1442 variables const at `nominal` 

1443 :param from_file: path to a `.pkl` file to load a saved slice from disk 

1444 :param subplot_size_in: side length size of each square subplot in inches 

1445 :returns: `fig, ax` with `len(inputs)` by `len(outputs)` subplots 

1446 """ 

1447 # Manage loading important quantities from file (if provided) 

1448 input_slices, output_slices_model, output_slices_surr = None, None, None 

1449 if from_file is not None: 

1450 with open(Path(from_file), 'rb') as fd: 

1451 slice_data = pickle.load(fd) 

1452 inputs = slice_data['inputs'] # Must use same input slices as save file 

1453 show_model = slice_data['show_model'] # Must use same model data as save file 

1454 outputs = slice_data.get('outputs') if outputs is None else outputs 

1455 input_slices = slice_data['input_slices'] 

1456 save_dir = None # Don't run or save any models if loading from file 

1457 

1458 # Set default values (take up to the first 3 inputs by default) 

1459 all_inputs = self.inputs() 

1460 all_outputs = self.outputs() 

1461 rand_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4)) 

1462 if save_dir is not None: 

1463 os.mkdir(Path(save_dir) / f'slice_{rand_id}') 

1464 if nominal is None: 

1465 nominal = dict() 

1466 inputs = all_inputs[:3] if inputs is None else inputs 

1467 outputs = all_outputs[:3] if outputs is None else outputs 

1468 

1469 if show_model is not None and not isinstance(show_model, list): 

1470 show_model = [show_model] 

1471 

1472 # Handle field quantities (directly use latent variables or only the first one) 

1473 for i, var in enumerate(list(inputs)): 

1474 if LATENT_STR_ID not in str(var) and all_inputs[var].compression is not None: 

1475 inputs[i] = f'{var}{LATENT_STR_ID}0' 

1476 for i, var in enumerate(list(outputs)): 

1477 if LATENT_STR_ID not in str(var) and all_outputs[var].compression is not None: 

1478 outputs[i] = f'{var}{LATENT_STR_ID}0' 

1479 

1480 bds = all_inputs.get_domains() 

1481 xlabels = [all_inputs[var].get_tex(units=False) if LATENT_STR_ID not in str(var) else 

1482 all_inputs[str(var).split(LATENT_STR_ID)[0]].get_tex(units=False) + 

1483 f' (latent {str(var).split(LATENT_STR_ID)[1]})' for var in inputs] 

1484 

1485 ylabels = [all_outputs[var].get_tex(units=False) if LATENT_STR_ID not in str(var) else 

1486 all_outputs[str(var).split(LATENT_STR_ID)[0]].get_tex(units=False) + 

1487 f' (latent {str(var).split(LATENT_STR_ID)[1]})' for var in outputs] 

1488 

1489 # Construct slices of model inputs (if not provided) 

1490 if input_slices is None: 

1491 input_slices = {} # Each input variable with shape (num_steps, num_slice) 

1492 for i in range(len(inputs)): 

1493 if random_walk: 

1494 # Make a random straight-line walk across d-cube 

1495 r0 = self.sample_inputs((1,), use_pdf=False) 

1496 rf = self.sample_inputs((1,), use_pdf=False) 

1497 

1498 for var, bd in bds.items(): 

1499 if var == inputs[i]: 

1500 r0[var] = np.atleast_1d(bd[0]) # Start slice at this lower bound 

1501 rf[var] = np.atleast_1d(bd[1]) # Slice up to this upper bound 

1502 

1503 step_size = (rf[var] - r0[var]) / (num_steps - 1) 

1504 arr = r0[var] + step_size * np.arange(num_steps) 

1505 

1506 input_slices[var] = arr[..., np.newaxis] if input_slices.get(var) is None else ( 

1507 np.concatenate((input_slices[var], arr[..., np.newaxis]), axis=-1)) 

1508 else: 

1509 # Otherwise, only slice one variable 

1510 for var, bd in bds.items(): 

1511 nom = nominal.get(var, np.mean(bd)) if LATENT_STR_ID in str(var) else ( 

1512 all_inputs[var].normalize(nominal.get(var, all_inputs[var].get_nominal()))) 

1513 arr = np.linspace(bd[0], bd[1], num_steps) if var == inputs[i] else np.full(num_steps, nom) 

1514 

1515 input_slices[var] = arr[..., np.newaxis] if input_slices.get(var) is None else ( 

1516 np.concatenate((input_slices[var], arr[..., np.newaxis]), axis=-1)) 

1517 

1518 # Walk through each model that is requested by show_model 

1519 if show_model is not None: 

1520 if from_file is not None: 

1521 output_slices_model = slice_data['output_slices_model'] 

1522 else: 

1523 output_slices_model = list() 

1524 for model in show_model: 

1525 output_dir = None 

1526 if save_dir is not None: 

1527 output_dir = (Path(save_dir) / f'slice_{rand_id}' / 

1528 str(model).replace('{', '').replace('}', '').replace(':', '=').replace("'", '')) 

1529 os.mkdir(output_dir) 

1530 output_slices_model.append(self.predict(input_slices, use_model=model, model_dir=output_dir, 

1531 executor=executor)) 

1532 if show_surr: 

1533 output_slices_surr = self.predict(input_slices, executor=executor) \ 

1534 if from_file is None else slice_data['output_slices_surr'] 

1535 

1536 # Make len(outputs) by len(inputs) grid of subplots 

1537 fig, axs = plt.subplots(len(outputs), len(inputs), sharex='col', sharey='row', squeeze=False) 

1538 for i, output_var in enumerate(outputs): 

1539 for j, input_var in enumerate(inputs): 

1540 ax = axs[i, j] 

1541 x = input_slices[input_var][:, j] 

1542 

1543 if show_model is not None: 

1544 c = np.array([[0, 0, 0, 1], [0.5, 0.5, 0.5, 1]]) if len(show_model) <= 2 else ( 

1545 plt.get_cmap('jet')(np.linspace(0, 1, len(show_model)))) 

1546 for k in range(len(show_model)): 

1547 model_str = (str(show_model[k]).replace('{', '').replace('}', '') 

1548 .replace(':', '=').replace("'", '')) 

1549 model_ret = to_surrogate_dataset(output_slices_model[k], all_outputs)[0] 

1550 y_model = model_ret[output_var][:, j] 

1551 label = {'best': 'High-fidelity' if len(show_model) > 1 else 'Model', 

1552 'worst': 'Low-fidelity'}.get(model_str, model_str) 

1553 ax.plot(x, y_model, ls='-', c=c[k, :], label=label) 

1554 

1555 if show_surr: 

1556 y_surr = output_slices_surr[output_var][:, j] 

1557 ax.plot(x, y_surr, '--r', label='Surrogate') 

1558 

1559 ax.set_xlabel(xlabels[j] if i == len(outputs) - 1 else '') 

1560 ax.set_ylabel(ylabels[i] if j == 0 else '') 

1561 if i == 0 and j == len(inputs) - 1: 

1562 ax.legend() 

1563 fig.set_size_inches(subplot_size_in * len(inputs), subplot_size_in * len(outputs)) 

1564 fig.tight_layout() 

1565 

1566 # Save results (unless we were already loading from a save file) 

1567 if from_file is None and save_dir is not None: 

1568 fname = f'in={",".join([str(v) for v in inputs])}_out={",".join([str(v) for v in outputs])}' 

1569 fname = f'slice_rand{rand_id}_' + fname if random_walk else f'slice_nom{rand_id}_' + fname 

1570 fdir = Path(save_dir) / f'slice_{rand_id}' 

1571 fig.savefig(fdir / f'{fname}.pdf', bbox_inches='tight', format='pdf') 

1572 save_dict = {'inputs': inputs, 'outputs': outputs, 'show_model': show_model, 'show_surr': show_surr, 

1573 'nominal': nominal, 'random_walk': random_walk, 'input_slices': input_slices, 

1574 'output_slices_model': output_slices_model, 'output_slices_surr': output_slices_surr} 

1575 with open(fdir / f'{fname}.pkl', 'wb') as fd: 

1576 pickle.dump(save_dict, fd) 

1577 

1578 return fig, axs 

1579 

1580 def get_allocation(self): 

1581 """Get a breakdown of cost allocation during training. 

1582 

1583 :returns: `cost_alloc, model_cost, overhead_cost, model_evals` - the cost allocation per model/fidelity, 

1584 the model evaluation cost per iteration (in s of CPU time), the algorithmic overhead cost per 

1585 iteration, and the total number of model evaluations at each training iteration 

1586 """ 

1587 cost_alloc = dict() # Cost allocation (cpu time in s) per node and model fidelity 

1588 model_cost = [] # Cost of model evaluations (CPU time in s) per iteration 

1589 overhead_cost = [] # Algorithm overhead costs (CPU time in s) per iteration 

1590 model_evals = [] # Number of model evaluations at each training iteration 

1591 

1592 prev_cands = {comp.name: IndexSet() for comp in self.components} # empty candidate sets 

1593 

1594 # Add cumulative training costs 

1595 for train_res, active_sets, cand_sets, misc_coeff_train, misc_coeff_test in self.simulate_fit(): 

1596 comp = train_res['component'] 

1597 alpha = train_res['alpha'] 

1598 beta = train_res['beta'] 

1599 overhead = train_res['overhead_s'] 

1600 

1601 cost_alloc.setdefault(comp, dict()) 

1602 

1603 new_cands = cand_sets[comp].union({(alpha, beta)}) - prev_cands[comp] # newly computed candidates 

1604 

1605 iter_cost = 0. 

1606 iter_eval = 0 

1607 for alpha_new, beta_new in new_cands: 

1608 cost_alloc[comp].setdefault(alpha_new, 0.) 

1609 

1610 added_eval = self[comp].get_cost(alpha_new, beta_new) 

1611 single_cost = self[comp].model_costs.get(alpha_new, 1.) 

1612 

1613 iter_cost += added_eval * single_cost 

1614 iter_eval += added_eval 

1615 

1616 cost_alloc[comp][alpha_new] += added_eval * single_cost 

1617 

1618 overhead_cost.append(overhead) 

1619 model_cost.append(iter_cost) 

1620 model_evals.append(iter_eval) 

1621 prev_cands[comp] = cand_sets[comp].union({(alpha, beta)}) 

1622 

1623 return cost_alloc, np.atleast_1d(model_cost), np.atleast_1d(overhead_cost), np.atleast_1d(model_evals) 

1624 

1625 def plot_allocation(self, cmap: str = 'Blues', text_bar_width: float = 0.06, arrow_bar_width: float = 0.02): 

1626 """Plot bar charts showing cost allocation during training. 

1627 

1628 !!! Warning "Beta feature" 

1629 This has pretty good default settings, but it might look terrible for your use. Mostly provided here as 

1630 a template for making cost allocation bar charts. Please feel free to copy and edit in your own code. 

1631 

1632 :param cmap: the colormap string identifier for `plt` 

1633 :param text_bar_width: the minimum total cost fraction above which a bar will print centered model fidelity text 

1634 :param arrow_bar_width: the minimum total cost fraction above which a bar will try to print text with an arrow; 

1635 below this amount, the bar is too skinny and won't print any text 

1636 :returns: `fig, ax`, Figure and Axes objects 

1637 """ 

1638 # Get total cost 

1639 cost_alloc, model_cost, _, _ = self.get_allocation() 

1640 total_cost = np.cumsum(model_cost)[-1] 

1641 

1642 # Remove nodes with cost=0 from alloc dicts (i.e. analytical models) 

1643 remove_nodes = [] 

1644 for node, alpha_dict in cost_alloc.items(): 

1645 if len(alpha_dict) == 0: 

1646 remove_nodes.append(node) 

1647 for node in remove_nodes: 

1648 del cost_alloc[node] 

1649 

1650 # Bar chart showing cost allocation breakdown for MF system at final iteration 

1651 fig, ax = plt.subplots(figsize=(6, 5), layout='tight') 

1652 width = 0.7 

1653 x = np.arange(len(cost_alloc)) 

1654 xlabels = list(cost_alloc.keys()) # One bar for each component 

1655 cmap = plt.get_cmap(cmap) 

1656 

1657 for j, (node, alpha_dict) in enumerate(cost_alloc.items()): 

1658 bottom = 0 

1659 c_intervals = np.linspace(0, 1, len(alpha_dict)) 

1660 bars = [(alpha, cost, cost / total_cost) for alpha, cost in alpha_dict.items()] 

1661 bars = sorted(bars, key=lambda ele: ele[2], reverse=True) 

1662 for i, (alpha, cost, frac) in enumerate(bars): 

1663 p = ax.bar(x[j], frac, width, color=cmap(c_intervals[i]), linewidth=1, 

1664 edgecolor=[0, 0, 0], bottom=bottom) 

1665 bottom += frac 

1666 num_evals = round(cost / self[node].model_costs.get(alpha, 1.)) 

1667 if frac > text_bar_width: 

1668 ax.bar_label(p, labels=[f'{alpha}, {num_evals}'], label_type='center') 

1669 elif frac > arrow_bar_width: 

1670 xy = (x[j] + width / 2, bottom - frac / 2) # Label smaller bars with a text off to the side 

1671 ax.annotate(f'{alpha}, {num_evals}', xy, xytext=(xy[0] + 0.2, xy[1]), 

1672 arrowprops={'arrowstyle': '->', 'linewidth': 1}) 

1673 else: 

1674 pass # Don't label really small bars 

1675 ax.set_xlabel('') 

1676 ax.set_ylabel('Fraction of total cost') 

1677 ax.set_xticks(x, xlabels) 

1678 ax.set_xlim(left=-1, right=x[-1] + 1) 

1679 

1680 if self.root_dir is not None: 

1681 fig.savefig(Path(self.root_dir) / 'mf_allocation.pdf', bbox_inches='tight', format='pdf') 

1682 

1683 return fig, ax 

1684 

1685 def serialize(self, keep_components=False, serialize_args=None, serialize_kwargs=None) -> dict: 

1686 """Convert to a `dict` with only standard Python types for fields. 

1687 

1688 :param keep_components: whether to serialize the components as well (defaults to False) 

1689 :param serialize_args: `dict` of arguments to pass to each component's serialize method 

1690 :param serialize_kwargs: `dict` of keyword arguments to pass to each component's serialize method 

1691 :returns: a `dict` representation of the `System` object 

1692 """ 

1693 serialize_args = serialize_args or dict() 

1694 serialize_kwargs = serialize_kwargs or dict() 

1695 d = {} 

1696 for key, value in self.__dict__.items(): 

1697 if value is not None and not key.startswith('_'): 

1698 if key == 'components' and not keep_components: 

1699 d[key] = [comp.serialize(keep_yaml_objects=False, 

1700 serialize_args=serialize_args.get(comp.name), 

1701 serialize_kwargs=serialize_kwargs.get(comp.name)) 

1702 for comp in value] 

1703 elif key == 'train_history': 

1704 if len(value) > 0: 

1705 d[key] = value.serialize() 

1706 else: 

1707 if not isinstance(value, _builtin): 

1708 self.logger.warning(f"Attribute '{key}' of type '{type(value)}' may not be a builtin " 

1709 f"Python type. This may cause issues when saving/loading from file.") 

1710 d[key] = value 

1711 

1712 for key, value in self.model_extra.items(): 

1713 if isinstance(value, _builtin): 

1714 d[key] = value 

1715 

1716 return d 

1717 

1718 @classmethod 

1719 def deserialize(cls, serialized_data: dict) -> System: 

1720 """Construct a `System` object from serialized data.""" 

1721 return cls(**serialized_data) 

1722 

1723 @staticmethod 

1724 def _yaml_representer(dumper: yaml.Dumper, system: System): 

1725 """Serialize a `System` object to a YAML representation.""" 

1726 return dumper.represent_mapping(System.yaml_tag, system.serialize(keep_components=True)) 

1727 

1728 @staticmethod 

1729 def _yaml_constructor(loader: yaml.Loader, node): 

1730 """Convert the `!SystemSurrogate` tag in yaml to a [`System`][amisc.system.System] object.""" 

1731 if isinstance(node, yaml.SequenceNode): 

1732 return [ele if isinstance(ele, System) else System.deserialize(ele) 

1733 for ele in loader.construct_sequence(node, deep=True)] 

1734 elif isinstance(node, yaml.MappingNode): 

1735 return System.deserialize(loader.construct_mapping(node, deep=True)) 

1736 else: 

1737 raise NotImplementedError(f'The "{System.yaml_tag}" yaml tag can only be used on a yaml sequence or ' 

1738 f'mapping, not a "{type(node)}".')