Coverage for src/amisc/system.py: 83%

813 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-29 21:38 +0000

1"""The `SystemSurrogate` is a framework for multidisciplinary models. It manages multiple single discipline component 

2models and the connections between them. It provides a top-level interface for constructing and evaluating surrogates. 

3 

4Features 

5-------- 

6- Manages multidisciplinary models in a graph data structure, supports feedforward and feedback connections 

7- Feedback connections are solved with a fixed-point iteration (FPI) nonlinear solver 

8- FPI uses Anderson acceleration and surrogate evaluations for speed-up 

9- Top-level interface for training and using surrogates of each component model 

10- Adaptive experimental design for choosing training data efficiently 

11- Convenient testing, plotting, and performance metrics provided to assess quality of surrogates 

12- Detailed logging and traceback information 

13- Supports parallel execution with OpenMP and MPI protocols 

14- Abstract and flexible interfacing with component models 

15 

16!!! Info "Model specification" 

17 Models are callable Python wrapper functions of the form `ret = model(x, *args, **kwargs)`, where `x` is an 

18 `np.ndarray` of model inputs (and `*args, **kwargs` allow passing any other required configurations for your model). 

19 The return value is a Python dictionary of the form `ret = {'y': y, 'files': files, 'cost': cost, etc.}`. In the 

20 return dictionary, you specify the raw model output `y` as an `np.ndarray` at a _minimum_. Optionally, you can 

21 specify paths to output files and the average model cost (in seconds of cpu time), and anything else you want. Your 

22 `model()` function can do anything it wants in order to go from `x` → `y`. Python has the flexibility to call 

23 virtually any external codes, or to implement the function natively with `numpy`. 

24 

25!!! Info "Component specification" 

26 A component adds some extra configuration around a callable `model`. These configurations are defined in a Python 

27 dictionary, which we give the custom type `ComponentSpec`. At a bare _minimum_, you must specify a callable 

28 `model` and its connections to other models within the multidisciplinary system. The limiting case is a single 

29 component model, for which the configuration is simply `component = ComponentSpec(model)`. 

30""" 

31# ruff: noqa: E702 

32import copy 

33import datetime 

34import functools 

35import os 

36import pickle 

37import random 

38import shutil 

39import string 

40import time 

41from collections import UserDict 

42from concurrent.futures import Executor 

43from datetime import timezone 

44from pathlib import Path 

45 

46import dill 

47import matplotlib.pyplot as plt 

48import networkx as nx 

49import numpy as np 

50from joblib import Parallel, delayed 

51from joblib.externals.loky import set_loky_pickler 

52from uqtils import ax_default 

53 

54from amisc import IndexSet, IndicesRV 

55from amisc.component import AnalyticalSurrogate, ComponentSurrogate, SparseGridSurrogate 

56from amisc.rv import BaseRV 

57from amisc.utils import get_logger 

58 

59 

60class ComponentSpec(UserDict): 

61 """Provides a simple extension class of a Python dictionary, used to configure a component model. 

62 

63 !!! Info "Specifying a list of random variables" 

64 The three fields: `exo_in`, `coupling_in`, and `coupling_out` fully determine how a component fits within a 

65 multidisciplinary system. For each, you must specify a list of variables in the same order as the model uses 

66 them. The model will use all exogenous inputs first, and then all coupling inputs. You can use a variable's 

67 global integer index into the system `exo_vars` or `coupling_vars`, or you can use the `str` id of the variable 

68 or the variable itself. This is summarized in the `amisc.IndicesRV` custom type. 

69 

70 !!! Example 

71 Let's say you have a model: 

72 ```python 

73 def my_model(x, *args, **kwargs): 

74 print(x.shape) # (3,), so a total of 3 inputs 

75 G = 6.674e-11 

76 m1 = x[0] # System-level input 

77 m2 = x[1] # System-level input 

78 r = x[2] # Coupling input 

79 F = G*m1*m2 / r**2 

80 return {'y': F} 

81 ``` 

82 Let's say this model is part of a larger system where `m1` and `m2` are specified by the system, and `r` comes 

83 from a different model that predicts the distance between two objects. You would set the configuration as: 

84 ```python 

85 component = ComponentSpec(my_model, exo_in=['m1', 'm2'], coupling_in=['r'], coupling_out=['F']) 

86 ``` 

87 """ 

88 Options = ['model', 'name', 'exo_in', 'coupling_in', 'coupling_out', 'truth_alpha', 'max_alpha', 'max_beta', 

89 'surrogate', 'model_args', 'model_kwargs', 'save_output'] 

90 

91 def __init__(self, model: callable, name: str = '', exo_in: IndicesRV = None, 

92 coupling_in: IndicesRV | dict[str: IndicesRV] = None, coupling_out: IndicesRV = None, 

93 truth_alpha: tuple | int = (), max_alpha: tuple | int = (), max_beta: tuple | int = (), 

94 surrogate: str | ComponentSurrogate = 'lagrange', model_args: tuple = (), model_kwargs: dict = None, 

95 save_output: bool = False): 

96 """Construct the configuration for this component model. 

97 

98 !!! Warning 

99 Always specify the model at a _global_ scope, i.e. don't use `lambda` or nested functions. When saving to 

100 file, only a symbolic reference to the function signature will be saved, which must be globally defined 

101 when loading back from that save file. 

102 

103 :param model: the component model, must be defined in a global scope (i.e. in a module or top-level of a script) 

104 :param name: the name used to identify this component model 

105 :param exo_in: specifies the global, system-level (i.e. exogenous/external) inputs to this model 

106 :param coupling_in: specifies the coupling inputs received from other models 

107 :param coupling_out: specifies all outputs of this model (which may couple later to downstream models) 

108 :param truth_alpha: the model fidelity indices to treat as a "ground truth" reference 

109 :param max_alpha: the maximum model fidelity indices to allow for refinement purposes 

110 :param max_beta: the maximum surrogate fidelity indices to allow for refinement purposes 

111 :param surrogate: one of ('lagrange, 'analytical'), or the `ComponentSurrogate` class to use directly 

112 :param model_args: optional arguments to pass to the component model 

113 :param model_kwargs: optional keyword arguments to pass to the component model 

114 :param save_output: whether this model will be saving outputs to file 

115 """ 

116 d = locals() 

117 d2 = {key: value for key, value in d.items() if key in ComponentSpec.Options} 

118 super().__init__(d2) 

119 

120 def __setitem__(self, key, value): 

121 if key in ComponentSpec.Options: 

122 super().__setitem__(key, value) 

123 else: 

124 raise ValueError(f'"{key}" is not applicable for a ComponentSpec. Try one of {ComponentSpec.Options}.') 

125 

126 def __delitem__(self, key): 

127 raise TypeError("Not allowed to delete items from a ComponentSpec.") 

128 

129 

130class SystemSurrogate: 

131 """Multidisciplinary (MD) surrogate framework top-level class. 

132 

133 !!! Note "Accessing individual components" 

134 The `ComponentSurrogate` objects that compose `SystemSurrogate` are internally stored in the `self.graph.nodes` 

135 data structure. You can access them with `get_component(comp_name)`. 

136 

137 :ivar exo_vars: global list of exogenous/external inputs for the MD system 

138 :ivar coupling_vars: global list of coupling variables for the MD system (including all system-level outputs) 

139 :ivar refine_level: the total number of refinement steps that have been made 

140 :ivar build_metrics: contains data that summarizes surrogate training progress 

141 :ivar root_dir: root directory where all surrogate build products are saved to file 

142 :ivar log_file: log file where all logs are written to by default 

143 :ivar executor: manages parallel execution for the system 

144 :ivar graph: the internal graph data structure of the MD system 

145 

146 :vartype exo_vars: list[BaseRV] 

147 :vartype coupling_vars: list[BaseRV] 

148 :vartype refine_level: int 

149 :vartype build_metrics: dict 

150 :vartype root_dir: str 

151 :vartype log_file: str 

152 :vartype executor: Executor 

153 :vartype graph: nx.DiGraph 

154 """ 

155 

156 def __init__(self, components: list[ComponentSpec] | ComponentSpec, exo_vars: list[BaseRV] | BaseRV, 

157 coupling_vars: list[BaseRV] | BaseRV, est_bds: int = 0, save_dir: str | Path = None, 

158 executor: Executor = None, stdout: bool = True, init_surr: bool = True, logger_name: str = None): 

159 """Construct the MD system surrogate. 

160 

161 !!! Warning 

162 Component models should always use coupling variables in the order they appear in the system-level 

163 `coupling_vars`. 

164 

165 :param components: list of components in the MD system (using the ComponentSpec class) 

166 :param exo_vars: list of system-level exogenous/external inputs 

167 :param coupling_vars: list of all coupling variables (including all system-level outputs) 

168 :param est_bds: number of samples to estimate coupling variable bounds, do nothing if 0 

169 :param save_dir: root directory for all build products (.log, .pkl, .json, etc.), won't save if None 

170 :param executor: an instance of a `concurrent.futures.Executor`, use to iterate new candidates in parallel 

171 :param stdout: whether to log to console 

172 :param init_surr: whether to initialize the surrogate immediately when constructing 

173 :param logger_name: the name of the logger to use, if None then uses class name by default 

174 """ 

175 # Setup root save directory 

176 if save_dir is not None: 

177 timestamp = datetime.datetime.now(tz=timezone.utc).isoformat().split('.')[0].replace(':', '.') 

178 save_dir = Path(save_dir) / ('amisc_' + timestamp) 

179 os.mkdir(save_dir) 

180 self.root_dir = None 

181 self.log_file = None 

182 self.logger = None 

183 self.executor = executor 

184 self.graph = nx.DiGraph() 

185 self.set_root_directory(save_dir, stdout=stdout, logger_name=logger_name) 

186 

187 # Store system info in a directed graph data structure 

188 self.exo_vars = copy.deepcopy(exo_vars) if isinstance(exo_vars, list) else [exo_vars] 

189 self.x_vars = self.exo_vars # Create an alias to be consistent with components 

190 self.coupling_vars = copy.deepcopy(coupling_vars) if isinstance(coupling_vars, list) else [coupling_vars] 

191 self.refine_level = 0 

192 self.build_metrics = dict() # Save refinement error metrics during training 

193 

194 # Construct graph nodes 

195 components = [components] if not isinstance(components, list) else components 

196 for k, comp in enumerate(components): 

197 if comp['name'] == '': 

198 comp['name'] = f'Component {k}' 

199 Nk = len(components) 

200 nodes = {comp['name']: comp for comp in components} # work-around since self.graph.nodes is not built yet 

201 for k in range(Nk): 

202 # Add the component as a str() node, with attributes specifying details of the surrogate 

203 comp_dict = components[k] 

204 indices, surr = self._build_component(comp_dict, nodes=nodes) 

205 self.graph.add_node(comp_dict['name'], surrogate=surr, is_computed=False, **indices) 

206 

207 # Connect all neighbor nodes 

208 for node, node_obj in self.graph.nodes.items(): 

209 for neighbor in node_obj['local_in']: 

210 self.graph.add_edge(neighbor, node) 

211 

212 self.set_logger(logger_name, stdout=stdout) # Need to update component loggers 

213 

214 # Estimate coupling variable bounds 

215 if est_bds > 0: 

216 self._estimate_coupling_bds(est_bds) 

217 

218 # Init system with most coarse fidelity indices in each component 

219 if init_surr: 

220 self.init_system() 

221 self._save_progress('sys_init.pkl') 

222 

223 def _build_component(self, component: ComponentSpec, nodes=None) -> tuple[dict, ComponentSurrogate]: 

224 """Build and return a `ComponentSurrogate` from a `dict` that describes the component model/connections. 

225 

226 :param component: specifies details of a component (see `ComponentSpec`) 

227 :param nodes: `dict` of `{node: node_attributes}`, defaults to `self.graph.nodes` 

228 :returns: `connections, surr`: a `dict` of all connection indices and the `ComponentSurrogate` object 

229 """ 

230 nodes = self.graph.nodes if nodes is None else nodes 

231 kwargs = component.get('model_kwargs', {}) 

232 kwargs = {} if kwargs is None else kwargs 

233 

234 # Set up defaults if this is a trivial one component system 

235 exo_in = component.get('exo_in', None) 

236 coupling_in = component.get('coupling_in', None) 

237 coupling_out = component.get('coupling_out', None) 

238 if len(nodes) == 1: 

239 exo_in = list(np.arange(0, len(self.exo_vars))) 

240 coupling_in = [] 

241 coupling_out = list(np.arange(0, len(self.coupling_vars))) 

242 else: 

243 exo_in = [] if exo_in is None else exo_in 

244 coupling_in = [] if coupling_in is None else coupling_in 

245 coupling_out = [] if coupling_out is None else coupling_out 

246 exo_in = [exo_in] if not isinstance(exo_in, list) else exo_in 

247 coupling_in = [coupling_in] if not isinstance(coupling_in, list | dict) else coupling_in 

248 component['coupling_out'] = [coupling_out] if not isinstance(coupling_out, list) else coupling_out 

249 

250 # Raise an error if all inputs or all outputs are empty 

251 if len(exo_in) + len(coupling_in) == 0: 

252 raise ValueError(f'Component {component["name"]} has no inputs! Please specify inputs in ' 

253 f'"exo_in" or "coupling_in".') 

254 if len(component['coupling_out']) == 0: 

255 raise ValueError(f'Component {component["name"]} has no outputs! Please specify outputs in ' 

256 f'"coupling_out".') 

257 

258 # Get exogenous input indices (might already be a list of ints, otherwise convert list of vars to indices) 

259 if len(exo_in) > 0: 

260 if isinstance(exo_in[0], str | BaseRV): 

261 exo_in = [self.exo_vars.index(var) for var in exo_in] 

262 

263 # Get global coupling output indices for all nodes (convert list of vars to list of indices if necessary) 

264 global_out = {} 

265 for node, node_obj in nodes.items(): 

266 node_use = node_obj if node != component.get('name') else component 

267 coupling_out = node_use.get('coupling_out', None) 

268 coupling_out = [] if coupling_out is None else coupling_out 

269 coupling_out = [coupling_out] if not isinstance(coupling_out, list) else coupling_out 

270 global_out[node] = [self.coupling_vars.index(var) for var in coupling_out] if isinstance( 

271 coupling_out[0], str | BaseRV) else coupling_out 

272 

273 # Refactor coupling inputs into both local and global index formats 

274 local_in = dict() # e.g. {'Cathode': [0, 1, 2], 'Thruster': [0,], etc...} 

275 global_in = list() # e.g. [0, 2, 4, 5, 6] 

276 if isinstance(coupling_in, dict): 

277 # If already a dict, get local connection indices from each neighbor 

278 for node, connections in coupling_in.items(): 

279 conn_list = [connections] if not isinstance(connections, list) else connections 

280 if isinstance(conn_list[0], str | BaseRV): 

281 global_ind = [self.coupling_vars.index(var) for var in conn_list] 

282 local_in[node] = sorted([global_out[node].index(i) for i in global_ind]) 

283 else: 

284 local_in[node] = sorted(conn_list) 

285 

286 # Convert to global coupling indices 

287 for node, local_idx in local_in.items(): 

288 global_in.extend([global_out[node][i] for i in local_idx]) 

289 global_in = sorted(global_in) 

290 else: 

291 # Otherwise, convert a list of global indices or vars into a dict of local indices 

292 if len(coupling_in) > 0: 

293 if isinstance(coupling_in[0], str | BaseRV): 

294 coupling_in = [self.coupling_vars.index(var) for var in coupling_in] 

295 global_in = sorted(coupling_in) 

296 for node, node_obj in nodes.items(): 

297 if node != component['name']: 

298 l = list() # noqa: E741 

299 for i in global_in: 

300 try: 

301 l.append(global_out[node].index(i)) 

302 except ValueError: 

303 pass 

304 if l: 

305 local_in[node] = sorted(l) 

306 

307 # Store all connection indices for this component 

308 connections = dict(exo_in=exo_in, local_in=local_in, global_in=global_in, 

309 global_out=global_out.get(component.get('name'))) 

310 

311 # Set up a component output save directory 

312 if component.get('save_output', False) and self.root_dir is not None: 

313 output_dir = str((Path(self.root_dir) / 'components' / component['name']).resolve()) 

314 if not Path(output_dir).is_dir(): 

315 os.mkdir(output_dir) 

316 kwargs['output_dir'] = output_dir 

317 else: 

318 if kwargs.get('output_dir', None) is not None: 

319 kwargs['output_dir'] = None 

320 

321 # Initialize a new component surrogate 

322 surr_class = component.get('surrogate', 'lagrange') 

323 if isinstance(surr_class, str): 

324 match surr_class: 

325 case 'lagrange': 

326 surr_class = SparseGridSurrogate 

327 case 'analytical': 

328 surr_class = AnalyticalSurrogate 

329 case other: 

330 raise NotImplementedError(f"Surrogate type '{other}' is not known at this time.") 

331 

332 # Check for an override of model fidelity indices (to enable just single-fidelity evaluation) 

333 if kwargs.get('hf_override', False): 

334 truth_alpha, max_alpha = (), () 

335 kwargs['hf_override'] = component['truth_alpha'] # Pass in the truth alpha indices as a kwarg to model 

336 else: 

337 truth_alpha, max_alpha = component['truth_alpha'], component['max_alpha'] 

338 max_beta = component.get('max_beta') 

339 truth_alpha = (truth_alpha,) if isinstance(truth_alpha, int) else truth_alpha 

340 max_alpha = (max_alpha,) if isinstance(max_alpha, int) else max_alpha 

341 max_beta = (max_beta,) if isinstance(max_beta, int) else max_beta 

342 

343 # Assumes input ordering is exogenous vars + sorted coupling vars 

344 x_vars = [self.exo_vars[i] for i in exo_in] + [self.coupling_vars[i] for i in global_in] 

345 surr = surr_class(x_vars, component['model'], truth_alpha=truth_alpha, max_alpha=max_alpha, 

346 max_beta=max_beta, executor=self.executor, log_file=self.log_file, 

347 model_args=component.get('model_args'), model_kwargs=kwargs) 

348 return connections, surr 

349 

350 def swap_component(self, component: ComponentSpec, exo_add: BaseRV | list[BaseRV] = None, 

351 exo_remove: IndicesRV = None, qoi_add: BaseRV | list[BaseRV] = None, 

352 qoi_remove: IndicesRV = None): 

353 """Swap a new component into the system, updating all connections/inputs. 

354 

355 !!! Warning "Beta feature, proceed with caution" 

356 If you are swapping a new component in, you cannot remove any inputs that are expected by other components, 

357 including the coupling variables output by the current model. 

358 

359 :param component: specs of new component model (must replace an existing component with matching `name`) 

360 :param exo_add: variables to add to system exogenous inputs (will be appended to end) 

361 :param exo_remove: indices of system exogenous inputs to delete (can't be shared by other components) 

362 :param qoi_add: system output QoIs to add (will be appended to end of `coupling_vars`) 

363 :param qoi_remove: indices of system `coupling_vars` to delete (can't be shared by other components) 

364 """ 

365 # Delete system exogenous inputs 

366 if exo_remove is None: 

367 exo_remove = [] 

368 exo_remove = [exo_remove] if not isinstance(exo_remove, list) else exo_remove 

369 exo_remove = [self.exo_vars.index(var) for var in exo_remove] if exo_remove and isinstance( 

370 exo_remove[0], str | BaseRV) else exo_remove 

371 

372 exo_remove = sorted(exo_remove) 

373 for j, exo_var_idx in enumerate(exo_remove): 

374 # Adjust exogenous indices for all components to account for deleted system inputs 

375 for node, node_obj in self.graph.nodes.items(): 

376 if node != component['name']: 

377 for i, idx in enumerate(node_obj['exo_in']): 

378 if idx == exo_var_idx: 

379 error_msg = f"Can't delete system exogenous input at idx {exo_var_idx}, since it is " \ 

380 f"shared by component '{node}'." 

381 self.logger.error(error_msg) 

382 raise ValueError(error_msg) 

383 if idx > exo_var_idx: 

384 node_obj['exo_in'][i] -= 1 

385 

386 # Need to update the remaining delete indices by -1 to account for each sequential deletion 

387 del self.exo_vars[exo_var_idx] 

388 for i in range(j+1, len(exo_remove)): 

389 exo_remove[i] -= 1 

390 

391 # Append any new exogenous inputs to the end 

392 if exo_add is not None: 

393 exo_add = [exo_add] if not isinstance(exo_add, list) else exo_add 

394 self.exo_vars.extend(exo_add) 

395 

396 # Delete system qoi outputs (if not shared by other components) 

397 qoi_remove = sorted(self._get_qoi_ind(qoi_remove)) 

398 for j, qoi_idx in enumerate(qoi_remove): 

399 # Adjust coupling indices for all components to account for deleted system outputs 

400 for node, node_obj in self.graph.nodes.items(): 

401 if node != component['name']: 

402 for i, idx in enumerate(node_obj['global_in']): 

403 if idx == qoi_idx: 

404 error_msg = f"Can't delete system QoI at idx {qoi_idx}, since it is an input to " \ 

405 f"component '{node}'." 

406 self.logger.error(error_msg) 

407 raise ValueError(error_msg) 

408 if idx > qoi_idx: 

409 node_obj['global_in'][i] -= 1 

410 

411 for i, idx in enumerate(node_obj['global_out']): 

412 if idx > qoi_idx: 

413 node_obj['global_out'][i] -= 1 

414 

415 # Need to update the remaining delete indices by -1 to account for each sequential deletion 

416 del self.coupling_vars[qoi_idx] 

417 for i in range(j+1, len(qoi_remove)): 

418 qoi_remove[i] -= 1 

419 

420 # Append any new system QoI outputs to the end 

421 if qoi_add is not None: 

422 qoi_add = [qoi_add] if not isinstance(qoi_add, list) else qoi_add 

423 self.coupling_vars.extend(qoi_add) 

424 

425 # Build and initialize the new component surrogate 

426 indices, surr = self._build_component(component) 

427 surr.init_coarse() 

428 

429 # Make changes to adj matrix if coupling inputs changed 

430 prev_neighbors = list(self.graph.nodes[component['name']]['local_in'].keys()) 

431 new_neighbors = list(indices['local_in'].keys()) 

432 for neighbor in new_neighbors: 

433 if neighbor not in prev_neighbors: 

434 self.graph.add_edge(neighbor, component['name']) 

435 else: 

436 prev_neighbors.remove(neighbor) 

437 for neighbor in prev_neighbors: 

438 self.graph.remove_edge(neighbor, component['name']) 

439 

440 self.logger.info(f"Swapped component '{component['name']}'.") 

441 nx.set_node_attributes(self.graph, {component['name']: {'exo_in': indices['exo_in'], 'local_in': 

442 indices['local_in'], 'global_in': indices['global_in'], 

443 'global_out': indices['global_out'], 

444 'surrogate': surr, 'is_computed': False}}) 

445 

446 def insert_component(self, component: ComponentSpec, exo_add: BaseRV | list[BaseRV] = None, 

447 qoi_add: BaseRV | list[BaseRV] = None): 

448 """Insert a new component into the system. 

449 

450 :param component: specs of new component model 

451 :param exo_add: variables to add to system exogenous inputs (will be appended to end of `exo_vars`) 

452 :param qoi_add: system output QoIs to add (will be appended to end of `coupling_vars`) 

453 """ 

454 if exo_add is not None: 

455 exo_add = [exo_add] if not isinstance(exo_add, list) else exo_add 

456 self.exo_vars.extend(exo_add) 

457 if qoi_add is not None: 

458 qoi_add = [qoi_add] if not isinstance(qoi_add, list) else qoi_add 

459 self.coupling_vars.extend(qoi_add) 

460 

461 indices, surr = self._build_component(component) 

462 surr.init_coarse() 

463 self.graph.add_node(component['name'], surrogate=surr, is_computed=False, **indices) 

464 

465 # Add graph edges 

466 neighbors = list(indices['local_in'].keys()) 

467 for neighbor in neighbors: 

468 self.graph.add_edge(neighbor, component['name']) 

469 self.logger.info(f"Inserted component '{component['name']}'.") 

470 

471 def _save_on_error(func): 

472 """Gracefully exit and save `SystemSurrogate` on any errors.""" 

473 @functools.wraps(func) 

474 def wrap(self, *args, **kwargs): 

475 try: 

476 return func(self, *args, **kwargs) 

477 except: 

478 self._save_progress('sys_error.pkl') 

479 self.logger.critical(f'An error occurred during execution of {func.__name__}. Saving ' 

480 f'SystemSurrogate object to sys_error.pkl', exc_info=True) 

481 self.logger.info(f'Final system surrogate on exit: \n {self}') 

482 raise 

483 return wrap 

484 _save_on_error = staticmethod(_save_on_error) 

485 

486 @_save_on_error 

487 def init_system(self): 

488 """Add the coarsest multi-index to each component surrogate.""" 

489 self._print_title_str('Initializing all component surrogates') 

490 for node, node_obj in self.graph.nodes.items(): 

491 node_obj['surrogate'].init_coarse() 

492 # for alpha, beta in list(node_obj['surrogate'].candidate_set): 

493 # # Add one refinement in each input dimension to initialize 

494 # node_obj['surrogate'].activate_index(alpha, beta) 

495 self.logger.info(f"Initialized component '{node}'.") 

496 

497 @_save_on_error 

498 def fit(self, qoi_ind: IndicesRV = None, num_refine: int = 100, max_iter: int = 20, max_tol: float = 1e-3, 

499 max_runtime: float = 1, save_interval: int = 0, update_bounds: bool = True, test_set: dict = None, 

500 n_jobs: int = 1): 

501 """Train the system surrogate adaptively by iterative refinement until an end condition is met. 

502 

503 :param qoi_ind: list of system QoI variables to focus refinement on, use all QoI if not specified 

504 :param num_refine: number of samples of exogenous inputs to compute error indicators on 

505 :param max_iter: the maximum number of refinement steps to take 

506 :param max_tol: the max allowable value in relative L2 error to achieve 

507 :param max_runtime: the maximum wall clock time (hr) to run refinement for (will go until all models finish) 

508 :param save_interval: number of refinement steps between each progress save, none if 0 

509 :param update_bounds: whether to continuously update coupling variable bounds during refinement 

510 :param test_set: `dict(xt=(Nt, x_dim), yt=(Nt, y_dim)` to show convergence of surrogate to the truth model 

511 :param n_jobs: number of cpu workers for computing error indicators (on master MPI task), 1=sequential 

512 """ 

513 qoi_ind = self._get_qoi_ind(qoi_ind) 

514 Nqoi = len(qoi_ind) 

515 max_iter = self.refine_level + max_iter 

516 curr_error = np.inf 

517 t_start = time.time() 

518 test_stats, xt, yt, t_fig, t_ax = None, None, None, None, None 

519 

520 # Record of (error indicator, component, alpha, beta, num_evals, total added cost (s)) for each iteration 

521 train_record = self.build_metrics.get('train_record', []) 

522 if test_set is not None: 

523 xt, yt = test_set['xt'], test_set['yt'] 

524 xt, yt = self.build_metrics.get('xt', xt), self.build_metrics.get('yt', yt) # Overrides test set param 

525 

526 # Track convergence progress on a test set and on the max error indicator 

527 err_fig, err_ax = plt.subplots() 

528 if xt is not None and yt is not None: 

529 self.build_metrics['xt'] = xt 

530 self.build_metrics['yt'] = yt 

531 if self.build_metrics.get('test_stats') is not None: 

532 test_stats = self.build_metrics.get('test_stats') 

533 else: 

534 # Get initial perf metrics, (2, Nqoi) 

535 test_stats = np.expand_dims(self.get_test_metrics(xt, yt, qoi_ind=qoi_ind), axis=0) 

536 t_fig, t_ax = plt.subplots(1, Nqoi) if Nqoi > 1 else plt.subplots() 

537 

538 # Set up a parallel pool of workers, sequential if n_jobs=1 

539 with Parallel(n_jobs=n_jobs, verbose=0) as ppool: 

540 while True: 

541 # Check all end conditions 

542 if self.refine_level >= max_iter: 

543 self._print_title_str(f'Termination criteria reached: Max iteration {self.refine_level}/{max_iter}') 

544 break 

545 if curr_error == -np.inf: 

546 self._print_title_str('Termination criteria reached: No candidates left to refine') 

547 break 

548 if curr_error < max_tol: 

549 self._print_title_str(f'Termination criteria reached: L2 error {curr_error} < tol {max_tol}') 

550 break 

551 if ((time.time() - t_start)/3600.0) >= max_runtime: 

552 actual = datetime.timedelta(seconds=time.time()-t_start) 

553 target = datetime.timedelta(seconds=max_runtime*3600) 

554 self._print_title_str(f'Termination criteria reached: runtime {str(actual)} > {str(target)}') 

555 break 

556 

557 # Refine surrogate and save progress 

558 refine_res = self.refine(qoi_ind=qoi_ind, num_refine=num_refine, update_bounds=update_bounds, 

559 ppool=ppool) 

560 curr_error = refine_res[0] 

561 if save_interval > 0 and self.refine_level % save_interval == 0: 

562 self._save_progress(f'sys_iter_{self.refine_level}.pkl') 

563 

564 # Plot progress of error indicator 

565 train_record.append(refine_res) 

566 error_record = [res[0] for res in train_record] 

567 self.build_metrics['train_record'] = train_record 

568 err_ax.clear(); err_ax.grid(); err_ax.plot(error_record, '-k') 

569 ax_default(err_ax, 'Iteration', r'Relative $L_2$ error indicator', legend=False) 

570 err_ax.set_yscale('log') 

571 if self.root_dir is not None: 

572 err_fig.savefig(str(Path(self.root_dir) / 'error_indicator.png'), dpi=300, format='png') 

573 

574 # Plot progress on test set 

575 if xt is not None and yt is not None: 

576 stats = self.get_test_metrics(xt, yt, qoi_ind=qoi_ind) 

577 test_stats = np.concatenate((test_stats, stats[np.newaxis, ...]), axis=0) 

578 for i in range(Nqoi): 

579 ax = t_ax if Nqoi == 1 else t_ax[i] 

580 ax.clear(); ax.grid(); ax.set_yscale('log') 

581 ax.plot(test_stats[:, 1, i], '-k') 

582 ax.set_title(self.coupling_vars[qoi_ind[i]].to_tex(units=True)) 

583 ax_default(ax, 'Iteration', r'Relative $L_2$ error', legend=False) 

584 t_fig.set_size_inches(3.5*Nqoi, 3.5) 

585 t_fig.tight_layout() 

586 if self.root_dir is not None: 

587 t_fig.savefig(str(Path(self.root_dir) / 'test_set.png'), dpi=300, format='png') 

588 self.build_metrics['test_stats'] = test_stats 

589 

590 self._save_progress('sys_final.pkl') 

591 self.logger.info(f'Final system surrogate: \n {self}') 

592 

593 def get_allocation(self, idx: int = None): 

594 """Get a breakdown of cost allocation up to a certain iteration number during training (starting at 1). 

595 

596 :param idx: the iteration number to get allocation results for (defaults to last refinement step) 

597 :returns: `cost_alloc, offline_alloc, cost_cum` - the cost alloc per node/fidelity and cumulative training cost 

598 """ 

599 if idx is None: 

600 idx = self.refine_level 

601 if idx > self.refine_level: 

602 raise ValueError(f'Specified index: {idx} is greater than the max training level of {self.refine_level}') 

603 

604 cost_alloc = dict() # Cost allocation per node and model fidelity 

605 cost_cum = [0.0] # Cumulative cost allocation during training 

606 

607 # Add initialization costs for each node 

608 for node, node_obj in self.graph.nodes.items(): 

609 surr = node_obj['surrogate'] 

610 base_alpha = (0,) * len(surr.truth_alpha) 

611 base_beta = (0,) * (len(surr.max_refine) - len(surr.truth_alpha)) 

612 base_cost = surr.get_cost(base_alpha, base_beta) 

613 cost_alloc[node] = dict() 

614 if base_cost > 0: 

615 cost_alloc[node][str(base_alpha)] = np.array([1, float(base_cost)]) 

616 cost_cum[0] += float(base_cost) 

617 

618 # Add cumulative training costs 

619 for i in range(idx): 

620 err_indicator, node, alpha, beta, num_evals, cost = self.build_metrics['train_record'][i] 

621 if cost_alloc[node].get(str(alpha), None) is None: 

622 cost_alloc[node][str(alpha)] = np.zeros(2) # (num model evals, total cpu_time cost) 

623 cost_alloc[node][str(alpha)] += [round(num_evals), float(cost)] 

624 cost_cum.append(float(cost)) 

625 

626 # Get summary of total offline costs spent building search candidates (i.e. training overhead) 

627 offline_alloc = dict() 

628 for node, node_obj in self.graph.nodes.items(): 

629 surr = node_obj['surrogate'] 

630 offline_alloc[node] = dict() 

631 for alpha, beta in surr.candidate_set: 

632 if offline_alloc[node].get(str(alpha), None) is None: 

633 offline_alloc[node][str(alpha)] = np.zeros(2) # (num model evals, total cpu_time cost) 

634 added_cost = surr.get_cost(alpha, beta) 

635 base_cost = surr.get_sub_surrogate(alpha, beta).model_cost 

636 offline_alloc[node][str(alpha)] += [round(added_cost/base_cost), float(added_cost)] 

637 

638 return cost_alloc, offline_alloc, np.cumsum(cost_cum) 

639 

640 def get_test_metrics(self, xt: np.ndarray, yt: np.ndarray, qoi_ind: IndicesRV = None, 

641 training: bool = True) -> np.ndarray: 

642 """Get relative L2 error metric over a test set. 

643 

644 :param xt: `(Nt, x_dim)` random test set of inputs 

645 :param yt: `(Nt, y_dim)` random test set outputs 

646 :param qoi_ind: list of indices of QoIs to get metrics for 

647 :param training: whether to evaluate the surrogate in training or evaluation mode 

648 :returns: `stats` - `(2, Nqoi)` array &rarr; `[num_candidates, rel_L2_error]` for each QoI 

649 """ 

650 qoi_ind = self._get_qoi_ind(qoi_ind) 

651 ysurr = self(xt, training=training) 

652 ysurr = ysurr[:, qoi_ind] 

653 yt = yt[:, qoi_ind] 

654 with np.errstate(divide='ignore', invalid='ignore'): 

655 rel_l2_err = np.sqrt(np.mean((yt - ysurr) ** 2, axis=0)) / np.sqrt(np.mean(yt ** 2, axis=0)) 

656 rel_l2_err = np.nan_to_num(rel_l2_err, posinf=np.nan, neginf=np.nan, nan=np.nan) 

657 num_cands = 0 

658 for node, node_obj in self.graph.nodes.items(): 

659 num_cands += len(node_obj['surrogate'].index_set) + len(node_obj['surrogate'].candidate_set) 

660 

661 # Get test stats for each QoI 

662 stats = np.zeros((2, yt.shape[-1])) 

663 self.logger.debug(f'{"QoI idx": >10} {"Iteration": >10} {"len(I_k)": >10} {"Relative L2": >15}') 

664 for i in range(yt.shape[-1]): 

665 stats[:, i] = np.array([num_cands, rel_l2_err[i]]) 

666 self.logger.debug(f'{i: 10d} {self.refine_level: 10d} {num_cands: 10d} {rel_l2_err[i]: 15.5f}') 

667 

668 return stats 

669 

670 def _get_qoi_ind(self, qoi_ind: IndicesRV) -> list[int]: 

671 """Small helper to make sure QoI indices are a list of integers.""" 

672 if qoi_ind is None: 

673 qoi_ind = list(np.arange(0, len(self.coupling_vars))) 

674 qoi_ind = [qoi_ind] if not isinstance(qoi_ind, list) else qoi_ind 

675 qoi_ind = [self.coupling_vars.index(var) for var in qoi_ind] if qoi_ind and isinstance( 

676 qoi_ind[0], str | BaseRV) else qoi_ind 

677 

678 return qoi_ind 

679 

680 def refine(self, qoi_ind: IndicesRV = None, num_refine: int = 100, update_bounds: bool = True, 

681 ppool: Parallel = None) -> tuple: 

682 """Find and refine the component surrogate with the largest error on system-level QoI. 

683 

684 :param qoi_ind: indices of system QoI to focus surrogate refinement on, use all QoI if not specified 

685 :param num_refine: number of samples of exogenous inputs to compute error indicators on 

686 :param update_bounds: whether to continuously update coupling variable bounds 

687 :param ppool: a `Parallel` instance from `joblib` to compute error indicators in parallel, None=sequential 

688 :returns refine_res: a tuple of `(error_indicator, component, node_star, alpha_star, beta_star, N, cost)` 

689 indicating the chosen candidate index and incurred cost 

690 """ 

691 self._print_title_str(f'Refining system surrogate: iteration {self.refine_level + 1}') 

692 set_loky_pickler('dill') # Dill can serialize 'self' for parallel workers 

693 temp_exc = self.executor # It can't serialize an executor though, so must save this temporarily 

694 self.set_executor(None) 

695 qoi_ind = self._get_qoi_ind(qoi_ind) 

696 

697 # Compute entire integrated-surrogate on a random test set for global system QoI error estimation 

698 x_exo = self.sample_inputs((num_refine,)) 

699 y_curr = self(x_exo, training=True) 

700 y_min, y_max = None, None 

701 if update_bounds: 

702 y_min = np.min(y_curr, axis=0, keepdims=True) # (1, ydim) 

703 y_max = np.max(y_curr, axis=0, keepdims=True) # (1, ydim) 

704 

705 # Find the candidate surrogate with the largest error indicator 

706 error_max, error_indicator = -np.inf, -np.inf 

707 node_star, alpha_star, beta_star, l2_star, cost_star = None, None, None, -np.inf, 0 

708 for node, node_obj in self.graph.nodes.items(): 

709 self.logger.info(f"Estimating error for component '{node}'...") 

710 candidates = node_obj['surrogate'].candidate_set.copy() 

711 

712 def compute_error(alpha, beta): 

713 # Helper function for computing error indicators for a given candidate (alpha, beta) 

714 index_set = node_obj['surrogate'].index_set.copy() 

715 index_set.append((alpha, beta)) 

716 y_cand = self(x_exo, training=True, index_set={node: index_set}) 

717 ymin = np.min(y_cand, axis=0, keepdims=True) 

718 ymax = np.max(y_cand, axis=0, keepdims=True) 

719 error = y_cand[:, qoi_ind] - y_curr[:, qoi_ind] 

720 rel_l2 = np.sqrt(np.nanmean(error ** 2, axis=0)) / np.sqrt(np.nanmean(y_cand[:, qoi_ind] ** 2, axis=0)) 

721 rel_l2 = np.nan_to_num(rel_l2, nan=np.nan, posinf=np.nan, neginf=np.nan) 

722 delta_error = np.nanmax(rel_l2) # Max relative L2 error over all system QoIs 

723 delta_work = max(1, node_obj['surrogate'].get_cost(alpha, beta)) # Cpu time (s) 

724 

725 return ymin, ymax, delta_error, delta_work 

726 

727 if len(candidates) > 0: 

728 ret = ppool(delayed(compute_error)(alpha, beta) for alpha, beta in candidates) if ppool is not None \ 

729 else [compute_error(alpha, beta) for alpha, beta in candidates] 

730 

731 for i, (ymin, ymax, d_error, d_work) in enumerate(ret): 

732 if update_bounds: 

733 y_min = np.min(np.concatenate((y_min, ymin), axis=0), axis=0, keepdims=True) 

734 y_max = np.max(np.concatenate((y_max, ymax), axis=0), axis=0, keepdims=True) 

735 alpha, beta = candidates[i] 

736 error_indicator = d_error / d_work 

737 self.logger.info(f"Candidate multi-index: {(alpha, beta)}. L2 error: {d_error}. Error indicator: " 

738 f"{error_indicator}.") 

739 

740 if error_indicator > error_max: 

741 error_max = error_indicator 

742 node_star, alpha_star, beta_star, l2_star, cost_star = node, alpha, beta, d_error, d_work 

743 else: 

744 self.logger.info(f"Component '{node}' has no available candidates left!") 

745 

746 # Update all coupling variable ranges 

747 if update_bounds: 

748 for i in range(y_curr.shape[-1]): 

749 self._update_coupling_bds(i, (y_min[0, i], y_max[0, i])) 

750 

751 # Add the chosen multi-index to the chosen component 

752 self.set_executor(temp_exc) 

753 if node_star is not None: 

754 self.logger.info(f"Candidate multi-index {(alpha_star, beta_star)} chosen for component '{node_star}'") 

755 self.graph.nodes[node_star]['surrogate'].activate_index(alpha_star, beta_star) 

756 self.refine_level += 1 

757 num_evals = round(cost_star / self[node_star].get_sub_surrogate(alpha_star, beta_star).model_cost) 

758 else: 

759 self.logger.info(f"No candidates left for refinement, iteration: {self.refine_level}") 

760 num_evals = 0 

761 

762 return l2_star, node_star, alpha_star, beta_star, num_evals, cost_star 

763 

764 def predict(self, x: np.ndarray | float, max_fpi_iter: int = 100, anderson_mem: int = 10, fpi_tol: float = 1e-10, 

765 use_model: str | tuple | dict = None, model_dir: str | Path = None, verbose: bool = False, 

766 training: bool = False, index_set: dict[str: IndexSet] = None, qoi_ind: IndicesRV = None, 

767 ppool=None) -> np.ndarray: 

768 """Evaluate the system surrogate at exogenous inputs `x`. 

769 

770 !!! Warning 

771 You can use this function to predict outputs for your MD system using the full-order models rather than the 

772 surrogate, by specifying `use_model`. This is convenient because `SystemSurrogate` manages all the 

773 coupled information flow between models automatically. However, it is *highly* recommended to not use 

774 the full model if your system contains feedback loops. The FPI nonlinear solver would be infeasible using 

775 anything more computationally demanding than the surrogate. 

776 

777 :param x: `(..., x_dim)` the points to get surrogate predictions for 

778 :param max_fpi_iter: the limit on convergence for the fixed-point iteration routine 

779 :param anderson_mem: hyperparameter for tuning the convergence of FPI with anderson acceleration 

780 :param fpi_tol: tolerance limit for convergence of fixed-point iteration 

781 :param use_model: 'best'=highest-fidelity, 'worst'=lowest-fidelity, tuple=specific fidelity, None=surrogate, 

782 specify a `dict` of the above to assign different model fidelities for diff components 

783 :param model_dir: directory to save model outputs if `use_model` is specified 

784 :param verbose: whether to print out iteration progress during execution 

785 :param training: whether to call the system surrogate in training or evaluation mode, ignored if `use_model` 

786 :param index_set: `dict(node=[indices])` to override default index set for a node (only useful for parallel) 

787 :param qoi_ind: list of qoi indices to return, defaults to returning all system `coupling_vars` 

788 :param ppool: a joblib `Parallel` instance to pass to each component to loop over multi-indices in parallel 

789 :returns y: `(..., y_dim)` the surrogate approximation of the system QoIs 

790 """ 

791 # Allocate space for all system outputs (just save all coupling vars) 

792 x = np.atleast_1d(x) 

793 ydim = len(self.coupling_vars) 

794 y = np.zeros(x.shape[:-1] + (ydim,), dtype=x.dtype) 

795 valid_idx = ~np.isnan(x[..., 0]) # Keep track of valid samples (set to False if FPI fails) 

796 t1 = 0 

797 output_dir = None 

798 qoi_ind = self._get_qoi_ind(qoi_ind) 

799 is_computed = np.full(ydim, False) 

800 

801 # Interpret which model fidelities to use for each component (if specified) 

802 if use_model is not None: 

803 if not isinstance(use_model, dict): 

804 use_model = {node: use_model for node in self.graph.nodes} # use same for each component 

805 else: 

806 use_model = {node: None for node in self.graph.nodes} 

807 use_model = {node: use_model.get(node, None) for node in self.graph.nodes} 

808 

809 # Initialize all components 

810 for node, node_obj in self.graph.nodes.items(): 

811 node_obj['is_computed'] = False 

812 

813 # Convert system into DAG by grouping strongly-connected-components 

814 dag = nx.condensation(self.graph) 

815 

816 # Compute component models in topological order 

817 for supernode in nx.topological_sort(dag): 

818 if np.all(is_computed[qoi_ind]): 

819 break # Exit early if all qois of interest are computed 

820 

821 scc = [n for n in dag.nodes[supernode]['members']] 

822 

823 # Compute single component feedforward output (no FPI needed) 

824 if len(scc) == 1: 

825 if verbose: 

826 self.logger.info(f"Running component '{scc[0]}'...") 

827 t1 = time.time() 

828 

829 # Gather inputs 

830 node_obj = self.graph.nodes[scc[0]] 

831 exo_inputs = x[..., node_obj['exo_in']] 

832 # for comp_name in node_obj['local_in']: 

833 # assert self.graph.nodes[comp_name]['is_computed'] 

834 coupling_inputs = y[..., node_obj['global_in']] 

835 comp_input = np.concatenate((exo_inputs, coupling_inputs), axis=-1) # (..., xdim) 

836 

837 # Compute outputs 

838 indices = index_set.get(scc[0], None) if index_set is not None else None 

839 if model_dir is not None: 

840 output_dir = Path(model_dir) / scc[0] 

841 os.mkdir(output_dir) 

842 comp_output = node_obj['surrogate'](comp_input[valid_idx, :], use_model=use_model.get(scc[0]), 

843 model_dir=output_dir, training=training, index_set=indices, 

844 ppool=ppool) 

845 for local_i, global_i in enumerate(node_obj['global_out']): 

846 y[valid_idx, global_i] = comp_output[..., local_i] 

847 is_computed[global_i] = True 

848 node_obj['is_computed'] = True 

849 

850 if verbose: 

851 self.logger.info(f"Component '{scc[0]}' completed. Runtime: {time.time() - t1} s") 

852 

853 # Handle FPI for SCCs with more than one component 

854 else: 

855 # Set the initial guess for all coupling vars (middle of domain) 

856 coupling_bds = [rv.bounds() for rv in self.coupling_vars] 

857 x_couple = np.array([(bds[0] + bds[1]) / 2 for bds in coupling_bds]).astype(x.dtype) 

858 x_couple = np.broadcast_to(x_couple, x.shape[:-1] + x_couple.shape).copy() 

859 

860 adj_nodes = [] 

861 fpi_idx = set() 

862 for node in scc: 

863 for comp_name, local_idx in self.graph.nodes[node]['local_in'].items(): 

864 # Track the global idx of all coupling vars that need FPI 

865 if comp_name in scc: 

866 fpi_idx.update([self.graph.nodes[comp_name]['global_out'][idx] for idx in local_idx]) 

867 

868 # Override coupling vars from components outside the scc (should already be computed) 

869 if comp_name not in scc and comp_name not in adj_nodes: 

870 # assert self.graph.nodes[comp_name]['is_computed'] 

871 global_idx = self.graph.nodes[comp_name]['global_out'] 

872 x_couple[..., global_idx] = y[..., global_idx] 

873 adj_nodes.append(comp_name) # Only need to do this once for each adj component 

874 x_couple_next = x_couple.copy() 

875 fpi_idx = sorted(fpi_idx) 

876 

877 # Main FPI loop 

878 if verbose: 

879 self.logger.info(f"Initializing FPI for SCC {scc} ...") 

880 t1 = time.time() 

881 k = 0 

882 residual_hist = None 

883 x_hist = None 

884 while True: 

885 for node in scc: 

886 # Gather inputs from exogenous and coupling sources 

887 node_obj = self.graph.nodes[node] 

888 exo_inputs = x[..., node_obj['exo_in']] 

889 coupling_inputs = x_couple[..., node_obj['global_in']] 

890 comp_input = np.concatenate((exo_inputs, coupling_inputs), axis=-1) # (..., xdim) 

891 

892 # Compute component outputs (just don't do this FPI with the real models, please..) 

893 indices = index_set.get(node, None) if index_set is not None else None 

894 comp_output = node_obj['surrogate'](comp_input[valid_idx, :], use_model=use_model.get(node), 

895 model_dir=None, training=training, index_set=indices, 

896 ppool=ppool) 

897 global_idx = node_obj['global_out'] 

898 for local_i, global_i in enumerate(global_idx): 

899 x_couple_next[valid_idx, global_i] = comp_output[..., local_i] 

900 # Can't splice valid_idx with global_idx for some reason, have to loop over global_idx here 

901 

902 # Compute residual and check end conditions 

903 residual = np.expand_dims(x_couple_next[..., fpi_idx] - x_couple[..., fpi_idx], axis=-1) 

904 max_error = np.max(np.abs(residual[valid_idx, :, :])) 

905 if verbose: 

906 self.logger.info(f'FPI iter: {k}. Max residual: {max_error}. Time: {time.time() - t1} s') 

907 if max_error <= fpi_tol: 

908 if verbose: 

909 self.logger.info(f'FPI converged for SCC {scc} in {k} iterations with {max_error} < tol ' 

910 f'{fpi_tol}. Final time: {time.time() - t1} s') 

911 break 

912 if k >= max_fpi_iter: 

913 self.logger.warning(f'FPI did not converge in {max_fpi_iter} iterations for SCC {scc}: ' 

914 f'{max_error} > tol {fpi_tol}. Some samples will be returned as NaN.') 

915 converged_idx = np.max(np.abs(residual), axis=(-1, -2)) <= fpi_tol 

916 for idx in fpi_idx: 

917 y[~converged_idx, idx] = np.nan 

918 valid_idx = np.logical_and(valid_idx, converged_idx) 

919 break 

920 

921 # Keep track of residual and x_couple histories 

922 if k == 0: 

923 residual_hist = residual.copy() # (..., xdim, 1) 

924 x_hist = np.expand_dims(x_couple_next[..., fpi_idx], axis=-1) # (..., xdim, 1) 

925 x_couple[:] = x_couple_next[:] 

926 k += 1 

927 continue # skip anderson accel on first iteration 

928 

929 # Iterate with anderson acceleration (only iterate on samples that are not yet converged) 

930 converged_idx = np.max(np.abs(residual), axis=(-1, -2)) <= fpi_tol 

931 curr_idx = np.logical_and(valid_idx, ~converged_idx) 

932 residual_hist = np.concatenate((residual_hist, residual), axis=-1) 

933 x_hist = np.concatenate((x_hist, np.expand_dims(x_couple_next[..., fpi_idx], axis=-1)), axis=-1) 

934 mk = min(anderson_mem, k) 

935 Fk = residual_hist[curr_idx, :, k-mk:] # (..., xdim, mk+1) 

936 C = np.ones(Fk.shape[:-2] + (1, mk + 1)) 

937 b = np.zeros(Fk.shape[:-2] + (len(fpi_idx), 1)) 

938 d = np.ones(Fk.shape[:-2] + (1, 1)) 

939 alpha = np.expand_dims(self._constrained_lls(Fk, b, C, d), axis=-3) # (..., 1, mk+1, 1) 

940 x_new = np.squeeze(x_hist[curr_idx, :, np.newaxis, -(mk+1):] @ alpha, axis=(-1, -2)) 

941 for local_i, global_i in enumerate(fpi_idx): 

942 x_couple[curr_idx, global_i] = x_new[..., local_i] 

943 k += 1 

944 

945 # Save outputs of each component in SCC after convergence of FPI 

946 for node in scc: 

947 global_idx = self.graph.nodes[node]['global_out'] 

948 for global_i in global_idx: 

949 y[valid_idx, global_i] = x_couple_next[valid_idx, global_i] 

950 is_computed[global_i] = True 

951 self.graph.nodes[node]['is_computed'] = True 

952 

953 # Return all component outputs (..., Nqoi); samples that didn't converge during FPI are left as np.nan 

954 return y[..., qoi_ind] 

955 

956 def __call__(self, *args, **kwargs): 

957 """Convenience wrapper to allow calling as `ret = SystemSurrogate(x)`.""" 

958 return self.predict(*args, **kwargs) 

959 

960 def _estimate_coupling_bds(self, num_est: int, anderson_mem: int = 10, fpi_tol: float = 1e-10, 

961 max_fpi_iter: int = 100): 

962 """Estimate and set the coupling variable bounds. 

963 

964 :param num_est: the number of samples of exogenous inputs to use 

965 :param anderson_mem: FPI hyperparameter (default is usually good) 

966 :param fpi_tol: floating point tolerance for FPI convergence 

967 :param max_fpi_iter: maximum number of FPI iterations 

968 """ 

969 self._print_title_str('Estimating coupling variable bounds') 

970 x = self.sample_inputs((num_est,)) 

971 y = self(x, use_model='best', verbose=True, anderson_mem=anderson_mem, fpi_tol=fpi_tol, 

972 max_fpi_iter=max_fpi_iter) 

973 for i in range(len(self.coupling_vars)): 

974 lb = np.nanmin(y[:, i]) 

975 ub = np.nanmax(y[:, i]) 

976 self._update_coupling_bds(i, (lb, ub), init=True) 

977 

978 def _update_coupling_bds(self, global_idx: int, bds: tuple, init: bool = False, buffer: float = 0.05): 

979 """Update coupling variable bounds. 

980 

981 :param global_idx: global index of coupling variable to update 

982 :param bds: new bounds to update the current bounds with 

983 :param init: whether to set new bounds or update existing (default) 

984 :param buffer: fraction of domain length to buffer upper/lower bounds 

985 """ 

986 offset = buffer * (bds[1] - bds[0]) 

987 offset_bds = (bds[0] - offset, bds[1] + offset) 

988 coupling_bds = [rv.bounds() for rv in self.coupling_vars] 

989 new_bds = offset_bds if init else (min(coupling_bds[global_idx][0], offset_bds[0]), 

990 max(coupling_bds[global_idx][1], offset_bds[1])) 

991 self.coupling_vars[global_idx].update_bounds(*new_bds) 

992 

993 # Iterate over all components and update internal coupling variable bounds 

994 for node_name, node_obj in self.graph.nodes.items(): 

995 if global_idx in node_obj['global_in']: 

996 # Get the local index for this coupling variable within each component's inputs 

997 local_idx = len(node_obj['exo_in']) + node_obj['global_in'].index(global_idx) 

998 node_obj['surrogate'].update_input_bds(local_idx, new_bds) 

999 

1000 def sample_inputs(self, size: tuple | int, comp: str = 'System', use_pdf: bool = False, 

1001 nominal: dict[str: float] = None, constants: set[str] = None) -> np.ndarray: 

1002 """Return samples of the inputs according to provided options. 

1003 

1004 :param size: tuple or integer specifying shape or number of samples to obtain 

1005 :param comp: which component to sample inputs for (defaults to full system exogenous inputs) 

1006 :param use_pdf: whether to sample from each variable's pdf, defaults to random samples over input domain instead 

1007 :param nominal: `dict(var_id=value)` of nominal values for params with relative uncertainty, also can use 

1008 to specify constant values for a variable listed in `constants` 

1009 :param constants: set of param types to hold constant while sampling (i.e. calibration, design, etc.), 

1010 can also put a `var_id` string in here to specify a single variable to hold constant 

1011 :returns x: `(*size, x_dim)` samples of the inputs for the given component/system 

1012 """ 

1013 size = (size, ) if isinstance(size, int) else size 

1014 if nominal is None: 

1015 nominal = dict() 

1016 if constants is None: 

1017 constants = set() 

1018 x_vars = self.exo_vars if comp == 'System' else self[comp].x_vars 

1019 x = np.empty((*size, len(x_vars))) 

1020 for i, var in enumerate(x_vars): 

1021 # Set a constant value for this variable 

1022 if var.param_type in constants or var in constants: 

1023 x[..., i] = nominal.get(var, var.nominal) # Defaults to variable's nominal value if not specified 

1024 

1025 # Sample from this variable's pdf or randomly within its domain bounds (reject if outside bounds) 

1026 else: 

1027 lb, ub = var.bounds() 

1028 x_sample = var.sample(size, nominal=nominal.get(var, None)) if use_pdf \ 

1029 else var.sample_domain(size) 

1030 good_idx = (x_sample < ub) & (x_sample > lb) 

1031 num_reject = np.sum(~good_idx) 

1032 

1033 while num_reject > 0: 

1034 new_sample = var.sample((num_reject,), nominal=nominal.get(var, None)) if use_pdf \ 

1035 else var.sample_domain((num_reject,)) 

1036 x_sample[~good_idx] = new_sample 

1037 good_idx = (x_sample < ub) & (x_sample > lb) 

1038 num_reject = np.sum(~good_idx) 

1039 

1040 x[..., i] = x_sample 

1041 

1042 return x 

1043 

1044 def plot_slice(self, slice_idx: IndicesRV = None, qoi_idx: IndicesRV = None, show_surr: bool = True, 

1045 show_model: list = None, model_dir: str | Path = None, N: int = 50, nominal: dict[str: float] = None, 

1046 random_walk: bool = False, from_file: str | Path = None): 

1047 """Helper function to plot 1d slices of the surrogate and/or model(s) over the inputs. 

1048 

1049 :param slice_idx: list of exogenous input variables or indices to take 1d slices of 

1050 :param qoi_idx: list of model output variables or indices to plot 1d slices of 

1051 :param show_surr: whether to show the surrogate prediction 

1052 :param show_model: also plot model predictions, list() of ['best', 'worst', tuple(alpha), etc.] 

1053 :param model_dir: base directory to save model outputs (if specified) 

1054 :param N: the number of points to take in the 1d slice 

1055 :param nominal: `dict` of `str(var)->nominal` to use as constant value for all non-sliced variables 

1056 :param random_walk: whether to slice in a random d-dimensional direction or hold all params const while slicing 

1057 :param from_file: path to a .pkl file to load a saved slice from disk 

1058 :returns: `fig, ax` with `num_slice` by `num_qoi` subplots 

1059 """ 

1060 # Manage loading important quantities from file (if provided) 

1061 xs, ys_model, ys_surr = None, None, None 

1062 if from_file is not None: 

1063 with open(Path(from_file), 'rb') as fd: 

1064 slice_data = pickle.load(fd) 

1065 slice_idx = slice_data['slice_idx'] # Must use same input slices as save file 

1066 show_model = slice_data['show_model'] # Must use same model data as save file 

1067 qoi_idx = slice_data['qoi_idx'] if qoi_idx is None else qoi_idx 

1068 xs = slice_data['xs'] 

1069 model_dir = None # Don't run or save any models if loading from file 

1070 

1071 # Set default values (take up to the first 3 slices by default) 

1072 rand_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4)) 

1073 if model_dir is not None: 

1074 os.mkdir(Path(model_dir) / f'sweep_{rand_id}') 

1075 if nominal is None: 

1076 nominal = dict() 

1077 slice_idx = list(np.arange(0, min(3, len(self.exo_vars)))) if slice_idx is None else slice_idx 

1078 qoi_idx = list(np.arange(0, min(3, len(self.coupling_vars)))) if qoi_idx is None else qoi_idx 

1079 if isinstance(slice_idx[0], str | BaseRV): 

1080 slice_idx = [self.exo_vars.index(var) for var in slice_idx] 

1081 if isinstance(qoi_idx[0], str | BaseRV): 

1082 qoi_idx = [self.coupling_vars.index(var) for var in qoi_idx] 

1083 

1084 exo_bds = [var.bounds() for var in self.exo_vars] 

1085 xlabels = [self.exo_vars[idx].to_tex(units=True) for idx in slice_idx] 

1086 ylabels = [self.coupling_vars[idx].to_tex(units=True) for idx in qoi_idx] 

1087 

1088 # Construct slice model inputs (if not provided) 

1089 if xs is None: 

1090 xs = np.zeros((N, len(slice_idx), len(self.exo_vars))) 

1091 for i in range(len(slice_idx)): 

1092 if random_walk: 

1093 # Make a random straight-line walk across d-cube 

1094 r0 = np.squeeze(self.sample_inputs((1,), use_pdf=False), axis=0) 

1095 r0[slice_idx[i]] = exo_bds[slice_idx[i]][0] # Start slice at this lower bound 

1096 rf = np.squeeze(self.sample_inputs((1,), use_pdf=False), axis=0) 

1097 rf[slice_idx[i]] = exo_bds[slice_idx[i]][1] # Slice up to this upper bound 

1098 xs[0, i, :] = r0 

1099 for k in range(1, N): 

1100 xs[k, i, :] = xs[k-1, i, :] + (rf-r0)/(N-1) 

1101 else: 

1102 # Otherwise, only slice one variable 

1103 for j in range(len(self.exo_vars)): 

1104 if j == slice_idx[i]: 

1105 xs[:, i, j] = np.linspace(exo_bds[slice_idx[i]][0], exo_bds[slice_idx[i]][1], N) 

1106 else: 

1107 xs[:, i, j] = nominal.get(self.exo_vars[j], self.exo_vars[j].nominal) 

1108 

1109 # Walk through each model that is requested by show_model 

1110 if show_model is not None: 

1111 if from_file is not None: 

1112 ys_model = slice_data['ys_model'] 

1113 else: 

1114 ys_model = list() 

1115 for model in show_model: 

1116 output_dir = None 

1117 if model_dir is not None: 

1118 output_dir = (Path(model_dir) / f'sweep_{rand_id}' / 

1119 str(model).replace('{', '').replace('}', '').replace(':', '=').replace("'", '')) 

1120 os.mkdir(output_dir) 

1121 ys_model.append(self(xs, use_model=model, model_dir=output_dir)) 

1122 if show_surr: 

1123 ys_surr = self(xs) if from_file is None else slice_data['ys_surr'] 

1124 

1125 # Make len(qoi) by len(inputs) grid of subplots 

1126 fig, axs = plt.subplots(len(qoi_idx), len(slice_idx), sharex='col', sharey='row') 

1127 for i in range(len(qoi_idx)): 

1128 for j in range(len(slice_idx)): 

1129 if len(qoi_idx) == 1: 

1130 ax = axs if len(slice_idx) == 1 else axs[j] 

1131 elif len(slice_idx) == 1: 

1132 ax = axs if len(qoi_idx) == 1 else axs[i] 

1133 else: 

1134 ax = axs[i, j] 

1135 x = xs[:, j, slice_idx[j]] 

1136 if show_model is not None: 

1137 c = np.array([[0, 0, 0, 1], [0.5, 0.5, 0.5, 1]]) if len(show_model) <= 2 else ( 

1138 plt.get_cmap('jet')(np.linspace(0, 1, len(show_model)))) 

1139 for k in range(len(show_model)): 

1140 model_str = (str(show_model[k]).replace('{', '').replace('}', '') 

1141 .replace(':', '=').replace("'", '')) 

1142 model_ret = ys_model[k] 

1143 y_model = model_ret[:, j, qoi_idx[i]] 

1144 label = {'best': 'High-fidelity' if len(show_model) > 1 else 'Model', 

1145 'worst': 'Low-fidelity'}.get(model_str, model_str) 

1146 ax.plot(x, y_model, ls='-', c=c[k, :], label=label) 

1147 if show_surr: 

1148 y_surr = ys_surr[:, j, qoi_idx[i]] 

1149 ax.plot(x, y_surr, '--r', label='Surrogate') 

1150 ylabel = ylabels[i] if j == 0 else '' 

1151 xlabel = xlabels[j] if i == len(qoi_idx) - 1 else '' 

1152 legend = (i == 0 and j == len(slice_idx) - 1) 

1153 ax_default(ax, xlabel, ylabel, legend=legend) 

1154 fig.set_size_inches(3 * len(slice_idx), 3 * len(qoi_idx)) 

1155 fig.tight_layout() 

1156 

1157 # Save results (unless we were already loading from a save file) 

1158 if from_file is None and self.root_dir is not None: 

1159 fname = f's{",".join([str(i) for i in slice_idx])}_q{",".join([str(i) for i in qoi_idx])}' 

1160 fname = f'sweep_rand{rand_id}_' + fname if random_walk else f'sweep_nom{rand_id}_' + fname 

1161 fdir = Path(self.root_dir) if model_dir is None else Path(model_dir) / f'sweep_{rand_id}' 

1162 fig.savefig(fdir / f'{fname}.png', dpi=300, format='png') 

1163 save_dict = {'slice_idx': slice_idx, 'qoi_idx': qoi_idx, 'show_model': show_model, 'show_surr': show_surr, 

1164 'nominal': nominal, 'random_walk': random_walk, 'xs': xs, 'ys_model': ys_model, 

1165 'ys_surr': ys_surr} 

1166 with open(fdir / f'{fname}.pkl', 'wb') as fd: 

1167 pickle.dump(save_dict, fd) 

1168 

1169 return fig, axs 

1170 

1171 def plot_allocation(self, cmap: str = 'Blues', text_bar_width: float = 0.06, arrow_bar_width: float = 0.02): 

1172 """Plot bar charts showing cost allocation during training. 

1173 

1174 !!! Warning "Beta feature" 

1175 This has pretty good default settings, but it might look terrible for your use. Mostly provided here as 

1176 a template for making cost allocation bar charts. Please feel free to copy and edit in your own code. 

1177 

1178 :param cmap: the colormap string identifier for `plt` 

1179 :param text_bar_width: the minimum total cost fraction above which a bar will print centered model fidelity text 

1180 :param arrow_bar_width: the minimum total cost fraction above which a bar will try to print text with an arrow; 

1181 below this amount, the bar is too skinny and won't print any text 

1182 :returns: `fig, ax`, Figure and Axes objects 

1183 """ 

1184 # Get total cost (including offline overhead) 

1185 train_alloc, offline_alloc, cost_cum = self.get_allocation() 

1186 total_cost = cost_cum[-1] 

1187 for node, alpha_dict in offline_alloc.items(): 

1188 for alpha, cost in alpha_dict.items(): 

1189 total_cost += cost[1] 

1190 

1191 # Remove nodes with cost=0 from alloc dicts (i.e. analytical models) 

1192 remove_nodes = [] 

1193 for node, alpha_dict in train_alloc.items(): 

1194 if len(alpha_dict) == 0: 

1195 remove_nodes.append(node) 

1196 for node in remove_nodes: 

1197 del train_alloc[node] 

1198 del offline_alloc[node] 

1199 

1200 # Bar chart showing cost allocation breakdown for MF system at end 

1201 fig, axs = plt.subplots(1, 2, sharey='row') 

1202 width = 0.7 

1203 x = np.arange(len(train_alloc)) 

1204 xlabels = list(train_alloc.keys()) 

1205 cmap = plt.get_cmap(cmap) 

1206 for k in range(2): 

1207 ax = axs[k] 

1208 alloc = train_alloc if k == 0 else offline_alloc 

1209 ax.set_title('Online training' if k == 0 else 'Overhead') 

1210 for j, (node, alpha_dict) in enumerate(alloc.items()): 

1211 bottom = 0 

1212 c_intervals = np.linspace(0, 1, len(alpha_dict)) 

1213 bars = [(alpha, cost, cost[1] / total_cost) for alpha, cost in alpha_dict.items()] 

1214 bars = sorted(bars, key=lambda ele: ele[2], reverse=True) 

1215 for i, (alpha, cost, frac) in enumerate(bars): 

1216 p = ax.bar(x[j], frac, width, color=cmap(c_intervals[i]), linewidth=1, 

1217 edgecolor=[0, 0, 0], bottom=bottom) 

1218 bottom += frac 

1219 if frac > text_bar_width: 

1220 ax.bar_label(p, labels=[f'{alpha}, {round(cost[0])}'], label_type='center') 

1221 elif frac > arrow_bar_width: 

1222 xy = (x[j] + width / 2, bottom - frac / 2) # Label smaller bars with a text off to the side 

1223 ax.annotate(f'{alpha}, {round(cost[0])}', xy, xytext=(xy[0] + 0.2, xy[1]), 

1224 arrowprops={'arrowstyle': '->', 'linewidth': 1}) 

1225 else: 

1226 pass # Don't label really small bars 

1227 ax_default(ax, '', "Fraction of total cost" if k == 0 else '', legend=False) 

1228 ax.set_xticks(x, xlabels) 

1229 ax.set_xlim(left=-1, right=x[-1] + 1) 

1230 fig.set_size_inches(8, 4) 

1231 fig.tight_layout() 

1232 

1233 if self.root_dir is not None: 

1234 fig.savefig(Path(self.root_dir) / 'mf_allocation.png', dpi=300, format='png') 

1235 

1236 return fig, axs 

1237 

1238 def get_component(self, comp_name: str) -> ComponentSurrogate: 

1239 """Return the `ComponentSurrogate` object for this component. 

1240 

1241 :param comp_name: name of the component to return 

1242 :returns: the `ComponentSurrogate` object 

1243 """ 

1244 comp = self if comp_name == 'System' else self.graph.nodes[comp_name]['surrogate'] 

1245 return comp 

1246 

1247 def _print_title_str(self, title_str: str): 

1248 """Log an important message.""" 

1249 self.logger.info('-' * int(len(title_str)/2) + title_str + '-' * int(len(title_str)/2)) 

1250 

1251 def _save_progress(self, filename: str): 

1252 """Internal helper to save surrogate training progress (only if `root_dir` exists) 

1253 

1254 :param filename: the name of the save file to `root/sys/filename.pkl` 

1255 """ 

1256 if self.root_dir is not None: 

1257 self.save_to_file(filename) 

1258 

1259 def save_to_file(self, filename: str, save_dir: str | Path = None): 

1260 """Save the `SystemSurrogate` object to a `.pkl` file. 

1261 

1262 :param filename: filename of the `.pkl` file to save to 

1263 :param save_dir: overrides existing surrogate root directory if provided, otherwise defaults to '.' 

1264 """ 

1265 if save_dir is None: 

1266 save_dir = '.' if self.root_dir is None else str(Path(self.root_dir) / 'sys') 

1267 if not Path(save_dir).is_dir(): 

1268 save_dir = '.' 

1269 

1270 exec_temp = self.executor # Temporarily save executor obj (can't pickle it) 

1271 self.set_executor(None) 

1272 with open(Path(save_dir) / filename, 'wb') as dill_file: 

1273 dill.dump(self, dill_file) 

1274 self.set_executor(exec_temp) 

1275 self.logger.info(f'SystemSurrogate saved to {(Path(save_dir) / filename).resolve()}') 

1276 

1277 def _set_output_dir(self, set_dict: dict[str: str | Path]): 

1278 """Set the output directory for each component in `set_dict`. 

1279 

1280 :param set_dict: a `dict` of component names (`str`) to their new output directories 

1281 """ 

1282 for node, node_obj in self.graph.nodes.items(): 

1283 if node in set_dict: 

1284 node_obj['surrogate']._set_output_dir(set_dict.get(node)) 

1285 

1286 def set_root_directory(self, root_dir: str | Path = None, stdout: bool = True, logger_name: str = None): 

1287 """Set the root to a new directory, for example if you move to a new filesystem. 

1288 

1289 :param root_dir: new root directory, don't save build products if None 

1290 :param stdout: whether to connect the logger to console (default) 

1291 :param logger_name: the logger name to use, defaults to class name 

1292 """ 

1293 if root_dir is None: 

1294 self.root_dir = None 

1295 self.log_file = None 

1296 else: 

1297 self.root_dir = str(Path(root_dir).resolve()) 

1298 log_file = None 

1299 if not (Path(self.root_dir) / 'sys').is_dir(): 

1300 os.mkdir(Path(self.root_dir) / 'sys') 

1301 if not (Path(self.root_dir) / 'components').is_dir(): 

1302 os.mkdir(Path(self.root_dir) / 'components') 

1303 for f in os.listdir(self.root_dir): 

1304 if f.endswith('.log'): 

1305 log_file = str((Path(self.root_dir) / f).resolve()) 

1306 break 

1307 if log_file is None: 

1308 fname = (datetime.datetime.now(tz=timezone.utc).isoformat().split('.')[0].replace(':', '.') + 

1309 'UTC_sys.log') 

1310 log_file = str((Path(self.root_dir) / fname).resolve()) 

1311 self.log_file = log_file 

1312 

1313 self.set_logger(logger_name, stdout=stdout) 

1314 

1315 # Update model output directories 

1316 for node, node_obj in self.graph.nodes.items(): 

1317 surr = node_obj['surrogate'] 

1318 if self.root_dir is not None and surr.save_enabled(): 

1319 output_dir = str((Path(self.root_dir) / 'components' / node).resolve()) 

1320 if not Path(output_dir).is_dir(): 

1321 os.mkdir(output_dir) 

1322 surr._set_output_dir(output_dir) 

1323 

1324 def __getitem__(self, component: str) -> ComponentSurrogate: 

1325 """Convenience method to get the `ComponentSurrogate object` from the `SystemSurrogate`. 

1326 

1327 :param component: the name of the component to get 

1328 :returns: the `ComponentSurrogate` object 

1329 """ 

1330 return self.get_component(component) 

1331 

1332 def __repr__(self): 

1333 s = f'----SystemSurrogate----\nAdjacency: \n{nx.to_numpy_array(self.graph, dtype=int)}\n' \ 

1334 f'Exogenous inputs: {[str(var) for var in self.exo_vars]}\n' 

1335 for node, node_obj in self.graph.nodes.items(): 

1336 s += f'Component: {node}\n{node_obj["surrogate"]}' 

1337 return s 

1338 

1339 def __str__(self): 

1340 return self.__repr__() 

1341 

1342 def set_executor(self, executor: Executor | None): 

1343 """Set a new `concurrent.futures.Executor` object for parallel calls. 

1344 

1345 :param executor: the new `Executor` object 

1346 """ 

1347 self.executor = executor 

1348 for node, node_obj in self.graph.nodes.items(): 

1349 node_obj['surrogate'].executor = executor 

1350 

1351 def set_logger(self, name: str = None, log_file: str | Path = None, stdout: bool = True): 

1352 """Set a new `logging.Logger` object with the given unique `name`. 

1353 

1354 :param name: the name of the new logger object 

1355 :param stdout: whether to connect the logger to console (default) 

1356 :param log_file: log file (if provided) 

1357 """ 

1358 if log_file is None: 

1359 log_file = self.log_file 

1360 if name is None: 

1361 name = self.__class__.__name__ 

1362 self.log_file = log_file 

1363 self.logger = get_logger(name, log_file=log_file, stdout=stdout) 

1364 

1365 for node, node_obj in self.graph.nodes.items(): 

1366 surr = node_obj['surrogate'] 

1367 surr.logger = self.logger.getChild('Component') 

1368 surr.log_file = self.log_file 

1369 

1370 @staticmethod 

1371 def load_from_file(filename: str | Path, root_dir: str | Path = None, executor: Executor = None, 

1372 stdout: bool = True, logger_name: str = None): 

1373 """Load a `SystemSurrogate` object from file. 

1374 

1375 :param filename: the .pkl file to load 

1376 :param root_dir: if provided, an `amisc_timestamp` directory will be created at `root_dir`. Ignored if the 

1377 `.pkl` file already resides in an `amisc`-like directory. If none, then the surrogate object 

1378 is only loaded into memory and is not given a file directory for any save artifacts. 

1379 :param executor: a `concurrent.futures.Executor` object to set; clears it if None 

1380 :param stdout: whether to log to console 

1381 :param logger_name: the name of the logger to use, if None then uses class name by default 

1382 :returns: the `SystemSurrogate` object 

1383 """ 

1384 with open(Path(filename), 'rb') as dill_file: 

1385 sys_surr = dill.load(dill_file) 

1386 sys_surr.set_executor(executor) 

1387 sys_surr.x_vars = sys_surr.exo_vars # backwards compatible v0.2.0 

1388 

1389 copy_flag = False 

1390 if root_dir is None: 

1391 parts = Path(filename).resolve().parts 

1392 if len(parts) > 2: 

1393 if parts[-3].startswith('amisc_'): 

1394 root_dir = Path(filename).parent.parent # Assumes amisc_root/sys/filename.pkl default structure 

1395 else: 

1396 if not Path(root_dir).is_dir(): 

1397 root_dir = '.' 

1398 timestamp = datetime.datetime.now(tz=timezone.utc).isoformat().split('.')[0].replace(':', '.') 

1399 root_dir = Path(root_dir) / ('amisc_' + timestamp) 

1400 os.mkdir(root_dir) 

1401 copy_flag = True 

1402 

1403 sys_surr.set_root_directory(root_dir, stdout=stdout, logger_name=logger_name) 

1404 

1405 if copy_flag: 

1406 shutil.copyfile(Path(filename), root_dir / 'sys' / Path(filename).name) 

1407 

1408 return sys_surr 

1409 

1410 @staticmethod 

1411 def _constrained_lls(A: np.ndarray, b: np.ndarray, C: np.ndarray, d: np.ndarray) -> np.ndarray: 

1412 """Minimize $||Ax-b||_2$, subject to $Cx=d$, i.e. constrained linear least squares. 

1413 

1414 !!! Note 

1415 See http://www.seas.ucla.edu/~vandenbe/133A/lectures/cls.pdf for more detail. 

1416 

1417 :param A: `(..., M, N)`, vandermonde matrix 

1418 :param b: `(..., M, 1)`, data 

1419 :param C: `(..., P, N)`, constraint operator 

1420 :param d: `(..., P, 1)`, constraint condition 

1421 :returns: `(..., N, 1)`, the solution parameter vector `x` 

1422 """ 

1423 M = A.shape[-2] 

1424 dims = len(A.shape[:-2]) 

1425 T_axes = tuple(np.arange(0, dims)) + (-1, -2) 

1426 Q, R = np.linalg.qr(np.concatenate((A, C), axis=-2)) 

1427 Q1 = Q[..., :M, :] 

1428 Q2 = Q[..., M:, :] 

1429 Q1_T = np.transpose(Q1, axes=T_axes) 

1430 Q2_T = np.transpose(Q2, axes=T_axes) 

1431 Qtilde, Rtilde = np.linalg.qr(Q2_T) 

1432 Qtilde_T = np.transpose(Qtilde, axes=T_axes) 

1433 Rtilde_T_inv = np.linalg.pinv(np.transpose(Rtilde, axes=T_axes)) 

1434 w = np.linalg.pinv(Rtilde) @ (Qtilde_T @ Q1_T @ b - Rtilde_T_inv @ d) 

1435 

1436 return np.linalg.pinv(R) @ (Q1_T @ b - Q2_T @ w)