Coverage for src/amisc/component.py: 87%

908 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-11-18 13:40 +0000

1"""A `Component` is an `amisc` wrapper around a single discipline model. It manages surrogate construction and 

2a hierarchy of modeling fidelities. 

3 

4!!! Info "Multi-indices in the MISC approximation" 

5 A multi-index is a tuple of natural numbers, each specifying a level of fidelity. You will frequently see two 

6 multi-indices: `alpha` and `beta`. The `alpha` (or $\\alpha$) indices specify physical model fidelity and get 

7 passed to the model as an additional argument (e.g. things like discretization level, time step size, etc.). 

8 The `beta` (or $\\beta$) indices specify surrogate refinement level, so typically an indication of the amount of 

9 training data used or the complexity of the surrogate model. We divide $\\beta$ into `data_fidelity` and 

10 `surrogate_fidelity` for specifying training data and surrogate model complexity, respectively. 

11 

12Includes: 

13 

14- `ModelKwargs` — a dataclass for storing model keyword arguments 

15- `StringKwargs` — a dataclass for storing model keyword arguments as a string 

16- `IndexSet` — a dataclass that maintains a list of multi-indices 

17- `MiscTree` — a dataclass that maintains MISC data in a `dict` tree, indexed by `alpha` and `beta` 

18- `Component` — a class that manages a single discipline model and its surrogate hierarchy 

19""" 

20from __future__ import annotations 

21 

22import ast 

23import copy 

24import inspect 

25import itertools 

26import logging 

27import random 

28import string 

29import time 

30import traceback 

31import typing 

32import warnings 

33from collections import UserDict, deque 

34from concurrent.futures import ALL_COMPLETED, Executor, wait 

35from pathlib import Path 

36from typing import Any, Callable, ClassVar, Iterable, Literal, Optional 

37 

38import numpy as np 

39import yaml 

40from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator 

41from typing_extensions import TypedDict 

42 

43from amisc.interpolator import Interpolator, InterpolatorState, Lagrange 

44from amisc.serialize import PickleSerializable, Serializable, StringSerializable, YamlSerializable 

45from amisc.training import SparseGrid, TrainingData 

46from amisc.typing import COORDS_STR_ID, LATENT_STR_ID, Dataset, MultiIndex 

47from amisc.utils import ( 

48 _get_yaml_path, 

49 _inspect_assignment, 

50 _inspect_function, 

51 format_inputs, 

52 format_outputs, 

53 get_logger, 

54 search_for_file, 

55 to_model_dataset, 

56 to_surrogate_dataset, 

57) 

58from amisc.variable import Variable, VariableList 

59 

60__all__ = ["ModelKwargs", "StringKwargs", "IndexSet", "MiscTree", "Component"] 

61_VariableLike = list[Variable | dict | str] | str | Variable | dict | VariableList # Generic type for Variables 

62 

63 

64class ModelKwargs(UserDict, Serializable): 

65 """Default dataclass for storing model keyword arguments in a `dict`. If you have kwargs that require 

66 more complicated serialization/specification than a plain `dict`, then you can subclass from here. 

67 """ 

68 

69 def serialize(self): 

70 return self.data 

71 

72 @classmethod 

73 def deserialize(cls, serialized_data): 

74 return ModelKwargs(**serialized_data) 

75 

76 @classmethod 

77 def from_dict(cls, config: dict) -> ModelKwargs: 

78 """Create a `ModelKwargs` object from a `dict` configuration.""" 

79 method = config.pop('method', 'default_kwargs').lower() 

80 match method: 

81 case 'default_kwargs': 

82 return ModelKwargs(**config) 

83 case 'string_kwargs': 

84 return StringKwargs(**config) 

85 case other: 

86 config['method'] = other 

87 return ModelKwargs(**config) # Pass the method through 

88 

89 

90class StringKwargs(StringSerializable, ModelKwargs): 

91 """Dataclass for storing model keyword arguments as a string.""" 

92 def __repr__(self): 

93 return str(self.data) 

94 

95 def __str__(self): 

96 def format_value(value): 

97 if isinstance(value, str): 

98 return f'"{value}"' 

99 else: 

100 return str(value) 

101 

102 kw_str = ", ".join([f"{key}={format_value(value)}" for key, value in self.items()]) 

103 return f"ModelKwargs({kw_str})" 

104 

105 

106class IndexSet(set, Serializable): 

107 """Dataclass that maintains a list of multi-indices. Overrides basic `set` functionality to ensure 

108 elements are formatted correctly as `(alpha, beta)`; that is, as a tuple of `alpha` and 

109 `beta`, which are themselves instances of a [`MultiIndex`][amisc.typing.MultiIndex] tuple. 

110 

111 !!! Example "An example index set" 

112 $\\mathcal{I} = [(\\alpha, \\beta)_1 , (\\alpha, \\beta)_2, (\\alpha, \\beta)_3 , ...]$ would be specified 

113 as `I = [((0, 0), (0, 0, 0)) , ((0, 1), (0, 1, 0)), ...]`. 

114 """ 

115 def __init__(self, s=()): 

116 s = [self._validate_element(ele) for ele in s] 

117 super().__init__(s) 

118 

119 def __str__(self): 

120 return str(list(self)) 

121 

122 def __repr__(self): 

123 return self.__str__() 

124 

125 def add(self, __element): 

126 super().add(self._validate_element(__element)) 

127 

128 def update(self, __elements): 

129 super().update([self._validate_element(ele) for ele in __elements]) 

130 

131 @classmethod 

132 def _validate_element(cls, element): 

133 """Validate that the element is a tuple of two multi-indices.""" 

134 alpha, beta = ast.literal_eval(element) if isinstance(element, str) else tuple(element) 

135 return MultiIndex(alpha), MultiIndex(beta) 

136 

137 @classmethod 

138 def _wrap_methods(cls, names): 

139 """Make sure set operations return an `IndexSet` object.""" 

140 def wrap_method_closure(name): 

141 def inner(self, *args): 

142 result = getattr(super(cls, self), name)(*args) 

143 if isinstance(result, set): 

144 result = cls(result) 

145 return result 

146 inner.fn_name = name 

147 setattr(cls, name, inner) 

148 

149 for name in names: 

150 wrap_method_closure(name) 

151 

152 def serialize(self) -> list[str]: 

153 """Return a list of each multi-index in the set serialized to a string.""" 

154 return [str(ele) for ele in self] 

155 

156 @classmethod 

157 def deserialize(cls, serialized_data: list[str]) -> IndexSet: 

158 """Deserialize a list of tuples to an `IndexSet`.""" 

159 return cls(serialized_data) 

160 

161 

162IndexSet._wrap_methods(['__ror__', 'difference_update', '__isub__', 'symmetric_difference', '__rsub__', '__and__', 

163 '__rand__', 'intersection', 'difference', '__iand__', 'union', '__ixor__', 

164 'symmetric_difference_update', '__or__', 'copy', '__rxor__', 'intersection_update', '__xor__', 

165 '__ior__', '__sub__' 

166 ]) 

167 

168 

169class MiscTree(UserDict, Serializable): 

170 """Dataclass that maintains MISC data in a `dict` tree, indexed by `alpha` and `beta`. Overrides 

171 basic `dict` functionality to ensure elements are formatted correctly as `(alpha, beta) -> data`. 

172 Used to store MISC coefficients, model costs, and interpolator states. 

173 

174 The underlying data structure is: `dict[MultiIndex, dict[MultiIndex, float | InterpolatorState]]`. 

175 """ 

176 SERIALIZER_KEY = 'state_serializer' 

177 

178 def __init__(self, data: dict = None, **kwargs): 

179 data_dict = data or {} 

180 if isinstance(data_dict, MiscTree): 

181 data_dict = data_dict.data 

182 data_dict.update(kwargs) 

183 super().__init__(self._validate_data(data_dict)) 

184 

185 def serialize(self, *args, keep_yaml_objects=False, **kwargs) -> dict: 

186 """Serialize `alpha, beta` indices to string and return a `dict` of internal data. 

187 

188 :param args: extra serialization arguments for internal `InterpolatorState` 

189 :param keep_yaml_objects: whether to keep `YamlSerializable` instances in the serialization 

190 :param kwargs: extra serialization keyword arguments for internal `InterpolatorState` 

191 """ 

192 ret_dict = {} 

193 if state_serializer := self.state_serializer(self.data): 

194 ret_dict[self.SERIALIZER_KEY] = state_serializer.obj if keep_yaml_objects else state_serializer.serialize() 

195 for alpha, beta, data in self: 

196 ret_dict.setdefault(str(alpha), dict()) 

197 serialized_data = data.serialize(*args, **kwargs) if isinstance(data, InterpolatorState) else float(data) 

198 ret_dict[str(alpha)][str(beta)] = serialized_data 

199 return ret_dict 

200 

201 @classmethod 

202 def deserialize(cls, serialized_data: dict) -> MiscTree: 

203 """Deserialize a `dict` to a `MiscTree`. 

204 

205 :param serialized_data: the data to deserialize to a `MiscTree` object 

206 """ 

207 return cls(serialized_data) 

208 

209 @classmethod 

210 def state_serializer(cls, data: dict) -> YamlSerializable | None: 

211 """Infer and return the interpolator state serializer from the `MiscTree` data (if possible). If no 

212 `InterpolatorState` instance could be found, return `None`. 

213 """ 

214 serializer = data.get(cls.SERIALIZER_KEY, None) # if `data` is serialized 

215 if serializer is None: # Otherwise search for an InterpolatorState 

216 for alpha, beta_dict in data.items(): 

217 if alpha == cls.SERIALIZER_KEY: 

218 continue 

219 for beta, value in beta_dict.items(): 

220 if isinstance(value, InterpolatorState): 

221 serializer = type(value) 

222 break 

223 if serializer is not None: 

224 break 

225 return cls._validate_state_serializer(serializer) 

226 

227 @classmethod 

228 def _validate_state_serializer(cls, state_serializer: Optional[str | type[Serializable] | YamlSerializable] 

229 ) -> YamlSerializable | None: 

230 if state_serializer is None: 

231 return None 

232 elif isinstance(state_serializer, YamlSerializable): 

233 return state_serializer 

234 elif isinstance(state_serializer, str): 

235 return YamlSerializable.deserialize(state_serializer) # Load the serializer type from string 

236 else: 

237 return YamlSerializable(obj=state_serializer) 

238 

239 @classmethod 

240 def _validate_data(cls, serialized_data: dict) -> dict: 

241 state_serializer = cls.state_serializer(serialized_data) 

242 ret_dict = {} 

243 for alpha, beta_dict in serialized_data.items(): 

244 if alpha == cls.SERIALIZER_KEY: 

245 continue 

246 alpha_tup = MultiIndex(alpha) 

247 ret_dict.setdefault(alpha_tup, dict()) 

248 for beta, data in beta_dict.items(): 

249 beta_tup = MultiIndex(beta) 

250 if isinstance(data, InterpolatorState): 

251 pass 

252 elif state_serializer is not None: 

253 data = state_serializer.obj.deserialize(data) 

254 else: 

255 data = float(data) 

256 assert isinstance(data, InterpolatorState | float) 

257 ret_dict[alpha_tup][beta_tup] = data 

258 return ret_dict 

259 

260 @staticmethod 

261 def _is_alpha_beta_access(key): 

262 """Check that the key is of the format `(alpha, beta).`""" 

263 return (isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], str | tuple) 

264 and isinstance(key[1], str | tuple)) 

265 

266 def get(self, key, default=None) -> float | InterpolatorState: 

267 try: 

268 return self.__getitem__(key) 

269 except Exception: 

270 return default 

271 

272 def update(self, data_dict: dict = None, **kwargs): 

273 """Force `dict.update()` through the validator.""" 

274 data_dict = data_dict or dict() 

275 data_dict.update(kwargs) 

276 super().update(self._validate_data(data_dict)) 

277 

278 def __setitem__(self, key: tuple | MultiIndex, value: float | InterpolatorState): 

279 """Allows `misc_tree[alpha, beta] = value` usage.""" 

280 if self._is_alpha_beta_access(key): 

281 alpha, beta = MultiIndex(key[0]), MultiIndex(key[1]) 

282 self.data.setdefault(alpha, dict()) 

283 self.data[alpha][beta] = value 

284 else: 

285 super().__setitem__(MultiIndex(key), value) 

286 

287 def __getitem__(self, key: tuple | MultiIndex) -> float | InterpolatorState: 

288 """Allows `value = misc_tree[alpha, beta]` usage.""" 

289 if self._is_alpha_beta_access(key): 

290 alpha, beta = MultiIndex(key[0]), MultiIndex(key[1]) 

291 return self.data[alpha][beta] 

292 else: 

293 return super().__getitem__(MultiIndex(key)) 

294 

295 def clear(self): 

296 """Clear the `MiscTree` data.""" 

297 for key in list(self.data.keys()): 

298 del self.data[key] 

299 

300 def __eq__(self, other): 

301 if isinstance(other, MiscTree): 

302 try: 

303 for alpha, beta, data in self: 

304 if other[alpha, beta] != data: 

305 return False 

306 return True 

307 except KeyError: 

308 return False 

309 else: 

310 return False 

311 

312 def __iter__(self) -> Iterable[tuple[tuple, tuple, float | InterpolatorState]]: 

313 for alpha, beta_dict in self.data.items(): 

314 if alpha == self.SERIALIZER_KEY: 

315 continue 

316 for beta, data in beta_dict.items(): 

317 yield alpha, beta, data 

318 

319 

320class ComponentSerializers(TypedDict, total=False): 

321 """Type hint for the `Component` class data serializers. 

322 

323 :ivar model_kwargs: the model kwarg object class 

324 :ivar interpolator: the interpolator object class 

325 :ivar training_data: the training data object class 

326 """ 

327 model_kwargs: str | type[Serializable] | YamlSerializable 

328 interpolator: str | type[Serializable] | YamlSerializable 

329 training_data: str | type[Serializable] | YamlSerializable 

330 

331 

332class Component(BaseModel, Serializable): 

333 """A `Component` wrapper around a single discipline model. It manages MISC surrogate construction and a hierarchy of 

334 modeling fidelities. 

335 

336 A `Component` can be constructed by specifying a model, input and output variables, and additional configurations 

337 such as the maximum fidelity levels, the interpolator type, and the training data type. If `model_fidelity`, 

338 `data_fidelity`, and `surrogate_fidelity` are all left empty, then the `Component` will not use a surrogate model, 

339 instead calling the underlying model directly. The `Component` can be serialized to a YAML file and deserialized 

340 back into a Python object. 

341 

342 !!! Example "A simple `Component`" 

343 ```python 

344 from amisc import Component, Variable 

345 

346 x = Variable(domain=(0, 1)) 

347 y = Variable() 

348 model = lambda x: {'y': x['x']**2} 

349 comp = Component(model=model, inputs=[x], outputs=[y]) 

350 ``` 

351 

352 Each fidelity index in $\\alpha$ increases in refinement from $0$ up to `model_fidelity`. Each fidelity index 

353 in $\\beta$ increases from $0$ up to `(data_fidelity, surrogate_fidelity)`. From the `Component's` perspective, 

354 the concatenation of $(\\alpha, \\beta)$ fully specifies a single fidelity "level". The `Component` 

355 forms an approximation of the model by summing up over many of these concatenated sets of $(\\alpha, \\beta)$. 

356 

357 :ivar name: the name of the `Component` 

358 :ivar model: the model or function that is to be approximated, callable as `y = f(x)` 

359 :ivar inputs: the input variables to the model 

360 :ivar outputs: the output variables from the model 

361 :ivar model_kwargs: extra keyword arguments to pass to the model 

362 :ivar model_fidelity: the maximum level of refinement for each fidelity index in $\\alpha$ for model fidelity 

363 :ivar data_fidelity: the maximum level of refinement for each fidelity index in $\\beta$ for training data 

364 :ivar surrogate_fidelity: the max level of refinement for each fidelity index in $\\beta$ for the surrogate 

365 :ivar interpolator: the interpolator to use as the underlying surrogate model 

366 :ivar vectorized: whether the model supports vectorized input/output (i.e. datasets with arbitrary shape `(...,)`) 

367 :ivar call_unpacked: whether the model expects unpacked input arguments (i.e. `func(x1, x2, ...)`) 

368 :ivar ret_unpacked: whether the model returns unpacked output arguments (i.e. `func() -> (y1, y2, ...)`) 

369 

370 :ivar active_set: the current active set of multi-indices in the MISC approximation 

371 :ivar candidate_set: all neighboring multi-indices that are candidates for inclusion in `active_set` 

372 :ivar misc_states: the interpolator states for each multi-index in the MISC approximation 

373 :ivar misc_costs: the computational cost associated with each multi-index in the MISC approximation 

374 :ivar misc_coeff_train: the combination technique coefficients for the active set multi-indices 

375 :ivar misc_coeff_test: the combination technique coefficients for the active and candidate set multi-indices 

376 :ivar model_costs: the tracked average single fidelity model costs for each $\\alpha$ 

377 :ivar model_evals: the tracked number of evaluations for each $\\alpha$ 

378 :ivar training_data: the training data storage structure for the surrogate model 

379 

380 :ivar serializers: the custom serializers for the `[model_kwargs, interpolator, training_data]` 

381 `Component` attributes -- these should be the _types_ of the serializer objects, which will 

382 be inferred from the data passed in if not explicitly set 

383 :ivar _logger: the logger for the `Component` 

384 """ 

385 yaml_tag: ClassVar[str] = u'!Component' 

386 model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True, validate_default=True, 

387 protected_namespaces=(), extra='allow') 

388 # Configuration 

389 serializers: Optional[ComponentSerializers] = None 

390 name: Optional[str] = None 

391 model: str | Callable[[dict | Dataset, ...], dict | Dataset] 

392 model_kwargs: str | dict | ModelKwargs = {} 

393 inputs: _VariableLike 

394 outputs: _VariableLike 

395 model_fidelity: str | tuple = MultiIndex() 

396 data_fidelity: str | tuple = MultiIndex() 

397 surrogate_fidelity: str | tuple = MultiIndex() 

398 interpolator: Any | Interpolator = Lagrange() 

399 vectorized: bool = False 

400 call_unpacked: Optional[bool] = None # If the model expects inputs/outputs like `func(x1, x2, ...)->(y1, y2, ...) 

401 ret_unpacked: Optional[bool] = None 

402 

403 # Data storage/states for a MISC component 

404 active_set: list | set | IndexSet = IndexSet() # set of active (alpha, beta) multi-indices 

405 candidate_set: list | set | IndexSet = IndexSet() # set of candidate (alpha, beta) multi-indices 

406 misc_states: dict | MiscTree = MiscTree() # (alpha, beta) -> Interpolator state 

407 misc_costs: dict | MiscTree = MiscTree() # (alpha, beta) -> Added computational cost for this mult-index 

408 misc_coeff_train: dict | MiscTree = MiscTree() # (alpha, beta) -> c_[alpha, beta] (active set only) 

409 misc_coeff_test: dict | MiscTree = MiscTree() # (alpha, beta) -> c_[alpha, beta] (including candidate set) 

410 model_costs: dict = dict() # Average single fidelity model costs (for each alpha) 

411 model_evals: dict = dict() # Number of evaluations for each alpha 

412 training_data: Any | TrainingData = SparseGrid() # Stores surrogate training data 

413 

414 # Internal 

415 _logger: Optional[logging.Logger] = None 

416 _model_start_time: float = -1.0 # Temporarily store the most recent model start timestamp from call_model 

417 _model_end_time: float = -1.0 # Temporarily store the most recent model end timestamp from call_model 

418 _cache: dict = dict() # Temporary cache for faster access to training data and similar 

419 

420 def __init__(self, /, model, *args, inputs=None, outputs=None, name=None, **kwargs): 

421 if name is None: 

422 name = _inspect_assignment('Component') # try to assign the name from inspection 

423 name = name or model.__name__ or "Component_" + "".join(random.choices(string.digits, k=3)) 

424 

425 # Determine how the model expects to be called and gather inputs/outputs 

426 _ = self._validate_model_signature(model, args, inputs, outputs, kwargs.get('call_unpacked', None), 

427 kwargs.get('ret_unpacked', None)) 

428 model, inputs, outputs, call_unpacked, ret_unpacked = _ 

429 kwargs['call_unpacked'] = call_unpacked 

430 kwargs['ret_unpacked'] = ret_unpacked 

431 

432 # Gather all model kwargs (anything else passed in for kwargs is assumed to be a model kwarg) 

433 model_kwargs = kwargs.get('model_kwargs', {}) 

434 for key in kwargs.keys() - self.model_fields.keys(): 

435 model_kwargs[key] = kwargs.pop(key) 

436 kwargs['model_kwargs'] = model_kwargs 

437 

438 # Gather data serializers from type checks (if not passed in as a kwarg) 

439 serializers = kwargs.get('serializers', {}) # directly passing serializers will override type checks 

440 for key in ComponentSerializers.__annotations__.keys(): 

441 field = kwargs.get(key, None) 

442 if isinstance(field, dict): 

443 field_super = next(filter(lambda x: issubclass(x, Serializable), 

444 typing.get_args(self.model_fields[key].annotation)), None) 

445 field = field_super.from_dict(field) if field_super is not None else field 

446 kwargs[key] = field 

447 if not serializers.get(key, None): 

448 serializers[key] = type(field) if isinstance(field, Serializable) else ( 

449 type(self.model_fields[key].default)) 

450 kwargs['serializers'] = serializers 

451 

452 super().__init__(model=model, inputs=inputs, outputs=outputs, name=name, **kwargs) # Runs pydantic validation 

453 

454 # Set internal properties 

455 assert self.is_downward_closed(self.active_set.union(self.candidate_set)) 

456 self.set_logger() 

457 

458 @classmethod 

459 def _validate_model_signature(cls, model, args=(), inputs=None, outputs=None, 

460 call_unpacked=None, ret_unpacked=None): 

461 """Parse model signature and decide how the model expects to be called based on what input/output information 

462 is provided or inspected from the model signature. 

463 """ 

464 if inputs is not None: 

465 inputs = cls._validate_variables(inputs) 

466 if outputs is not None: 

467 outputs = cls._validate_variables(outputs) 

468 model = cls._validate_model(model) 

469 

470 # Default to `dict` (i.e. packed) model call/return signatures 

471 if call_unpacked is None: 

472 call_unpacked = False 

473 if ret_unpacked is None: 

474 ret_unpacked = False 

475 inputs_inspect, outputs_inspect = _inspect_function(model) 

476 call_unpacked = call_unpacked or (len(inputs_inspect) > 1) # Assume multiple inputs require unpacking 

477 ret_unpacked = ret_unpacked or (len(outputs_inspect) > 1) # Assume multiple outputs require unpacking 

478 

479 # Extract inputs/outputs from args 

480 arg_inputs = () 

481 arg_outputs = () 

482 if len(args) > 0: 

483 if call_unpacked: 

484 if isinstance(args[0], dict | str | Variable): 

485 arg_inputs = args[:len(inputs_inspect)] 

486 arg_outputs = args[len(inputs_inspect):] 

487 else: 

488 arg_inputs = args[0] 

489 arg_outputs = args[1:] 

490 else: 

491 arg_inputs = args[0] # Assume first arg is a single or list of inputs 

492 arg_outputs = args[1:] # Assume rest are outputs 

493 

494 # Resolve inputs 

495 inputs = inputs or [] 

496 inputs = VariableList.merge(inputs, arg_inputs) 

497 if len(inputs) == 0: 

498 inputs = inputs_inspect 

499 call_unpacked = True 

500 if len(inputs) == 0: 

501 raise ValueError("Could not infer input variables from model signature. Either your model does not " 

502 "accept input arguments or an error occurred during inspection.\nPlease provide the " 

503 "inputs directly as `Component(inputs=[...])` or fix the model signature.") 

504 if call_unpacked: 

505 if not all([var == inputs_inspect[i] for i, var in enumerate(inputs)]): 

506 warnings.warn(f"Mismatch between provided inputs: {inputs.values()} and inputs inferred from " 

507 f"model signature: {inputs_inspect}. This may cause unexpected results.") 

508 else: 

509 if len(inputs_inspect) > 1: 

510 warnings.warn(f"Model signature expects multiple input arguments: {inputs_inspect}. " 

511 f"Please set `call_unpacked=True` to use this model signature for multiple " 

512 f"inputs.\nOtherwise, move all inputs into a single `dict` argument and all " 

513 f"extra arguments into the `model_kwargs` field.") 

514 

515 # Can't assume unpacked for single input/output, so warn user if they may be trying to do so 

516 if len(inputs) == 1 and len(inputs_inspect) == 1 and str(inputs[0]) == str(inputs_inspect[0]): 

517 warnings.warn(f"Single input argument: {inputs[0]} provided to model with input signature: " 

518 f"{inputs_inspect}.\nIf you intended to use a single input argument, set " 

519 f"`call_unpacked=True` to use this model signature.\nOtherwise, the first input will " 

520 f"be passed to your model as a `dict`.\nIf you are expecting a `dict` input already, " 

521 f"change the name of the input to not exactly " 

522 f"match {inputs_inspect} in order to silence this warning.") 

523 # Resolve outputs 

524 outputs = outputs or [] 

525 outputs = VariableList.merge(outputs, *arg_outputs) 

526 if len(outputs) == 0: 

527 outputs = outputs_inspect 

528 ret_unpacked = True 

529 if len(outputs) == 0: 

530 raise ValueError("Could not infer output variables from model inspection. Either your model does not " 

531 "return outputs or an error occurred during inspection.\nPlease provide the " 

532 "outputs directly as `Component(outputs=[...])` or fix the model return values.") 

533 if ret_unpacked: 

534 if not all([var == outputs_inspect[i] for i, var in enumerate(outputs)]): 

535 warnings.warn(f"Mismatch between provided outputs: {outputs.values()} and outputs inferred " 

536 f"from model: {outputs_inspect}. This may cause unexpected results.") 

537 else: 

538 if len(outputs_inspect) > 1: 

539 warnings.warn(f"Model expects multiple return values: {outputs_inspect}. Please set " 

540 f"`ret_unpacked=True` to use this model signature for multiple outputs.\n" 

541 f"Otherwise, move all outputs into a single `dict` return value.") 

542 

543 if len(outputs) == 1 and len(outputs_inspect) == 1 and str(outputs[0]) == str(outputs_inspect[0]): 

544 warnings.warn(f"Single output: {outputs[0]} provided to model with single expected return: " 

545 f"{outputs_inspect}.\nIf you intended to output a single return value, set " 

546 f"`ret_unpacked=True` to use this model signature.\nOtherwise, the output should " 

547 f"be returned from your model as a `dict`.\nIf you are returning a `dict` already, " 

548 f"then change its name to not exactly match {outputs_inspect} in order to silence " 

549 f"this warning.") 

550 return model, inputs, outputs, call_unpacked, ret_unpacked 

551 

552 def __repr__(self): 

553 s = f'---- {self.name} ----\n' 

554 s += f'Inputs: {self.inputs}\n' 

555 s += f'Outputs: {self.outputs}\n' 

556 s += f'Model: {self.model}' 

557 return s 

558 

559 def __str__(self): 

560 return self.__repr__() 

561 

562 @field_validator('serializers') 

563 @classmethod 

564 def _validate_serializers(cls, serializers: ComponentSerializers) -> ComponentSerializers: 

565 """Make sure custom serializer object types are themselves serializable as `YamlSerializable`.""" 

566 for key, serializer in serializers.items(): 

567 if serializer is None: 

568 serializers[key] = None 

569 elif isinstance(serializer, YamlSerializable): 

570 serializers[key] = serializer 

571 elif isinstance(serializer, str): 

572 serializers[key] = YamlSerializable.deserialize(serializer) 

573 else: 

574 serializers[key] = YamlSerializable(obj=serializer) 

575 return serializers 

576 

577 @field_validator('model') 

578 @classmethod 

579 def _validate_model(cls, model: str | Callable) -> Callable: 

580 """Expects model as a callable or a yaml !!python/name string representation.""" 

581 if isinstance(model, str): 

582 return YamlSerializable.deserialize(model).obj 

583 else: 

584 return model 

585 

586 @field_validator('inputs', 'outputs') 

587 @classmethod 

588 def _validate_variables(cls, variables: _VariableLike) -> VariableList: 

589 if isinstance(variables, VariableList): 

590 return variables 

591 else: 

592 return VariableList.deserialize(variables) 

593 

594 @field_validator('model_fidelity', 'data_fidelity', 'surrogate_fidelity') 

595 @classmethod 

596 def _validate_indices(cls, multi_index) -> MultiIndex: 

597 return MultiIndex(multi_index) 

598 

599 @field_validator('active_set', 'candidate_set') 

600 @classmethod 

601 def _validate_index_set(cls, index_set) -> IndexSet: 

602 return IndexSet.deserialize(index_set) 

603 

604 @field_validator('misc_states', 'misc_costs', 'misc_coeff_train', 'misc_coeff_test') 

605 @classmethod 

606 def _validate_misc_tree(cls, misc_tree) -> MiscTree: 

607 return MiscTree.deserialize(misc_tree) 

608 

609 @field_validator('model_costs') 

610 @classmethod 

611 def _validate_model_costs(cls, model_costs: dict) -> dict: 

612 return {MultiIndex(key): float(value) for key, value in model_costs.items()} 

613 

614 @field_validator('model_evals') 

615 @classmethod 

616 def _validate_model_evals(cls, model_evals: dict) -> dict: 

617 return {MultiIndex(key): int(value) for key, value in model_evals.items()} 

618 

619 @field_validator('model_kwargs', 'interpolator', 'training_data') 

620 @classmethod 

621 def _validate_arbitrary_serializable(cls, data: Any, info: ValidationInfo) -> Any: 

622 """Use the stored custom serialization classes to deserialize arbitrary objects.""" 

623 serializer = info.data.get('serializers').get(info.field_name).obj 

624 if isinstance(data, Serializable): 

625 return data 

626 else: 

627 return serializer.deserialize(data) 

628 

629 @property 

630 def xdim(self) -> int: 

631 return len(self.inputs) 

632 

633 @property 

634 def ydim(self) -> int: 

635 return len(self.outputs) 

636 

637 @property 

638 def max_alpha(self) -> MultiIndex: 

639 """The maximum model fidelity multi-index (alias for `model_fidelity`).""" 

640 return self.model_fidelity 

641 

642 @property 

643 def max_beta(self) -> MultiIndex: 

644 """The maximum surrogate fidelity multi-index is a combination of training and interpolator indices.""" 

645 return self.data_fidelity + self.surrogate_fidelity 

646 

647 @property 

648 def has_surrogate(self) -> bool: 

649 """The component has no surrogate model if there are no fidelity indices.""" 

650 return (len(self.max_alpha) + len(self.max_beta)) > 0 

651 

652 @property 

653 def logger(self) -> logging.Logger: 

654 return self._logger 

655 

656 @logger.setter 

657 def logger(self, logger: logging.Logger): 

658 self._logger = logger 

659 

660 def __eq__(self, other): 

661 if isinstance(other, Component): 

662 return (self.model.__code__.co_code == other.model.__code__.co_code and self.inputs == other.inputs 

663 and self.outputs == other.outputs and self.name == other.name 

664 and self.model_kwargs.data == other.model_kwargs.data 

665 and self.model_fidelity == other.model_fidelity and self.max_beta == other.max_beta and 

666 self.interpolator == other.interpolator 

667 and self.active_set == other.active_set and self.candidate_set == other.candidate_set 

668 and self.misc_states == other.misc_states and self.misc_costs == other.misc_costs 

669 ) 

670 else: 

671 return False 

672 

673 def _neighbors(self, alpha: MultiIndex, beta: MultiIndex, active_set: IndexSet = None, forward: bool = True): 

674 """Get all possible forward or backward multi-index neighbors (distance of one unit vector away). 

675 

676 :param alpha: the model fidelity index 

677 :param beta: the surrogate fidelity index 

678 :param active_set: the set of active multi-indices 

679 :param forward: whether to get forward or backward neighbors 

680 :returns: a set of multi-indices that are neighbors of the input multi-index pair `(alpha, beta)` 

681 """ 

682 active_set = active_set or self.active_set 

683 ind = np.array(alpha + beta) 

684 max_ind = np.array(self.max_alpha + self.max_beta) 

685 new_candidates = IndexSet() 

686 for i in range(len(ind)): 

687 ind_new = ind.copy() 

688 ind_new[i] += 1 if forward else -1 

689 

690 # Don't add if we surpass a refinement limit or lower bound 

691 if np.any(ind_new > max_ind) or np.any(ind_new < 0): 

692 continue 

693 

694 # Add the new index if it maintains downward-closedness 

695 down_closed = True 

696 for j in range(len(ind)): 

697 ind_check = ind_new.copy() 

698 ind_check[j] -= 1 

699 if ind_check[j] >= 0: 

700 tup_check = (MultiIndex(ind_check[:len(alpha)]), MultiIndex(ind_check[len(alpha):])) 

701 if tup_check not in active_set and tup_check != (alpha, beta): 

702 down_closed = False 

703 break 

704 if down_closed: 

705 new_candidates.add((ind_new[:len(alpha)], ind_new[len(alpha):])) 

706 

707 return new_candidates 

708 

709 def _surrogate_outputs(self): 

710 """Helper function to get the names of the surrogate outputs (including latent variables).""" 

711 y_vars = [] 

712 for var in self.outputs: 

713 if var.compression is not None: 

714 for i in range(var.compression.latent_size()): 

715 y_vars.append(f'{var.name}{LATENT_STR_ID}{i}') 

716 else: 

717 y_vars.append(var.name) 

718 return y_vars 

719 

720 def _match_index_set(self, index_set, misc_coeff): 

721 """Helper function to grab the correct data structures for the given index set and MISC coefficients.""" 

722 if misc_coeff is None: 

723 match index_set: 

724 case 'train': 

725 misc_coeff = self.misc_coeff_train 

726 case 'test': 

727 misc_coeff = self.misc_coeff_test 

728 case other: 

729 raise ValueError(f"Index set must be 'train' or 'test' if you do not provide `misc_coeff`. " 

730 f"{other} not recognized.") 

731 if isinstance(index_set, str): 

732 match index_set: 

733 case 'train': 

734 index_set = self.active_set 

735 case 'test': 

736 index_set = self.active_set.union(self.candidate_set) 

737 case other: 

738 raise ValueError(f"Index set must be 'train' or 'test'. {other} not recognized.") 

739 

740 return index_set, misc_coeff 

741 

742 def cache(self, kind: list | Literal["training"] = "training"): 

743 """Cache data for quicker access. Only `"training"` is supported. 

744 

745 :param kind: the type(s) of data to cache (only "training" is supported). This will cache the 

746 surrogate training data with nans removed. 

747 """ 

748 if not isinstance(kind, list): 

749 kind = [kind] 

750 

751 if "training" in kind: 

752 self._cache.setdefault("training", {}) 

753 y_vars = self._surrogate_outputs() 

754 for alpha, beta in self.active_set.union(self.candidate_set): 

755 self._cache["training"].setdefault(alpha, {}) 

756 

757 if beta not in self._cache["training"][alpha]: 

758 self._cache["training"][alpha][beta] = self.training_data.get(alpha, beta[:len(self.data_fidelity)], 

759 y_vars=y_vars, skip_nan=True) 

760 

761 def clear_cache(self): 

762 """Clear cached data.""" 

763 self._cache.clear() 

764 

765 def get_training_data(self, alpha: Literal['best', 'worst'] | MultiIndex = 'best', 

766 beta: Literal['best', 'worst'] | MultiIndex = 'best', 

767 y_vars: list = None, 

768 cached: bool = False) -> tuple[Dataset, Dataset]: 

769 """Get all training data for a given multi-index pair `(alpha, beta)`. 

770 

771 :param alpha: the model fidelity index (defaults to the maximum available model fidelity) 

772 :param beta: the surrogate fidelity index (defaults to the maximum available surrogate fidelity) 

773 :param y_vars: the training data to return (defaults to all stored data) 

774 :param cached: if True, will get cached training data if available (this will ignore `y_vars` and 

775 only grab whatever is in the cache, which is surrogate outputs only and no nans) 

776 :returns: `(xtrain, ytrain)` - the training data for the given multi-indices 

777 """ 

778 # Find the best alpha 

779 if alpha == 'best': 

780 alpha_best = () 

781 for a, _ in self.active_set.union(self.candidate_set): 

782 if sum(a) > sum(alpha_best): 

783 alpha_best = a 

784 alpha = alpha_best 

785 elif alpha == 'worst': 

786 alpha = (0,) * len(self.max_alpha) 

787 

788 # Find the best beta for the given alpha 

789 if beta == 'best': 

790 beta_best = () 

791 for a, b in self.active_set.union(self.candidate_set): 

792 if a == alpha and sum(b) > sum(beta_best): 

793 beta_best = b 

794 beta = beta_best 

795 elif beta == 'worst': 

796 beta = (0,) * len(self.max_beta) 

797 

798 try: 

799 if cached and (data := self._cache.get("training", {}).get(alpha, {}).get(beta)) is not None: 

800 return data 

801 else: 

802 return self.training_data.get(alpha, beta[:len(self.data_fidelity)], y_vars=y_vars, skip_nan=True) 

803 except Exception as e: 

804 self.logger.error(f"Error getting training data for alpha={alpha}, beta={beta}.") 

805 raise e 

806 

807 def call_model(self, inputs: dict | Dataset, 

808 model_fidelity: Literal['best', 'worst'] | tuple | list = None, 

809 output_path: str | Path = None, 

810 executor: Executor = None, 

811 track_costs: bool = False, 

812 **kwds) -> Dataset: 

813 """Wrapper function for calling the underlying component model. 

814 

815 This function formats the input data, calls the model, and processes the output data. 

816 It supports vectorized calls, parallel execution using an executor, or serial execution. These options are 

817 checked in that order, with the first available method used. Must set `Component.vectorized=True` if the 

818 model supports input arrays of the form `(N,)` or even arbitrary shape `(...,)`. 

819 

820 !!! Warning "Parallel Execution" 

821 The underlying model must be defined in a global module scope if `pickle` is the serialization method for 

822 the provided `Executor`. 

823 

824 !!! Note "Additional return values" 

825 The model can return additional items that are not part of `Component.outputs`. These items are returned 

826 as object arrays in the output `dict`. Two special return values are `model_cost` and `output_path`. 

827 Returning `model_cost` will store the computational cost of a single model evaluation (which is used by 

828 `amisc` adaptive surrogate training). Returning `output_path` will store the output file name if the model 

829 wrote any files to disk. 

830 

831 !!! Note "Handling errors" 

832 If the underlying component model raises an exception, the error is stored in `output_dict['errors']` with 

833 the index of the input data that caused the error. The output data for that index is set to `np.nan` 

834 for each output variable. 

835 

836 :param inputs: The input data for the model, formatted as a `dict` with a key for each input variable and 

837 a corresponding value that is an array of the input data. If specified as a plain list, then the 

838 order is assumed the same as `Component.inputs`. 

839 :param model_fidelity: Fidelity indices to tune the model fidelity (model must request this 

840 in its keyword arguments). 

841 :param output_path: Directory to save model output files (model must request this in its keyword arguments). 

842 :param executor: Executor for parallel execution if the model is not vectorized (optional). 

843 :param track_costs: Whether to track the computational cost of each model evaluation. 

844 :param kwds: Additional keyword arguments to pass to the model (model must request these in its keyword args). 

845 :returns: The output data from the model, formatted as a `dict` with a key for each output variable and a 

846 corresponding value that is an array of the output data. 

847 """ 

848 # Format inputs to a common loop shape (fail if missing any) 

849 if len(inputs) == 0: 

850 return {} # your fault 

851 if isinstance(inputs, list | np.ndarray): 

852 inputs = np.atleast_1d(inputs) 

853 inputs = {var.name: inputs[..., i] for i, var in enumerate(self.inputs)} 

854 

855 var_shape = {} 

856 for var in self.inputs: 

857 s = None 

858 if (arr := kwds.get(f'{var.name}{COORDS_STR_ID}')) is not None: 

859 if not np.issubdtype(arr.dtype, np.object_): # if not object array, then it's a single coordinate set 

860 s = arr.shape if len(arr.shape) == 1 else arr.shape[:-1] # skip the coordinate dim (last axis) 

861 if var.compression is not None: 

862 for field in var.compression.fields: 

863 var_shape[field] = s 

864 else: 

865 var_shape[var.name] = s 

866 inputs, loop_shape = format_inputs(inputs, var_shape=var_shape) 

867 

868 N = int(np.prod(loop_shape)) 

869 list_alpha = isinstance(model_fidelity, list | np.ndarray) 

870 alpha_requested = self.model_kwarg_requested('model_fidelity') 

871 for var in self.inputs: 

872 if var.compression is not None: 

873 for field in var.compression.fields: 

874 if field not in inputs: 

875 raise ValueError(f"Missing field '{field}' for input variable '{var}'.") 

876 elif var.name not in inputs: 

877 raise ValueError(f"Missing input variable '{var.name}'.") 

878 

879 # Pass extra requested items to the model kwargs 

880 kwargs = copy.deepcopy(self.model_kwargs.data) 

881 if self.model_kwarg_requested('output_path'): 

882 kwargs['output_path'] = output_path 

883 if self.model_kwarg_requested('input_vars'): 

884 kwargs['input_vars'] = self.inputs 

885 if self.model_kwarg_requested('output_vars'): 

886 kwargs['output_vars'] = self.outputs 

887 if alpha_requested: 

888 if not list_alpha: 

889 model_fidelity = [model_fidelity] * N 

890 for i in range(N): 

891 if model_fidelity[i] == 'best': 

892 model_fidelity[i] = self.max_alpha 

893 elif model_fidelity[i] == 'worst': 

894 model_fidelity[i] = (0,) * len(self.model_fidelity) 

895 

896 for k, v in kwds.items(): 

897 if self.model_kwarg_requested(k): 

898 kwargs[k] = v 

899 

900 # Compute model (vectorized, executor parallel, or serial) 

901 errors = {} 

902 if self.vectorized: 

903 if alpha_requested: 

904 kwargs['model_fidelity'] = np.atleast_1d(model_fidelity).reshape((N, -1)) 

905 

906 self._model_start_time = time.time() 

907 output_dict = self.model(*[inputs[var.name] for var in self.inputs], **kwargs) if self.call_unpacked \ 

908 else self.model(inputs, **kwargs) 

909 self._model_end_time = time.time() 

910 

911 if self.ret_unpacked: 

912 output_dict = (output_dict,) if not isinstance(output_dict, tuple) else output_dict 

913 output_dict = {out_var.name: output_dict[i] for i, out_var in enumerate(self.outputs)} 

914 else: 

915 self._model_start_time = time.time() 

916 if executor is None: # Serial 

917 results = deque(maxlen=N) 

918 for i in range(N): 

919 try: 

920 if alpha_requested: 

921 kwargs['model_fidelity'] = model_fidelity[i] 

922 ret = self.model(*[{k: v[i] for k, v in inputs.items()}[var.name] for var in self.inputs], 

923 **kwargs) if self.call_unpacked else ( 

924 self.model({k: v[i] for k, v in inputs.items()}, **kwargs)) 

925 if self.ret_unpacked: 

926 ret = (ret,) if not isinstance(ret, tuple) else ret 

927 ret = {out_var.name: ret[i] for i, out_var in enumerate(self.outputs)} 

928 results.append(ret) 

929 except Exception: 

930 results.append({'inputs': {k: v[i] for k, v in inputs.items()}, 'index': i, 

931 'model_kwargs': kwargs.copy(), 'error': traceback.format_exc()}) 

932 else: # Parallel 

933 results = deque(maxlen=N) 

934 futures = [] 

935 for i in range(N): 

936 if alpha_requested: 

937 kwargs['model_fidelity'] = model_fidelity[i] 

938 fs = executor.submit(self.model, 

939 *[{k: v[i] for k, v in inputs.items()}[var.name] for var in self.inputs], 

940 **kwargs) if self.call_unpacked else ( 

941 executor.submit(self.model, {k: v[i] for k, v in inputs.items()}, **kwargs)) 

942 futures.append(fs) 

943 wait(futures, timeout=None, return_when=ALL_COMPLETED) 

944 

945 for i, fs in enumerate(futures): 

946 try: 

947 if alpha_requested: 

948 kwargs['model_fidelity'] = model_fidelity[i] 

949 ret = fs.result() 

950 if self.ret_unpacked: 

951 ret = (ret,) if not isinstance(ret, tuple) else ret 

952 ret = {out_var.name: ret[i] for i, out_var in enumerate(self.outputs)} 

953 results.append(ret) 

954 except Exception: 

955 results.append({'inputs': {k: v[i] for k, v in inputs.items()}, 'index': i, 

956 'model_kwargs': kwargs.copy(), 'error': traceback.format_exc()}) 

957 self._model_end_time = time.time() 

958 

959 # Collect parallel/serial results 

960 output_dict = {} 

961 for i in range(N): 

962 res = results.popleft() 

963 if 'error' in res: 

964 errors[i] = res 

965 else: 

966 for key, val in res.items(): 

967 # Save this component's variables 

968 is_component_var = False 

969 for var in self.outputs: 

970 if var.compression is not None: # field quantity return values (save as object arrays) 

971 if key in var.compression.fields or key == f'{var}{COORDS_STR_ID}': 

972 if output_dict.get(key) is None: 

973 output_dict.setdefault(key, np.full((N,), None, dtype=object)) 

974 output_dict[key][i] = np.atleast_1d(val) 

975 is_component_var = True 

976 break 

977 elif key == var: 

978 if output_dict.get(key) is None: 

979 _val = np.atleast_1d(val) 

980 _extra_shape = () if len(_val.shape) == 1 and _val.shape[0] == 1 else _val.shape 

981 output_dict.setdefault(key, np.full((N, *_extra_shape), np.nan)) 

982 output_dict[key][i, ...] = np.atleast_1d(val) 

983 is_component_var = True 

984 break 

985 

986 # Otherwise, save other objects 

987 if not is_component_var: 

988 # Save singleton numeric values as numeric arrays (model costs, etc.) 

989 _val = np.atleast_1d(val) 

990 if key == 'model_cost' or (np.issubdtype(_val.dtype, np.number) 

991 and len(_val.shape) == 1 and _val.shape[0] == 1): 

992 if output_dict.get(key) is None: 

993 output_dict.setdefault(key, np.full((N,), np.nan)) 

994 output_dict[key][i] = _val[0] 

995 else: 

996 # Otherwise save into a generic object array 

997 if output_dict.get(key) is None: 

998 output_dict.setdefault(key, np.full((N,), None, dtype=object)) 

999 output_dict[key][i] = val 

1000 

1001 # Save average model costs for each alpha fidelity 

1002 if track_costs: 

1003 if model_fidelity is not None and output_dict.get('model_cost') is not None: 

1004 alpha_costs = {} 

1005 for i, cost in enumerate(output_dict['model_cost']): 

1006 alpha_costs.setdefault(MultiIndex(model_fidelity[i]), []) 

1007 alpha_costs[MultiIndex(model_fidelity[i])].append(cost) 

1008 for a, costs in alpha_costs.items(): 

1009 self.model_evals.setdefault(a, 0) 

1010 self.model_costs.setdefault(a, 0.0) 

1011 num_evals_prev = self.model_evals.get(a) 

1012 num_evals_new = len(costs) 

1013 prev_avg = self.model_costs.get(a) 

1014 costs = np.nan_to_num(costs, nan=prev_avg) 

1015 new_avg = (np.sum(costs) + prev_avg * num_evals_prev) / (num_evals_prev + num_evals_new) 

1016 self.model_evals[a] += num_evals_new 

1017 self.model_costs[a] = float(new_avg) 

1018 

1019 # Reshape loop dimensions to match the original input shape 

1020 output_dict = format_outputs(output_dict, loop_shape) 

1021 

1022 for var in self.outputs: 

1023 if var.compression is not None: 

1024 for field in var.compression.fields: 

1025 if field not in output_dict: 

1026 self.logger.warning(f"Model return missing field '{field}' for output variable '{var}'. " 

1027 f"This may indicate an error during model evaluation. Returning NaNs...") 

1028 output_dict.setdefault(field, np.full((N,), np.nan)) 

1029 elif var.name not in output_dict: 

1030 self.logger.warning(f"Model return missing output variable '{var.name}'. This may indicate " 

1031 f"an error during model evaluation. Returning NaNs...") 

1032 output_dict[var.name] = np.full((N,), np.nan) 

1033 

1034 # Return the output dictionary and any errors 

1035 if errors: 

1036 output_dict['errors'] = errors 

1037 return output_dict 

1038 

1039 def predict(self, inputs: dict | Dataset, 

1040 use_model: Literal['best', 'worst'] | tuple = None, 

1041 model_dir: str | Path = None, 

1042 index_set: Literal['train', 'test'] | IndexSet = 'test', 

1043 misc_coeff: MiscTree = None, 

1044 incremental: bool = False, 

1045 executor: Executor = None, 

1046 **kwds) -> Dataset: 

1047 """Evaluate the MISC surrogate approximation at new inputs `x`. 

1048 

1049 !!! Note "Using the underlying model" 

1050 By default this will predict the MISC surrogate approximation; all inputs are assumed to be in a compressed 

1051 and normalized form. If the component does not have a surrogate (i.e. it is analytical), then the inputs 

1052 will be converted to model form and the underlying model will be called in place. If you instead want to 

1053 override the surrogate, passing `use_model` will call the underlying model directly. In that case, the 

1054 inputs should be passed in already in model form (i.e. full fields, denormalized). 

1055 

1056 :param inputs: `dict` of input arrays for each variable input 

1057 :param use_model: 'best'=high-fidelity, 'worst'=low-fidelity, tuple=a specific `alpha`, None=surrogate (default) 

1058 :param model_dir: directory to save output files if `use_model` is specified, ignored otherwise 

1059 :param index_set: the active index set, defaults to `self.active_set` if `'train'` or both 

1060 `self.active_set + self.candidate_set` if `'test'` 

1061 :param misc_coeff: the data structure holding the MISC coefficients to use, which defaults to the 

1062 training or testing coefficients depending on the `index_set` parameter. 

1063 :param incremental: a special flag to use if the provided `index_set` is an incremental update to the active 

1064 index set. A temporary copy of the internal `misc_coeff` data structure will be updated 

1065 and used to incorporate the new indices. 

1066 :param executor: executor for parallel execution if the model is not vectorized (optional), will use the 

1067 executor for looping over MISC coefficients if evaluating the surrogate rather than the model 

1068 :param kwds: additional keyword arguments to pass to the model (if using the underlying model) 

1069 :returns: the surrogate approximation of the model (or the model return itself if `use_model`) 

1070 """ 

1071 # Use raw model inputs/outputs 

1072 if use_model is not None: 

1073 outputs = self.call_model(inputs, model_fidelity=use_model, output_path=model_dir, executor=executor,**kwds) 

1074 return {str(var): outputs[var] for var in outputs} 

1075 

1076 # Convert inputs/outputs to/from model if no surrogate (i.e. analytical models) 

1077 if not self.has_surrogate: 

1078 field_coords = {f'{var}{COORDS_STR_ID}': 

1079 self.model_kwargs.get(f'{var}{COORDS_STR_ID}', kwds.get(f'{var}{COORDS_STR_ID}', None)) 

1080 for var in self.inputs} 

1081 inputs, field_coords = to_model_dataset(inputs, self.inputs, del_latent=True, **field_coords) 

1082 field_coords.update(kwds) 

1083 outputs = self.call_model(inputs, model_fidelity=use_model or 'best', output_path=model_dir, 

1084 executor=executor, **field_coords) 

1085 outputs, _ = to_surrogate_dataset(outputs, self.outputs, del_fields=True, **field_coords) 

1086 return {str(var): outputs[var] for var in outputs} 

1087 

1088 # Choose the correct index set and misc_coeff data structures 

1089 if incremental: 

1090 misc_coeff = copy.deepcopy(self.misc_coeff_train) 

1091 self.update_misc_coeff(index_set, self.active_set, misc_coeff) 

1092 index_set = self.active_set.union(index_set) 

1093 else: 

1094 index_set, misc_coeff = self._match_index_set(index_set, misc_coeff) 

1095 

1096 # Format inputs for surrogate prediction (all scalars at this point, including latent coeffs) 

1097 inputs, loop_shape = format_inputs(inputs) # {'x': (N,)} 

1098 outputs = {} 

1099 

1100 # Handle prediction with empty active set (return nan) 

1101 if len(index_set) == 0: 

1102 self.logger.warning(f"Component '{self.name}' has an empty active set. " 

1103 f"Has the surrogate been trained yet? Returning NaNs...") 

1104 for var in self.outputs: 

1105 outputs[var.name] = np.full(loop_shape, np.nan) 

1106 return outputs 

1107 

1108 y_vars = self._surrogate_outputs() # Only request this component's specified outputs (ignore all extras) 

1109 

1110 # Combination technique MISC surrogate prediction 

1111 results = [] 

1112 coeffs = [] 

1113 for alpha, beta in index_set: 

1114 comb_coeff = misc_coeff[alpha, beta] 

1115 if np.abs(comb_coeff) > 0: 

1116 coeffs.append(comb_coeff) 

1117 args = (self.misc_states.get((alpha, beta)), 

1118 self.get_training_data(alpha, beta, y_vars=y_vars, cached=True)) 

1119 

1120 results.append(self.interpolator.predict(inputs, *args) if executor is None else 

1121 executor.submit(self.interpolator.predict, inputs, *args)) 

1122 

1123 if executor is not None: 

1124 wait(results, timeout=None, return_when=ALL_COMPLETED) 

1125 results = [future.result() for future in results] 

1126 

1127 for coeff, interp_pred in zip(coeffs, results): 

1128 for var, arr in interp_pred.items(): 

1129 if outputs.get(var) is None: 

1130 outputs[str(var)] = coeff * arr 

1131 else: 

1132 outputs[str(var)] += coeff * arr 

1133 

1134 return format_outputs(outputs, loop_shape) 

1135 

1136 def update_misc_coeff(self, new_indices: IndexSet, index_set: Literal['test', 'train'] | IndexSet = 'train', 

1137 misc_coeff: MiscTree = None): 

1138 """Update MISC coefficients incrementally resulting from the addition of new indices to an index set. 

1139 

1140 !!! Warning "Incremental updates" 

1141 This function is used to update the MISC coefficients stored in `misc_coeff` after adding new indices 

1142 to the given `index_set`. If a custom `index_set` or `misc_coeff` are provided, the user is responsible 

1143 for ensuring the data structures are consistent. Since this is an incremental update, this means all 

1144 existing coefficients for every index in `index_set` should be precomputed and stored in `misc_coeff`. 

1145 

1146 :param new_indices: a set of $(\\alpha, \\beta)$ tuples that are being added to the `index_set` 

1147 :param index_set: the active index set, defaults to `self.active_set` if `'train'` or both 

1148 `self.active_set + self.candidate_set` if `'test'` 

1149 :param misc_coeff: the data structure holding the MISC coefficients to update, which defaults to the 

1150 training or testing coefficients depending on the `index_set` parameter. This data structure 

1151 is modified in place. 

1152 """ 

1153 index_set, misc_coeff = self._match_index_set(index_set, misc_coeff) 

1154 

1155 for new_alpha, new_beta in new_indices: 

1156 new_ind = np.array(new_alpha + new_beta) 

1157 

1158 # Update all existing/new coefficients if they are a distance of [0, 1] "below" the new index 

1159 # Note that new indices can only be [0, 1] away from themselves -- not any other new indices 

1160 for old_alpha, old_beta in itertools.chain(index_set, [(new_alpha, new_beta)]): 

1161 old_ind = np.array(old_alpha + old_beta) 

1162 diff = new_ind - old_ind 

1163 if np.all(np.isin(diff, [0, 1])): 

1164 if misc_coeff.get((old_alpha, old_beta)) is None: 

1165 misc_coeff[old_alpha, old_beta] = 0 

1166 misc_coeff[old_alpha, old_beta] += (-1) ** int(np.sum(np.abs(diff))) 

1167 

1168 def activate_index(self, alpha: MultiIndex, beta: MultiIndex, model_dir: str | Path = None, 

1169 executor: Executor = None, weight_fcns: dict[str, callable] | Literal['pdf'] | None = 'pdf'): 

1170 """Add a multi-index to the active set and all neighbors to the candidate set. 

1171 

1172 !!! Warning 

1173 The user of this function is responsible for ensuring that the index set maintains downward-closedness. 

1174 That is, only activate indices that are neighbors of the current active set. 

1175 

1176 :param alpha: A multi-index specifying model fidelity 

1177 :param beta: A multi-index specifying surrogate fidelity 

1178 :param model_dir: Directory to save model output files 

1179 :param executor: Executor for parallel execution of model on training data if the model is not vectorized 

1180 :param weight_fcns: Dictionary of weight functions for each input variable (defaults to the variable PDFs); 

1181 each function should be callable as `fcn(x: np.ndarray) -> np.ndarray`, where the input 

1182 is an array of normalized input data and the output is an array of weights. If None, then 

1183 no weighting is applied. 

1184 """ 

1185 if (alpha, beta) in self.active_set: 

1186 self.logger.warning(f'Multi-index {(alpha, beta)} is already in the active index set. Ignoring...') 

1187 return 

1188 if (alpha, beta) not in self.candidate_set and (sum(alpha) + sum(beta)) > 0: 

1189 # Can only activate the initial index (0, 0, ... 0) without it being in the candidate set 

1190 self.logger.warning(f'Multi-index {(alpha, beta)} is not a neighbor of the active index set, so it ' 

1191 f'cannot be activated. Please only add multi-indices from the candidate set. ' 

1192 f'Ignoring...') 

1193 return 

1194 

1195 # Collect all neighbor candidate indices; sort by largest model cost first 

1196 neighbors = self._neighbors(alpha, beta, forward=True) 

1197 indices = list(itertools.chain([(alpha, beta)] if (alpha, beta) not in self.candidate_set else [], neighbors)) 

1198 indices.sort(key=lambda ele: self.model_costs.get(ele[0], sum(ele[0])), reverse=True) 

1199 

1200 # Refine and collect all new model inputs (i.e. training points) requested by the new candidates 

1201 alpha_list = [] # keep track of model fidelities 

1202 design_list = [] # keep track of training data coordinates/locations/indices 

1203 model_inputs = {} # concatenate all model inputs 

1204 field_coords = {f'{var}{COORDS_STR_ID}': self.model_kwargs.get(f'{var}{COORDS_STR_ID}', None) 

1205 for var in self.inputs} 

1206 domains = self.inputs.get_domains() 

1207 

1208 if weight_fcns == 'pdf': 

1209 weight_fcns = self.inputs.get_pdfs() 

1210 

1211 for a, b in indices: 

1212 if ((a, b[:len(self.data_fidelity)] + (0,) * len(self.surrogate_fidelity)) in 

1213 self.active_set.union(self.candidate_set)): 

1214 # Don't refine training data if only updating surrogate fidelity indices 

1215 # Training data is the same for all surrogate fidelity indices, given constant data fidelity 

1216 design_list.append([]) 

1217 continue 

1218 

1219 design_coords, design_pts = self.training_data.refine(a, b[:len(self.data_fidelity)], 

1220 domains, weight_fcns) 

1221 design_pts, fc = to_model_dataset(design_pts, self.inputs, del_latent=True, **field_coords) 

1222 

1223 # Remove duplicate (alpha, coords) pairs -- so you don't evaluate the model twice for the same input 

1224 i = 0 

1225 del_idx = [] 

1226 for other_design in design_list: 

1227 for other_coord in other_design: 

1228 for j, curr_coord in enumerate(design_coords): 

1229 if curr_coord == other_coord and a == alpha_list[i] and j not in del_idx: 

1230 del_idx.append(j) 

1231 i += 1 

1232 design_coords = [design_coords[j] for j in range(len(design_coords)) if j not in del_idx] 

1233 design_pts = {var: np.delete(arr, del_idx, axis=0) for var, arr in design_pts.items()} 

1234 

1235 alpha_list.extend([tuple(a)] * len(design_coords)) 

1236 design_list.append(design_coords) 

1237 field_coords.update(fc) 

1238 for var in design_pts: 

1239 model_inputs[var] = design_pts[var] if model_inputs.get(var) is None else ( 

1240 np.concatenate((model_inputs[var], design_pts[var]), axis=0)) 

1241 

1242 # Evaluate model at designed training points 

1243 if len(alpha_list) > 0: 

1244 self.logger.info(f"Running {len(alpha_list)} total model evaluations for component " 

1245 f"'{self.name}' new candidate indices: {indices}...") 

1246 model_outputs = self.call_model(model_inputs, model_fidelity=alpha_list, output_path=model_dir, 

1247 executor=executor, track_costs=True, **field_coords) 

1248 self.logger.info(f"Model evaluations complete for component '{self.name}'.") 

1249 errors = model_outputs.pop('errors', {}) 

1250 else: 

1251 self._model_start_time = -1.0 

1252 self._model_end_time = -1.0 

1253 

1254 # Unpack model outputs and update states 

1255 start_idx = 0 

1256 for i, (a, b) in enumerate(indices): 

1257 num_train_pts = len(design_list[i]) 

1258 end_idx = start_idx + num_train_pts # Ensure loop dim of 1 gets its own axis (might have been squeezed) 

1259 

1260 if num_train_pts > 0: 

1261 yi_dict = {var: arr[np.newaxis, ...] if len(alpha_list) == 1 and arr.shape[0] != 1 else 

1262 arr[start_idx:end_idx, ...] for var, arr in model_outputs.items()} 

1263 

1264 # Check for errors and store 

1265 err_coords = [] 

1266 err_list = [] 

1267 for idx in list(errors.keys()): 

1268 if idx < end_idx: 

1269 err_info = errors.pop(idx) 

1270 err_info['index'] = idx - start_idx 

1271 err_coords.append(design_list[i][idx - start_idx]) 

1272 err_list.append(err_info) 

1273 if len(err_list) > 0: 

1274 self.logger.warning(f"Model errors occurred while adding candidate ({a}, {b}) for component " 

1275 f"{self.name}. Leaving NaN values in training data...") 

1276 self.training_data.set_errors(a, b[:len(self.data_fidelity)], err_coords, err_list) 

1277 

1278 # Compress field quantities and normalize 

1279 yi_dict, y_vars = to_surrogate_dataset(yi_dict, self.outputs, del_fields=False, **field_coords) 

1280 

1281 # Store training data, computational cost, and new interpolator state 

1282 self.training_data.set(a, b[:len(self.data_fidelity)], design_list[i], yi_dict) 

1283 self.training_data.impute_missing_data(a, b[:len(self.data_fidelity)]) 

1284 

1285 else: 

1286 y_vars = self._surrogate_outputs() 

1287 

1288 self.misc_costs[a, b] = num_train_pts 

1289 self.misc_states[a, b] = self.interpolator.refine(b[len(self.data_fidelity):], 

1290 self.training_data.get(a, b[:len(self.data_fidelity)], 

1291 y_vars=y_vars, skip_nan=True), 

1292 self.misc_states.get((alpha, beta)), 

1293 domains) 

1294 start_idx = end_idx 

1295 

1296 # Move to the active index set 

1297 s = set() 

1298 s.add((alpha, beta)) 

1299 self.update_misc_coeff(IndexSet(s), index_set='train') 

1300 if (alpha, beta) in self.candidate_set: 

1301 self.candidate_set.remove((alpha, beta)) 

1302 else: 

1303 # Only for initial index which didn't come from the candidate set 

1304 self.update_misc_coeff(IndexSet(s), index_set='test') 

1305 self.active_set.update(s) 

1306 

1307 self.update_misc_coeff(neighbors, index_set='test') # neighbors will only ever pass through here once 

1308 self.candidate_set.update(neighbors) 

1309 

1310 def gradient(self, inputs: dict | Dataset, 

1311 index_set: Literal['train', 'test'] | IndexSet = 'test', 

1312 misc_coeff: MiscTree = None, 

1313 derivative: Literal['first', 'second'] = 'first', 

1314 executor: Executor = None) -> Dataset: 

1315 """Evaluate the Jacobian or Hessian of the MISC surrogate approximation at new `inputs`, i.e. 

1316 the first or second derivatives, respectively. 

1317 

1318 :param inputs: `dict` of input arrays for each variable input 

1319 :param index_set: the active index set, defaults to `self.active_set` if `'train'` or both 

1320 `self.active_set + self.candidate_set` if `'test'` 

1321 :param misc_coeff: the data structure holding the MISC coefficients to use, which defaults to the 

1322 training or testing coefficients depending on the `index_set` parameter. 

1323 :param derivative: whether to compute the first or second derivative (i.e. Jacobian or Hessian) 

1324 :param executor: executor for looping over MISC coefficients (optional) 

1325 :returns: a `dict` of the Jacobian or Hessian of the surrogate approximation for each output variable 

1326 """ 

1327 if not self.has_surrogate: 

1328 self.logger.warning("No surrogate model available for gradient computation.") 

1329 return None 

1330 

1331 index_set, misc_coeff = self._match_index_set(index_set, misc_coeff) 

1332 inputs, loop_shape = format_inputs(inputs) # {'x': (N,)} 

1333 outputs = {} 

1334 

1335 if len(index_set) == 0: 

1336 for var in self.outputs: 

1337 outputs[var] = np.full(loop_shape, np.nan) 

1338 return outputs 

1339 y_vars = self._surrogate_outputs() 

1340 

1341 # Combination technique MISC gradient prediction 

1342 results = [] 

1343 coeffs = [] 

1344 for alpha, beta in index_set: 

1345 comb_coeff = misc_coeff[alpha, beta] 

1346 if np.abs(comb_coeff) > 0: 

1347 coeffs.append(comb_coeff) 

1348 func = self.interpolator.gradient if derivative == 'first' else self.interpolator.hessian 

1349 args = (self.misc_states.get((alpha, beta)), 

1350 self.get_training_data(alpha, beta, y_vars=y_vars, cached=True)) 

1351 

1352 results.append(func(inputs, *args) if executor is None else executor.submit(func, inputs, *args)) 

1353 

1354 if executor is not None: 

1355 wait(results, timeout=None, return_when=ALL_COMPLETED) 

1356 results = [future.result() for future in results] 

1357 

1358 for coeff, interp_pred in zip(coeffs, results): 

1359 for var, arr in interp_pred.items(): 

1360 if outputs.get(var) is None: 

1361 outputs[str(var)] = coeff * arr 

1362 else: 

1363 outputs[str(var)] += coeff * arr 

1364 

1365 return format_outputs(outputs, loop_shape) 

1366 

1367 def hessian(self, *args, **kwargs): 

1368 """Alias for `Component.gradient(*args, derivative='second', **kwargs)`.""" 

1369 return self.gradient(*args, derivative='second', **kwargs) 

1370 

1371 def model_kwarg_requested(self, kwarg_name: str) -> bool: 

1372 """Return whether the underlying component model requested this `kwarg_name`. Special kwargs include: 

1373 

1374 - `output_path` — a save directory created by `amisc` will be passed to the model for saving model output files. 

1375 - `alpha` — a tuple or list of model fidelity indices will be passed to the model to adjust fidelity. 

1376 - `input_vars` — a list of `Variable` objects will be passed to the model for input variable information. 

1377 - `output_vars` — a list of `Variable` objects will be passed to the model for output variable information. 

1378 

1379 :param kwarg_name: the argument to check for in the underlying component model's function signature kwargs 

1380 :returns: whether the component model requests this `kwarg` argument 

1381 """ 

1382 signature = inspect.signature(self.model) 

1383 for param in signature.parameters.values(): 

1384 if param.name == kwarg_name and param.default != param.empty: 

1385 return True 

1386 return False 

1387 

1388 def set_logger(self, log_file: str | Path = None, stdout: bool = None, logger: logging.Logger = None, 

1389 level: int = logging.INFO): 

1390 """Set a new `logging.Logger` object. 

1391 

1392 :param log_file: log to file (if provided) 

1393 :param stdout: whether to connect the logger to console (defaults to whatever is currently set or False) 

1394 :param logger: the logging object to use (if None, then a new logger is created; this will override 

1395 the `log_file` and `stdout` arguments if set) 

1396 :param level: the logging level to set (default is `logging.INFO`) 

1397 """ 

1398 if stdout is None: 

1399 stdout = False 

1400 if self._logger is not None: 

1401 for handler in self._logger.handlers: 

1402 if isinstance(handler, logging.StreamHandler): 

1403 stdout = True 

1404 break 

1405 self._logger = logger or get_logger(self.name, log_file=log_file, stdout=stdout, level=level) 

1406 

1407 def update_model(self, new_model: callable = None, model_kwargs: dict = None, **kwargs): 

1408 """Update the underlying component model or its kwargs.""" 

1409 if new_model is not None: 

1410 self.model = new_model 

1411 new_kwargs = self.model_kwargs.data 

1412 new_kwargs.update(model_kwargs or {}) 

1413 new_kwargs.update(kwargs) 

1414 self.model_kwargs = new_kwargs 

1415 

1416 def get_cost(self, alpha: MultiIndex, beta: MultiIndex) -> int: 

1417 """Return the total cost (i.e. number of model evaluations) required to add $(\\alpha, \\beta)$ to the 

1418 MISC approximation. 

1419 

1420 :param alpha: A multi-index specifying model fidelity 

1421 :param beta: A multi-index specifying surrogate fidelity 

1422 :returns: the total number of model evaluations required for adding this multi-index to the MISC approximation 

1423 """ 

1424 try: 

1425 return self.misc_costs[alpha, beta] 

1426 except Exception: 

1427 return 0 

1428 

1429 def get_model_timestamps(self): 

1430 """Return a tuple with the (start, end) timestamps for the most recent call to `call_model`. This 

1431 is useful for tracking the duration of model evaluations. Will return (None, None) if no model has been called. 

1432 """ 

1433 if self._model_start_time < 0 or self._model_end_time < 0: 

1434 return None, None 

1435 else: 

1436 return self._model_start_time, self._model_end_time 

1437 

1438 @staticmethod 

1439 def is_downward_closed(indices: IndexSet) -> bool: 

1440 """Return if a list of $(\\alpha, \\beta)$ multi-indices is downward-closed. 

1441 

1442 MISC approximations require a downward-closed set in order to use the combination-technique formula for the 

1443 coefficients (as implemented by `Component.update_misc_coeff()`). 

1444 

1445 !!! Example 

1446 The list `[( (0,), (0,) ), ( (1,), (0,) ), ( (1,), (1,) )]` is downward-closed. You can visualize this as 

1447 building a stack of cubes: in order to place a cube, all adjacent cubes must be present (does the logo 

1448 make sense now?). 

1449 

1450 :param indices: `IndexSet` of (`alpha`, `beta`) multi-indices 

1451 :returns: whether the set of indices is downward-closed 

1452 """ 

1453 # Iterate over every multi-index 

1454 for alpha, beta in indices: 

1455 # Every smaller multi-index must also be included in the indices list 

1456 sub_sets = [np.arange(tuple(alpha + beta)[i] + 1) for i in range(len(alpha) + len(beta))] 

1457 for ele in itertools.product(*sub_sets): 

1458 tup = (MultiIndex(ele[:len(alpha)]), MultiIndex(ele[len(alpha):])) 

1459 if tup not in indices: 

1460 return False 

1461 return True 

1462 

1463 def clear(self): 

1464 """Clear the component of all training data, index sets, and MISC states.""" 

1465 self.active_set.clear() 

1466 self.candidate_set.clear() 

1467 self.misc_states.clear() 

1468 self.misc_costs.clear() 

1469 self.misc_coeff_train.clear() 

1470 self.misc_coeff_test.clear() 

1471 self.model_costs.clear() 

1472 self.model_evals.clear() 

1473 self.training_data.clear() 

1474 self._model_start_time = -1.0 

1475 self._model_end_time = -1.0 

1476 self.clear_cache() 

1477 

1478 def serialize(self, keep_yaml_objects: bool = False, serialize_args: dict[str, tuple] = None, 

1479 serialize_kwargs: dict[str: dict] = None) -> dict: 

1480 """Convert to a `dict` with only standard Python types as fields and values. 

1481 

1482 :param keep_yaml_objects: whether to keep `Variable` or other yaml serializable objects instead of 

1483 also serializing them (default is False) 

1484 :param serialize_args: additional arguments to pass to the `serialize` method of each `Component` attribute; 

1485 specify as a `dict` of attribute names to tuple of arguments to pass 

1486 :param serialize_kwargs: additional keyword arguments to pass to the `serialize` method of each 

1487 `Component` attribute 

1488 :returns: a `dict` representation of the `Component` object 

1489 """ 

1490 serialize_args = serialize_args or dict() 

1491 serialize_kwargs = serialize_kwargs or dict() 

1492 d = {} 

1493 for key, value in self.__dict__.items(): 

1494 if value is not None and not key.startswith('_'): 

1495 if key == 'serializers': 

1496 # Update the serializers 

1497 serializers = self._validate_serializers({k: type(getattr(self, k)) for k in value.keys()}) 

1498 d[key] = {k: (v.obj if keep_yaml_objects else v.serialize()) for k, v in serializers.items()} 

1499 elif key in ['inputs', 'outputs'] and not keep_yaml_objects: 

1500 d[key] = value.serialize(**serialize_kwargs.get(key, {})) 

1501 elif key == 'model' and not keep_yaml_objects: 

1502 d[key] = YamlSerializable(obj=value).serialize() 

1503 elif key in ['data_fidelity', 'surrogate_fidelity', 'model_fidelity']: 

1504 if len(value) > 0: 

1505 d[key] = str(value) 

1506 elif key in ['active_set', 'candidate_set']: 

1507 if len(value) > 0: 

1508 d[key] = value.serialize() 

1509 elif key in ['misc_costs', 'misc_coeff_train', 'misc_coeff_test', 'misc_states']: 

1510 if len(value) > 0: 

1511 d[key] = value.serialize(keep_yaml_objects=keep_yaml_objects) 

1512 elif key in ['model_costs']: 

1513 if len(value) > 0: 

1514 d[key] = {str(k): float(v) for k, v in value.items()} 

1515 elif key in ['model_evals']: 

1516 if len(value) > 0: 

1517 d[key] = {str(k): int(v) for k, v in value.items()} 

1518 elif key in ComponentSerializers.__annotations__.keys(): 

1519 if key in ['training_data'] and not self.has_surrogate: 

1520 continue 

1521 else: 

1522 d[key] = value.serialize(*serialize_args.get(key, ()), **serialize_kwargs.get(key, {})) 

1523 else: 

1524 d[key] = value 

1525 return d 

1526 

1527 @classmethod 

1528 def deserialize(cls, serialized_data: dict, search_paths: list[str | Path] = None, 

1529 search_keys: list[str] = None) -> Component: 

1530 """Return a `Component` from `data`. Let pydantic handle field validation and conversion. If any component 

1531 data has been saved to file and the save file doesn't exist, then the loader will search for the file 

1532 in the current working directory and any additional search paths provided. 

1533 

1534 :param serialized_data: the serialized data to construct the object from 

1535 :param search_paths: paths to try and find any save files (i.e. if they moved since they were serialized), 

1536 will always search in the current working directory by default 

1537 :param search_keys: keys to search for save files in each component (default is all keys in 

1538 [`ComponentSerializers`][amisc.component.ComponentSerializers], in addition to variable 

1539 inputs and outputs) 

1540 """ 

1541 if isinstance(serialized_data, Component): 

1542 return serialized_data 

1543 elif callable(serialized_data): 

1544 # try to construct a component from a raw function (assume data fidelity is (2,) for each inspected input) 

1545 return cls(serialized_data, data_fidelity=(2,) * len(_inspect_function(serialized_data)[0])) 

1546 

1547 search_paths = search_paths or [] 

1548 search_keys = search_keys or [] 

1549 search_keys.extend(ComponentSerializers.__annotations__.keys()) 

1550 comp = serialized_data 

1551 

1552 for key in search_keys: 

1553 if (filename := comp.get(key, None)) is not None: 

1554 comp[key] = search_for_file(filename, search_paths=search_paths) # will ret original str if not found 

1555 

1556 for key in ['inputs', 'outputs']: 

1557 for var in comp.get(key, []): 

1558 if isinstance(var, dict): 

1559 if (compression := var.get('compression', None)) is not None: 

1560 var['compression'] = search_for_file(compression, search_paths=search_paths) 

1561 

1562 return cls(**comp) 

1563 

1564 @staticmethod 

1565 def _yaml_representer(dumper: yaml.Dumper, comp: Component) -> yaml.MappingNode: 

1566 """Convert a single `Component` object (`data`) to a yaml MappingNode (i.e. a `dict`).""" 

1567 save_path, save_file = _get_yaml_path(dumper) 

1568 serialize_kwargs = {} 

1569 for key, serializer in comp.serializers.items(): 

1570 if issubclass(serializer.obj, PickleSerializable): 

1571 filename = save_path / f'{save_file}_{comp.name}_{key}.pkl' 

1572 serialize_kwargs[key] = {'save_path': save_path / filename} 

1573 return dumper.represent_mapping(Component.yaml_tag, comp.serialize(serialize_kwargs=serialize_kwargs, 

1574 keep_yaml_objects=True)) 

1575 

1576 @staticmethod 

1577 def _yaml_constructor(loader: yaml.Loader, node): 

1578 """Convert the `!Component` tag in yaml to a `Component` object.""" 

1579 # Add a file search path in the same directory as the yaml file being loaded from 

1580 save_path, save_file = _get_yaml_path(loader) 

1581 if isinstance(node, yaml.SequenceNode): 

1582 return [ele if isinstance(ele, Component) else Component.deserialize(ele, search_paths=[save_path]) 

1583 for ele in loader.construct_sequence(node, deep=True)] 

1584 elif isinstance(node, yaml.MappingNode): 

1585 return Component.deserialize(loader.construct_mapping(node, deep=True), search_paths=[save_path]) 

1586 else: 

1587 raise NotImplementedError(f'The "{Component.yaml_tag}" yaml tag can only be used on a yaml sequence or ' 

1588 f'mapping, not a "{type(node)}".')