Coverage for src/amisc/component.py: 87%
906 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-10 15:12 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-10 15:12 +0000
1"""A `Component` is an `amisc` wrapper around a single discipline model. It manages surrogate construction and
2a hierarchy of modeling fidelities.
4!!! Info "Multi-indices in the MISC approximation"
5 A multi-index is a tuple of natural numbers, each specifying a level of fidelity. You will frequently see two
6 multi-indices: `alpha` and `beta`. The `alpha` (or $\\alpha$) indices specify physical model fidelity and get
7 passed to the model as an additional argument (e.g. things like discretization level, time step size, etc.).
8 The `beta` (or $\\beta$) indices specify surrogate refinement level, so typically an indication of the amount of
9 training data used or the complexity of the surrogate model. We divide $\\beta$ into `data_fidelity` and
10 `surrogate_fidelity` for specifying training data and surrogate model complexity, respectively.
12Includes:
14- `ModelKwargs` — a dataclass for storing model keyword arguments
15- `StringKwargs` — a dataclass for storing model keyword arguments as a string
16- `IndexSet` — a dataclass that maintains a list of multi-indices
17- `MiscTree` — a dataclass that maintains MISC data in a `dict` tree, indexed by `alpha` and `beta`
18- `Component` — a class that manages a single discipline model and its surrogate hierarchy
19"""
20from __future__ import annotations
22import ast
23import copy
24import inspect
25import itertools
26import logging
27import random
28import string
29import time
30import traceback
31import typing
32import warnings
33from collections import UserDict, deque
34from concurrent.futures import ALL_COMPLETED, Executor, wait
35from pathlib import Path
36from typing import Any, Callable, ClassVar, Iterable, Literal, Optional
38import numpy as np
39import yaml
40from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator
41from typing_extensions import TypedDict
43from amisc.interpolator import Interpolator, InterpolatorState, Lagrange
44from amisc.serialize import PickleSerializable, Serializable, StringSerializable, YamlSerializable
45from amisc.training import SparseGrid, TrainingData
46from amisc.typing import COORDS_STR_ID, LATENT_STR_ID, Dataset, MultiIndex
47from amisc.utils import (
48 _get_yaml_path,
49 _inspect_assignment,
50 _inspect_function,
51 format_inputs,
52 format_outputs,
53 get_logger,
54 search_for_file,
55 to_model_dataset,
56 to_surrogate_dataset,
57)
58from amisc.variable import Variable, VariableList
60__all__ = ["ModelKwargs", "StringKwargs", "IndexSet", "MiscTree", "Component"]
61_VariableLike = list[Variable | dict | str] | str | Variable | dict | VariableList # Generic type for Variables
64class ModelKwargs(UserDict, Serializable):
65 """Default dataclass for storing model keyword arguments in a `dict`. If you have kwargs that require
66 more complicated serialization/specification than a plain `dict`, then you can subclass from here.
67 """
69 def serialize(self):
70 return self.data
72 @classmethod
73 def deserialize(cls, serialized_data):
74 return ModelKwargs(**serialized_data)
76 @classmethod
77 def from_dict(cls, config: dict) -> ModelKwargs:
78 """Create a `ModelKwargs` object from a `dict` configuration."""
79 method = config.pop('method', 'default_kwargs').lower()
80 match method:
81 case 'default_kwargs':
82 return ModelKwargs(**config)
83 case 'string_kwargs':
84 return StringKwargs(**config)
85 case other:
86 config['method'] = other
87 return ModelKwargs(**config) # Pass the method through
90class StringKwargs(StringSerializable, ModelKwargs):
91 """Dataclass for storing model keyword arguments as a string."""
92 def __repr__(self):
93 return str(self.data)
95 def __str__(self):
96 def format_value(value):
97 if isinstance(value, str):
98 return f'"{value}"'
99 else:
100 return str(value)
102 kw_str = ", ".join([f"{key}={format_value(value)}" for key, value in self.items()])
103 return f"ModelKwargs({kw_str})"
106class IndexSet(set, Serializable):
107 """Dataclass that maintains a list of multi-indices. Overrides basic `set` functionality to ensure
108 elements are formatted correctly as `(alpha, beta)`; that is, as a tuple of `alpha` and
109 `beta`, which are themselves instances of a [`MultiIndex`][amisc.typing.MultiIndex] tuple.
111 !!! Example "An example index set"
112 $\\mathcal{I} = [(\\alpha, \\beta)_1 , (\\alpha, \\beta)_2, (\\alpha, \\beta)_3 , ...]$ would be specified
113 as `I = [((0, 0), (0, 0, 0)) , ((0, 1), (0, 1, 0)), ...]`.
114 """
115 def __init__(self, s=()):
116 s = [self._validate_element(ele) for ele in s]
117 super().__init__(s)
119 def __str__(self):
120 return str(list(self))
122 def __repr__(self):
123 return self.__str__()
125 def add(self, __element):
126 super().add(self._validate_element(__element))
128 def update(self, __elements):
129 super().update([self._validate_element(ele) for ele in __elements])
131 @classmethod
132 def _validate_element(cls, element):
133 """Validate that the element is a tuple of two multi-indices."""
134 alpha, beta = ast.literal_eval(element) if isinstance(element, str) else tuple(element)
135 return MultiIndex(alpha), MultiIndex(beta)
137 @classmethod
138 def _wrap_methods(cls, names):
139 """Make sure set operations return an `IndexSet` object."""
140 def wrap_method_closure(name):
141 def inner(self, *args):
142 result = getattr(super(cls, self), name)(*args)
143 if isinstance(result, set):
144 result = cls(result)
145 return result
146 inner.fn_name = name
147 setattr(cls, name, inner)
149 for name in names:
150 wrap_method_closure(name)
152 def serialize(self) -> list[str]:
153 """Return a list of each multi-index in the set serialized to a string."""
154 return [str(ele) for ele in self]
156 @classmethod
157 def deserialize(cls, serialized_data: list[str]) -> IndexSet:
158 """Deserialize a list of tuples to an `IndexSet`."""
159 return cls(serialized_data)
162IndexSet._wrap_methods(['__ror__', 'difference_update', '__isub__', 'symmetric_difference', '__rsub__', '__and__',
163 '__rand__', 'intersection', 'difference', '__iand__', 'union', '__ixor__',
164 'symmetric_difference_update', '__or__', 'copy', '__rxor__', 'intersection_update', '__xor__',
165 '__ior__', '__sub__'
166 ])
169class MiscTree(UserDict, Serializable):
170 """Dataclass that maintains MISC data in a `dict` tree, indexed by `alpha` and `beta`. Overrides
171 basic `dict` functionality to ensure elements are formatted correctly as `(alpha, beta) -> data`.
172 Used to store MISC coefficients, model costs, and interpolator states.
174 The underlying data structure is: `dict[MultiIndex, dict[MultiIndex, float | InterpolatorState]]`.
175 """
176 SERIALIZER_KEY = 'state_serializer'
178 def __init__(self, data: dict = None, **kwargs):
179 data_dict = data or {}
180 if isinstance(data_dict, MiscTree):
181 data_dict = data_dict.data
182 data_dict.update(kwargs)
183 super().__init__(self._validate_data(data_dict))
185 def serialize(self, *args, keep_yaml_objects=False, **kwargs) -> dict:
186 """Serialize `alpha, beta` indices to string and return a `dict` of internal data.
188 :param args: extra serialization arguments for internal `InterpolatorState`
189 :param keep_yaml_objects: whether to keep `YamlSerializable` instances in the serialization
190 :param kwargs: extra serialization keyword arguments for internal `InterpolatorState`
191 """
192 ret_dict = {}
193 if state_serializer := self.state_serializer(self.data):
194 ret_dict[self.SERIALIZER_KEY] = state_serializer.obj if keep_yaml_objects else state_serializer.serialize()
195 for alpha, beta, data in self:
196 ret_dict.setdefault(str(alpha), dict())
197 serialized_data = data.serialize(*args, **kwargs) if isinstance(data, InterpolatorState) else float(data)
198 ret_dict[str(alpha)][str(beta)] = serialized_data
199 return ret_dict
201 @classmethod
202 def deserialize(cls, serialized_data: dict) -> MiscTree:
203 """Deserialize a `dict` to a `MiscTree`.
205 :param serialized_data: the data to deserialize to a `MiscTree` object
206 """
207 return cls(serialized_data)
209 @classmethod
210 def state_serializer(cls, data: dict) -> YamlSerializable | None:
211 """Infer and return the interpolator state serializer from the `MiscTree` data (if possible). If no
212 `InterpolatorState` instance could be found, return `None`.
213 """
214 serializer = data.get(cls.SERIALIZER_KEY, None) # if `data` is serialized
215 if serializer is None: # Otherwise search for an InterpolatorState
216 for alpha, beta_dict in data.items():
217 if alpha == cls.SERIALIZER_KEY:
218 continue
219 for beta, value in beta_dict.items():
220 if isinstance(value, InterpolatorState):
221 serializer = type(value)
222 break
223 if serializer is not None:
224 break
225 return cls._validate_state_serializer(serializer)
227 @classmethod
228 def _validate_state_serializer(cls, state_serializer: Optional[str | type[Serializable] | YamlSerializable]
229 ) -> YamlSerializable | None:
230 if state_serializer is None:
231 return None
232 elif isinstance(state_serializer, YamlSerializable):
233 return state_serializer
234 elif isinstance(state_serializer, str):
235 return YamlSerializable.deserialize(state_serializer) # Load the serializer type from string
236 else:
237 return YamlSerializable(obj=state_serializer)
239 @classmethod
240 def _validate_data(cls, serialized_data: dict) -> dict:
241 state_serializer = cls.state_serializer(serialized_data)
242 ret_dict = {}
243 for alpha, beta_dict in serialized_data.items():
244 if alpha == cls.SERIALIZER_KEY:
245 continue
246 alpha_tup = MultiIndex(alpha)
247 ret_dict.setdefault(alpha_tup, dict())
248 for beta, data in beta_dict.items():
249 beta_tup = MultiIndex(beta)
250 if isinstance(data, InterpolatorState):
251 pass
252 elif state_serializer is not None:
253 data = state_serializer.obj.deserialize(data)
254 else:
255 data = float(data)
256 assert isinstance(data, InterpolatorState | float)
257 ret_dict[alpha_tup][beta_tup] = data
258 return ret_dict
260 @staticmethod
261 def _is_alpha_beta_access(key):
262 """Check that the key is of the format `(alpha, beta).`"""
263 return (isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], str | tuple)
264 and isinstance(key[1], str | tuple))
266 def get(self, key, default=None) -> float | InterpolatorState:
267 try:
268 return self.__getitem__(key)
269 except Exception:
270 return default
272 def update(self, data_dict: dict = None, **kwargs):
273 """Force `dict.update()` through the validator."""
274 data_dict = data_dict or dict()
275 data_dict.update(kwargs)
276 super().update(self._validate_data(data_dict))
278 def __setitem__(self, key: tuple | MultiIndex, value: float | InterpolatorState):
279 """Allows `misc_tree[alpha, beta] = value` usage."""
280 if self._is_alpha_beta_access(key):
281 alpha, beta = MultiIndex(key[0]), MultiIndex(key[1])
282 self.data.setdefault(alpha, dict())
283 self.data[alpha][beta] = value
284 else:
285 super().__setitem__(MultiIndex(key), value)
287 def __getitem__(self, key: tuple | MultiIndex) -> float | InterpolatorState:
288 """Allows `value = misc_tree[alpha, beta]` usage."""
289 if self._is_alpha_beta_access(key):
290 alpha, beta = MultiIndex(key[0]), MultiIndex(key[1])
291 return self.data[alpha][beta]
292 else:
293 return super().__getitem__(MultiIndex(key))
295 def clear(self):
296 """Clear the `MiscTree` data."""
297 for key in list(self.data.keys()):
298 del self.data[key]
300 def __eq__(self, other):
301 if isinstance(other, MiscTree):
302 try:
303 for alpha, beta, data in self:
304 if other[alpha, beta] != data:
305 return False
306 return True
307 except KeyError:
308 return False
309 else:
310 return False
312 def __iter__(self) -> Iterable[tuple[tuple, tuple, float | InterpolatorState]]:
313 for alpha, beta_dict in self.data.items():
314 if alpha == self.SERIALIZER_KEY:
315 continue
316 for beta, data in beta_dict.items():
317 yield alpha, beta, data
320class ComponentSerializers(TypedDict, total=False):
321 """Type hint for the `Component` class data serializers.
323 :ivar model_kwargs: the model kwarg object class
324 :ivar interpolator: the interpolator object class
325 :ivar training_data: the training data object class
326 """
327 model_kwargs: str | type[Serializable] | YamlSerializable
328 interpolator: str | type[Serializable] | YamlSerializable
329 training_data: str | type[Serializable] | YamlSerializable
332class Component(BaseModel, Serializable):
333 """A `Component` wrapper around a single discipline model. It manages MISC surrogate construction and a hierarchy of
334 modeling fidelities.
336 A `Component` can be constructed by specifying a model, input and output variables, and additional configurations
337 such as the maximum fidelity levels, the interpolator type, and the training data type. If `model_fidelity`,
338 `data_fidelity`, and `surrogate_fidelity` are all left empty, then the `Component` will not use a surrogate model,
339 instead calling the underlying model directly. The `Component` can be serialized to a YAML file and deserialized
340 back into a Python object.
342 !!! Example "A simple `Component`"
343 ```python
344 from amisc import Component, Variable
346 x = Variable(domain=(0, 1))
347 y = Variable()
348 model = lambda x: {'y': x['x']**2}
349 comp = Component(model=model, inputs=[x], outputs=[y])
350 ```
352 Each fidelity index in $\\alpha$ increases in refinement from $0$ up to `model_fidelity`. Each fidelity index
353 in $\\beta$ increases from $0$ up to `(data_fidelity, surrogate_fidelity)`. From the `Component's` perspective,
354 the concatenation of $(\\alpha, \\beta)$ fully specifies a single fidelity "level". The `Component`
355 forms an approximation of the model by summing up over many of these concatenated sets of $(\\alpha, \\beta)$.
357 :ivar name: the name of the `Component`
358 :ivar model: the model or function that is to be approximated, callable as `y = f(x)`
359 :ivar inputs: the input variables to the model
360 :ivar outputs: the output variables from the model
361 :ivar model_kwargs: extra keyword arguments to pass to the model
362 :ivar model_fidelity: the maximum level of refinement for each fidelity index in $\\alpha$ for model fidelity
363 :ivar data_fidelity: the maximum level of refinement for each fidelity index in $\\beta$ for training data
364 :ivar surrogate_fidelity: the max level of refinement for each fidelity index in $\\beta$ for the surrogate
365 :ivar interpolator: the interpolator to use as the underlying surrogate model
366 :ivar vectorized: whether the model supports vectorized input/output (i.e. datasets with arbitrary shape `(...,)`)
367 :ivar call_unpacked: whether the model expects unpacked input arguments (i.e. `func(x1, x2, ...)`)
368 :ivar ret_unpacked: whether the model returns unpacked output arguments (i.e. `func() -> (y1, y2, ...)`)
370 :ivar active_set: the current active set of multi-indices in the MISC approximation
371 :ivar candidate_set: all neighboring multi-indices that are candidates for inclusion in `active_set`
372 :ivar misc_states: the interpolator states for each multi-index in the MISC approximation
373 :ivar misc_costs: the computational cost associated with each multi-index in the MISC approximation
374 :ivar misc_coeff_train: the combination technique coefficients for the active set multi-indices
375 :ivar misc_coeff_test: the combination technique coefficients for the active and candidate set multi-indices
376 :ivar model_costs: the tracked average single fidelity model costs for each $\\alpha$
377 :ivar model_evals: the tracked number of evaluations for each $\\alpha$
378 :ivar training_data: the training data storage structure for the surrogate model
380 :ivar serializers: the custom serializers for the `[model_kwargs, interpolator, training_data]`
381 `Component` attributes -- these should be the _types_ of the serializer objects, which will
382 be inferred from the data passed in if not explicitly set
383 :ivar _logger: the logger for the `Component`
384 """
385 yaml_tag: ClassVar[str] = u'!Component'
386 model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True, validate_default=True,
387 protected_namespaces=(), extra='allow')
388 # Configuration
389 serializers: Optional[ComponentSerializers] = None
390 name: Optional[str] = None
391 model: str | Callable[[dict | Dataset, ...], dict | Dataset]
392 model_kwargs: str | dict | ModelKwargs = {}
393 inputs: _VariableLike
394 outputs: _VariableLike
395 model_fidelity: str | tuple = MultiIndex()
396 data_fidelity: str | tuple = MultiIndex()
397 surrogate_fidelity: str | tuple = MultiIndex()
398 interpolator: Any | Interpolator = Lagrange()
399 vectorized: bool = False
400 call_unpacked: Optional[bool] = None # If the model expects inputs/outputs like `func(x1, x2, ...)->(y1, y2, ...)
401 ret_unpacked: Optional[bool] = None
403 # Data storage/states for a MISC component
404 active_set: list | set | IndexSet = IndexSet() # set of active (alpha, beta) multi-indices
405 candidate_set: list | set | IndexSet = IndexSet() # set of candidate (alpha, beta) multi-indices
406 misc_states: dict | MiscTree = MiscTree() # (alpha, beta) -> Interpolator state
407 misc_costs: dict | MiscTree = MiscTree() # (alpha, beta) -> Added computational cost for this mult-index
408 misc_coeff_train: dict | MiscTree = MiscTree() # (alpha, beta) -> c_[alpha, beta] (active set only)
409 misc_coeff_test: dict | MiscTree = MiscTree() # (alpha, beta) -> c_[alpha, beta] (including candidate set)
410 model_costs: dict = dict() # Average single fidelity model costs (for each alpha)
411 model_evals: dict = dict() # Number of evaluations for each alpha
412 training_data: Any | TrainingData = SparseGrid() # Stores surrogate training data
414 # Internal
415 _logger: Optional[logging.Logger] = None
416 _model_start_time: float = -1.0 # Temporarily store the most recent model start timestamp from call_model
417 _model_end_time: float = -1.0 # Temporarily store the most recent model end timestamp from call_model
418 _cache: dict = dict() # Temporary cache for faster access to training data and similar
420 def __init__(self, /, model, *args, inputs=None, outputs=None, name=None, **kwargs):
421 if name is None:
422 name = _inspect_assignment('Component') # try to assign the name from inspection
423 name = name or model.__name__ or "Component_" + "".join(random.choices(string.digits, k=3))
425 # Determine how the model expects to be called and gather inputs/outputs
426 _ = self._validate_model_signature(model, args, inputs, outputs, kwargs.get('call_unpacked', None),
427 kwargs.get('ret_unpacked', None))
428 model, inputs, outputs, call_unpacked, ret_unpacked = _
429 kwargs['call_unpacked'] = call_unpacked
430 kwargs['ret_unpacked'] = ret_unpacked
432 # Gather all model kwargs (anything else passed in for kwargs is assumed to be a model kwarg)
433 model_kwargs = kwargs.get('model_kwargs', {})
434 for key in kwargs.keys() - self.model_fields.keys():
435 model_kwargs[key] = kwargs.pop(key)
436 kwargs['model_kwargs'] = model_kwargs
438 # Gather data serializers from type checks (if not passed in as a kwarg)
439 serializers = kwargs.get('serializers', {}) # directly passing serializers will override type checks
440 for key in ComponentSerializers.__annotations__.keys():
441 field = kwargs.get(key, None)
442 if isinstance(field, dict):
443 field_super = next(filter(lambda x: issubclass(x, Serializable),
444 typing.get_args(self.model_fields[key].annotation)), None)
445 field = field_super.from_dict(field) if field_super is not None else field
446 kwargs[key] = field
447 if not serializers.get(key, None):
448 serializers[key] = type(field) if isinstance(field, Serializable) else (
449 type(self.model_fields[key].default))
450 kwargs['serializers'] = serializers
452 super().__init__(model=model, inputs=inputs, outputs=outputs, name=name, **kwargs) # Runs pydantic validation
454 # Set internal properties
455 assert self.is_downward_closed(self.active_set.union(self.candidate_set))
456 self.set_logger()
458 @classmethod
459 def _validate_model_signature(cls, model, args=(), inputs=None, outputs=None,
460 call_unpacked=None, ret_unpacked=None):
461 """Parse model signature and decide how the model expects to be called based on what input/output information
462 is provided or inspected from the model signature.
463 """
464 if inputs is not None:
465 inputs = cls._validate_variables(inputs)
466 if outputs is not None:
467 outputs = cls._validate_variables(outputs)
468 model = cls._validate_model(model)
470 # Default to `dict` (i.e. packed) model call/return signatures
471 if call_unpacked is None:
472 call_unpacked = False
473 if ret_unpacked is None:
474 ret_unpacked = False
475 inputs_inspect, outputs_inspect = _inspect_function(model)
476 call_unpacked = call_unpacked or (len(inputs_inspect) > 1) # Assume multiple inputs require unpacking
477 ret_unpacked = ret_unpacked or (len(outputs_inspect) > 1) # Assume multiple outputs require unpacking
479 # Extract inputs/outputs from args
480 arg_inputs = ()
481 arg_outputs = ()
482 if len(args) > 0:
483 if call_unpacked:
484 if isinstance(args[0], dict | str | Variable):
485 arg_inputs = args[:len(inputs_inspect)]
486 arg_outputs = args[len(inputs_inspect):]
487 else:
488 arg_inputs = args[0]
489 arg_outputs = args[1:]
490 else:
491 arg_inputs = args[0] # Assume first arg is a single or list of inputs
492 arg_outputs = args[1:] # Assume rest are outputs
494 # Resolve inputs
495 inputs = inputs or []
496 inputs = VariableList.merge(inputs, arg_inputs)
497 if len(inputs) == 0:
498 inputs = inputs_inspect
499 call_unpacked = True
500 if len(inputs) == 0:
501 raise ValueError("Could not infer input variables from model signature. Either your model does not "
502 "accept input arguments or an error occurred during inspection.\nPlease provide the "
503 "inputs directly as `Component(inputs=[...])` or fix the model signature.")
504 if call_unpacked:
505 if not all([var == inputs_inspect[i] for i, var in enumerate(inputs)]):
506 warnings.warn(f"Mismatch between provided inputs: {inputs.values()} and inputs inferred from "
507 f"model signature: {inputs_inspect}. This may cause unexpected results.")
508 else:
509 if len(inputs_inspect) > 1:
510 warnings.warn(f"Model signature expects multiple input arguments: {inputs_inspect}. "
511 f"Please set `call_unpacked=True` to use this model signature for multiple "
512 f"inputs.\nOtherwise, move all inputs into a single `dict` argument and all "
513 f"extra arguments into the `model_kwargs` field.")
515 # Can't assume unpacked for single input/output, so warn user if they may be trying to do so
516 if len(inputs) == 1 and len(inputs_inspect) == 1 and str(inputs[0]) == str(inputs_inspect[0]):
517 warnings.warn(f"Single input argument: {inputs[0]} provided to model with input signature: "
518 f"{inputs_inspect}.\nIf you intended to use a single input argument, set "
519 f"`call_unpacked=True` to use this model signature.\nOtherwise, the first input will "
520 f"be passed to your model as a `dict`.\nIf you are expecting a `dict` input already, "
521 f"change the name of the input to not exactly "
522 f"match {inputs_inspect} in order to silence this warning.")
523 # Resolve outputs
524 outputs = outputs or []
525 outputs = VariableList.merge(outputs, *arg_outputs)
526 if len(outputs) == 0:
527 outputs = outputs_inspect
528 ret_unpacked = True
529 if len(outputs) == 0:
530 raise ValueError("Could not infer output variables from model inspection. Either your model does not "
531 "return outputs or an error occurred during inspection.\nPlease provide the "
532 "outputs directly as `Component(outputs=[...])` or fix the model return values.")
533 if ret_unpacked:
534 if not all([var == outputs_inspect[i] for i, var in enumerate(outputs)]):
535 warnings.warn(f"Mismatch between provided outputs: {outputs.values()} and outputs inferred "
536 f"from model: {outputs_inspect}. This may cause unexpected results.")
537 else:
538 if len(outputs_inspect) > 1:
539 warnings.warn(f"Model expects multiple return values: {outputs_inspect}. Please set "
540 f"`ret_unpacked=True` to use this model signature for multiple outputs.\n"
541 f"Otherwise, move all outputs into a single `dict` return value.")
543 if len(outputs) == 1 and len(outputs_inspect) == 1 and str(outputs[0]) == str(outputs_inspect[0]):
544 warnings.warn(f"Single output: {outputs[0]} provided to model with single expected return: "
545 f"{outputs_inspect}.\nIf you intended to output a single return value, set "
546 f"`ret_unpacked=True` to use this model signature.\nOtherwise, the output should "
547 f"be returned from your model as a `dict`.\nIf you are returning a `dict` already, "
548 f"then change its name to not exactly match {outputs_inspect} in order to silence "
549 f"this warning.")
550 return model, inputs, outputs, call_unpacked, ret_unpacked
552 def __repr__(self):
553 s = f'---- {self.name} ----\n'
554 s += f'Inputs: {self.inputs}\n'
555 s += f'Outputs: {self.outputs}\n'
556 s += f'Model: {self.model}'
557 return s
559 def __str__(self):
560 return self.__repr__()
562 @field_validator('serializers')
563 @classmethod
564 def _validate_serializers(cls, serializers: ComponentSerializers) -> ComponentSerializers:
565 """Make sure custom serializer object types are themselves serializable as `YamlSerializable`."""
566 for key, serializer in serializers.items():
567 if serializer is None:
568 serializers[key] = None
569 elif isinstance(serializer, YamlSerializable):
570 serializers[key] = serializer
571 elif isinstance(serializer, str):
572 serializers[key] = YamlSerializable.deserialize(serializer)
573 else:
574 serializers[key] = YamlSerializable(obj=serializer)
575 return serializers
577 @field_validator('model')
578 @classmethod
579 def _validate_model(cls, model: str | Callable) -> Callable:
580 """Expects model as a callable or a yaml !!python/name string representation."""
581 if isinstance(model, str):
582 return YamlSerializable.deserialize(model).obj
583 else:
584 return model
586 @field_validator('inputs', 'outputs')
587 @classmethod
588 def _validate_variables(cls, variables: _VariableLike) -> VariableList:
589 if isinstance(variables, VariableList):
590 return variables
591 else:
592 return VariableList.deserialize(variables)
594 @field_validator('model_fidelity', 'data_fidelity', 'surrogate_fidelity')
595 @classmethod
596 def _validate_indices(cls, multi_index) -> MultiIndex:
597 return MultiIndex(multi_index)
599 @field_validator('active_set', 'candidate_set')
600 @classmethod
601 def _validate_index_set(cls, index_set) -> IndexSet:
602 return IndexSet.deserialize(index_set)
604 @field_validator('misc_states', 'misc_costs', 'misc_coeff_train', 'misc_coeff_test')
605 @classmethod
606 def _validate_misc_tree(cls, misc_tree) -> MiscTree:
607 return MiscTree.deserialize(misc_tree)
609 @field_validator('model_costs')
610 @classmethod
611 def _validate_model_costs(cls, model_costs: dict) -> dict:
612 return {MultiIndex(key): float(value) for key, value in model_costs.items()}
614 @field_validator('model_evals')
615 @classmethod
616 def _validate_model_evals(cls, model_evals: dict) -> dict:
617 return {MultiIndex(key): int(value) for key, value in model_evals.items()}
619 @field_validator('model_kwargs', 'interpolator', 'training_data')
620 @classmethod
621 def _validate_arbitrary_serializable(cls, data: Any, info: ValidationInfo) -> Any:
622 """Use the stored custom serialization classes to deserialize arbitrary objects."""
623 serializer = info.data.get('serializers').get(info.field_name).obj
624 if isinstance(data, Serializable):
625 return data
626 else:
627 return serializer.deserialize(data)
629 @property
630 def xdim(self) -> int:
631 return len(self.inputs)
633 @property
634 def ydim(self) -> int:
635 return len(self.outputs)
637 @property
638 def max_alpha(self) -> MultiIndex:
639 """The maximum model fidelity multi-index (alias for `model_fidelity`)."""
640 return self.model_fidelity
642 @property
643 def max_beta(self) -> MultiIndex:
644 """The maximum surrogate fidelity multi-index is a combination of training and interpolator indices."""
645 return self.data_fidelity + self.surrogate_fidelity
647 @property
648 def has_surrogate(self) -> bool:
649 """The component has no surrogate model if there are no fidelity indices."""
650 return (len(self.max_alpha) + len(self.max_beta)) > 0
652 @property
653 def logger(self) -> logging.Logger:
654 return self._logger
656 @logger.setter
657 def logger(self, logger: logging.Logger):
658 self._logger = logger
660 def __eq__(self, other):
661 if isinstance(other, Component):
662 return (self.model.__code__.co_code == other.model.__code__.co_code and self.inputs == other.inputs
663 and self.outputs == other.outputs and self.name == other.name
664 and self.model_kwargs.data == other.model_kwargs.data
665 and self.model_fidelity == other.model_fidelity and self.max_beta == other.max_beta and
666 self.interpolator == other.interpolator
667 and self.active_set == other.active_set and self.candidate_set == other.candidate_set
668 and self.misc_states == other.misc_states and self.misc_costs == other.misc_costs
669 )
670 else:
671 return False
673 def _neighbors(self, alpha: MultiIndex, beta: MultiIndex, active_set: IndexSet = None, forward: bool = True):
674 """Get all possible forward or backward multi-index neighbors (distance of one unit vector away).
676 :param alpha: the model fidelity index
677 :param beta: the surrogate fidelity index
678 :param active_set: the set of active multi-indices
679 :param forward: whether to get forward or backward neighbors
680 :returns: a set of multi-indices that are neighbors of the input multi-index pair `(alpha, beta)`
681 """
682 active_set = active_set or self.active_set
683 ind = np.array(alpha + beta)
684 max_ind = np.array(self.max_alpha + self.max_beta)
685 new_candidates = IndexSet()
686 for i in range(len(ind)):
687 ind_new = ind.copy()
688 ind_new[i] += 1 if forward else -1
690 # Don't add if we surpass a refinement limit or lower bound
691 if np.any(ind_new > max_ind) or np.any(ind_new < 0):
692 continue
694 # Add the new index if it maintains downward-closedness
695 down_closed = True
696 for j in range(len(ind)):
697 ind_check = ind_new.copy()
698 ind_check[j] -= 1
699 if ind_check[j] >= 0:
700 tup_check = (MultiIndex(ind_check[:len(alpha)]), MultiIndex(ind_check[len(alpha):]))
701 if tup_check not in active_set and tup_check != (alpha, beta):
702 down_closed = False
703 break
704 if down_closed:
705 new_candidates.add((ind_new[:len(alpha)], ind_new[len(alpha):]))
707 return new_candidates
709 def _surrogate_outputs(self):
710 """Helper function to get the names of the surrogate outputs (including latent variables)."""
711 y_vars = []
712 for var in self.outputs:
713 if var.compression is not None:
714 for i in range(var.compression.latent_size()):
715 y_vars.append(f'{var.name}{LATENT_STR_ID}{i}')
716 else:
717 y_vars.append(var.name)
718 return y_vars
720 def _match_index_set(self, index_set, misc_coeff):
721 """Helper function to grab the correct data structures for the given index set and MISC coefficients."""
722 if misc_coeff is None:
723 match index_set:
724 case 'train':
725 misc_coeff = self.misc_coeff_train
726 case 'test':
727 misc_coeff = self.misc_coeff_test
728 case other:
729 raise ValueError(f"Index set must be 'train' or 'test' if you do not provide `misc_coeff`. "
730 f"{other} not recognized.")
731 if isinstance(index_set, str):
732 match index_set:
733 case 'train':
734 index_set = self.active_set
735 case 'test':
736 index_set = self.active_set.union(self.candidate_set)
737 case other:
738 raise ValueError(f"Index set must be 'train' or 'test'. {other} not recognized.")
740 return index_set, misc_coeff
742 def cache(self, kind: list | Literal["training"] = "training"):
743 """Cache data for quicker access. Only `"training"` is supported.
745 :param kind: the type(s) of data to cache (only "training" is supported). This will cache the
746 surrogate training data with nans removed.
747 """
748 if not isinstance(kind, list):
749 kind = [kind]
751 if "training" in kind:
752 self._cache.setdefault("training", {})
753 y_vars = self._surrogate_outputs()
754 for alpha, beta in self.active_set.union(self.candidate_set):
755 self._cache["training"].setdefault(alpha, {})
757 if beta not in self._cache["training"][alpha]:
758 self._cache["training"][alpha][beta] = self.training_data.get(alpha, beta[:len(self.data_fidelity)],
759 y_vars=y_vars, skip_nan=True)
761 def clear_cache(self):
762 """Clear cached data."""
763 self._cache.clear()
765 def get_training_data(self, alpha: Literal['best', 'worst'] | MultiIndex = 'best',
766 beta: Literal['best', 'worst'] | MultiIndex = 'best',
767 y_vars: list = None,
768 cached: bool = False) -> tuple[Dataset, Dataset]:
769 """Get all training data for a given multi-index pair `(alpha, beta)`.
771 :param alpha: the model fidelity index (defaults to the maximum available model fidelity)
772 :param beta: the surrogate fidelity index (defaults to the maximum available surrogate fidelity)
773 :param y_vars: the training data to return (defaults to all stored data)
774 :param cached: if True, will get cached training data if available (this will ignore `y_vars` and
775 only grab whatever is in the cache, which is surrogate outputs only and no nans)
776 :returns: `(xtrain, ytrain)` - the training data for the given multi-indices
777 """
778 # Find the best alpha
779 if alpha == 'best':
780 alpha_best = ()
781 for a, _ in self.active_set.union(self.candidate_set):
782 if sum(a) > sum(alpha_best):
783 alpha_best = a
784 alpha = alpha_best
785 elif alpha == 'worst':
786 alpha = (0,) * len(self.max_alpha)
788 # Find the best beta for the given alpha
789 if beta == 'best':
790 beta_best = ()
791 for a, b in self.active_set.union(self.candidate_set):
792 if a == alpha and sum(b) > sum(beta_best):
793 beta_best = b
794 beta = beta_best
795 elif beta == 'worst':
796 beta = (0,) * len(self.max_beta)
798 try:
799 if cached and (data := self._cache.get("training", {}).get(alpha, {}).get(beta)) is not None:
800 return data
801 else:
802 return self.training_data.get(alpha, beta[:len(self.data_fidelity)], y_vars=y_vars, skip_nan=True)
803 except Exception as e:
804 self.logger.error(f"Error getting training data for alpha={alpha}, beta={beta}.")
805 raise e
807 def call_model(self, inputs: dict | Dataset,
808 model_fidelity: Literal['best', 'worst'] | tuple | list = None,
809 output_path: str | Path = None,
810 executor: Executor = None,
811 track_costs: bool = False,
812 **kwds) -> Dataset:
813 """Wrapper function for calling the underlying component model.
815 This function formats the input data, calls the model, and processes the output data.
816 It supports vectorized calls, parallel execution using an executor, or serial execution. These options are
817 checked in that order, with the first available method used. Must set `Component.vectorized=True` if the
818 model supports input arrays of the form `(N,)` or even arbitrary shape `(...,)`.
820 !!! Warning "Parallel Execution"
821 The underlying model must be defined in a global module scope if `pickle` is the serialization method for
822 the provided `Executor`.
824 !!! Note "Additional return values"
825 The model can return additional items that are not part of `Component.outputs`. These items are returned
826 as object arrays in the output `dict`. Two special return values are `model_cost` and `output_path`.
827 Returning `model_cost` will store the computational cost of a single model evaluation (which is used by
828 `amisc` adaptive surrogate training). Returning `output_path` will store the output file name if the model
829 wrote any files to disk.
831 !!! Note "Handling errors"
832 If the underlying component model raises an exception, the error is stored in `output_dict['errors']` with
833 the index of the input data that caused the error. The output data for that index is set to `np.nan`
834 for each output variable.
836 :param inputs: The input data for the model, formatted as a `dict` with a key for each input variable and
837 a corresponding value that is an array of the input data. If specified as a plain list, then the
838 order is assumed the same as `Component.inputs`.
839 :param model_fidelity: Fidelity indices to tune the model fidelity (model must request this
840 in its keyword arguments).
841 :param output_path: Directory to save model output files (model must request this in its keyword arguments).
842 :param executor: Executor for parallel execution if the model is not vectorized (optional).
843 :param track_costs: Whether to track the computational cost of each model evaluation.
844 :param kwds: Additional keyword arguments to pass to the model (model must request these in its keyword args).
845 :returns: The output data from the model, formatted as a `dict` with a key for each output variable and a
846 corresponding value that is an array of the output data.
847 """
848 # Format inputs to a common loop shape (fail if missing any)
849 if len(inputs) == 0:
850 return {} # your fault
851 if isinstance(inputs, list | np.ndarray):
852 inputs = np.atleast_1d(inputs)
853 inputs = {var.name: inputs[..., i] for i, var in enumerate(self.inputs)}
855 var_shape = {}
856 for var in self.inputs:
857 s = None
858 if (arr := kwds.get(f'{var.name}{COORDS_STR_ID}')) is not None:
859 if not np.issubdtype(arr.dtype, np.object_): # if not object array, then it's a single coordinate set
860 s = arr.shape if len(arr.shape) == 1 else arr.shape[:-1] # skip the coordinate dim (last axis)
861 if var.compression is not None:
862 for field in var.compression.fields:
863 var_shape[field] = s
864 else:
865 var_shape[var.name] = s
866 inputs, loop_shape = format_inputs(inputs, var_shape=var_shape)
868 N = int(np.prod(loop_shape))
869 list_alpha = isinstance(model_fidelity, list | np.ndarray)
870 alpha_requested = self.model_kwarg_requested('model_fidelity')
871 for var in self.inputs:
872 if var.compression is not None:
873 for field in var.compression.fields:
874 if field not in inputs:
875 raise ValueError(f"Missing field '{field}' for input variable '{var}'.")
876 elif var.name not in inputs:
877 raise ValueError(f"Missing input variable '{var.name}'.")
879 # Pass extra requested items to the model kwargs
880 kwargs = copy.deepcopy(self.model_kwargs.data)
881 if self.model_kwarg_requested('output_path'):
882 kwargs['output_path'] = output_path
883 if self.model_kwarg_requested('input_vars'):
884 kwargs['input_vars'] = self.inputs
885 if self.model_kwarg_requested('output_vars'):
886 kwargs['output_vars'] = self.outputs
887 if alpha_requested:
888 if not list_alpha:
889 model_fidelity = [model_fidelity] * N
890 for i in range(N):
891 if model_fidelity[i] == 'best':
892 model_fidelity[i] = self.max_alpha
893 elif model_fidelity[i] == 'worst':
894 model_fidelity[i] = (0,) * len(self.model_fidelity)
896 for k, v in kwds.items():
897 if self.model_kwarg_requested(k):
898 kwargs[k] = v
900 # Compute model (vectorized, executor parallel, or serial)
901 errors = {}
902 if self.vectorized:
903 if alpha_requested:
904 kwargs['model_fidelity'] = np.atleast_1d(model_fidelity).reshape((N, -1))
906 self._model_start_time = time.time()
907 output_dict = self.model(*[inputs[var.name] for var in self.inputs], **kwargs) if self.call_unpacked \
908 else self.model(inputs, **kwargs)
909 self._model_end_time = time.time()
911 if self.ret_unpacked:
912 output_dict = (output_dict,) if not isinstance(output_dict, tuple) else output_dict
913 output_dict = {out_var.name: output_dict[i] for i, out_var in enumerate(self.outputs)}
914 else:
915 self._model_start_time = time.time()
916 if executor is None: # Serial
917 results = deque(maxlen=N)
918 for i in range(N):
919 try:
920 if alpha_requested:
921 kwargs['model_fidelity'] = model_fidelity[i]
922 ret = self.model(*[{k: v[i] for k, v in inputs.items()}[var.name] for var in self.inputs],
923 **kwargs) if self.call_unpacked else (
924 self.model({k: v[i] for k, v in inputs.items()}, **kwargs))
925 if self.ret_unpacked:
926 ret = (ret,) if not isinstance(ret, tuple) else ret
927 ret = {out_var.name: ret[i] for i, out_var in enumerate(self.outputs)}
928 results.append(ret)
929 except Exception:
930 results.append({'inputs': {k: v[i] for k, v in inputs.items()}, 'index': i,
931 'model_kwargs': kwargs.copy(), 'error': traceback.format_exc()})
932 else: # Parallel
933 results = deque(maxlen=N)
934 futures = []
935 for i in range(N):
936 if alpha_requested:
937 kwargs['model_fidelity'] = model_fidelity[i]
938 fs = executor.submit(self.model,
939 *[{k: v[i] for k, v in inputs.items()}[var.name] for var in self.inputs],
940 **kwargs) if self.call_unpacked else (
941 executor.submit(self.model, {k: v[i] for k, v in inputs.items()}, **kwargs))
942 futures.append(fs)
943 wait(futures, timeout=None, return_when=ALL_COMPLETED)
945 for i, fs in enumerate(futures):
946 try:
947 if alpha_requested:
948 kwargs['model_fidelity'] = model_fidelity[i]
949 ret = fs.result()
950 if self.ret_unpacked:
951 ret = (ret,) if not isinstance(ret, tuple) else ret
952 ret = {out_var.name: ret[i] for i, out_var in enumerate(self.outputs)}
953 results.append(ret)
954 except Exception:
955 results.append({'inputs': {k: v[i] for k, v in inputs.items()}, 'index': i,
956 'model_kwargs': kwargs.copy(), 'error': traceback.format_exc()})
957 self._model_end_time = time.time()
959 # Collect parallel/serial results
960 output_dict = {}
961 for i in range(N):
962 res = results.popleft()
963 if 'error' in res:
964 errors[i] = res
965 else:
966 for key, val in res.items():
967 # Save this component's variables
968 is_component_var = False
969 for var in self.outputs:
970 if var.compression is not None: # field quantity return values (save as object arrays)
971 if key in var.compression.fields or key == f'{var}{COORDS_STR_ID}':
972 if output_dict.get(key) is None:
973 output_dict.setdefault(key, np.full((N,), None, dtype=object))
974 output_dict[key][i] = np.atleast_1d(val)
975 is_component_var = True
976 break
977 elif key == var:
978 if output_dict.get(key) is None:
979 output_dict.setdefault(key, np.full((N, *np.atleast_1d(val).shape), np.nan))
980 output_dict[key][i, ...] = np.atleast_1d(val)
981 is_component_var = True
982 break
984 # Otherwise, save other objects
985 if not is_component_var:
986 # Save singleton numeric values as numeric arrays (model costs, etc.)
987 _val = np.atleast_1d(val)
988 if key == 'model_cost' or (np.issubdtype(_val.dtype, np.number)
989 and len(_val.shape) == 1 and _val.shape[0] == 1):
990 if output_dict.get(key) is None:
991 output_dict.setdefault(key, np.full((N,), np.nan))
992 output_dict[key][i] = _val[0]
993 else:
994 # Otherwise save into a generic object array
995 if output_dict.get(key) is None:
996 output_dict.setdefault(key, np.full((N,), None, dtype=object))
997 output_dict[key][i] = val
999 # Save average model costs for each alpha fidelity
1000 if track_costs:
1001 if model_fidelity is not None and output_dict.get('model_cost') is not None:
1002 alpha_costs = {}
1003 for i, cost in enumerate(output_dict['model_cost']):
1004 alpha_costs.setdefault(MultiIndex(model_fidelity[i]), [])
1005 alpha_costs[MultiIndex(model_fidelity[i])].append(cost)
1006 for a, costs in alpha_costs.items():
1007 self.model_evals.setdefault(a, 0)
1008 self.model_costs.setdefault(a, 0.0)
1009 num_evals_prev = self.model_evals.get(a)
1010 num_evals_new = len(costs)
1011 prev_avg = self.model_costs.get(a)
1012 costs = np.nan_to_num(costs, nan=prev_avg)
1013 new_avg = (np.sum(costs) + prev_avg * num_evals_prev) / (num_evals_prev + num_evals_new)
1014 self.model_evals[a] += num_evals_new
1015 self.model_costs[a] = float(new_avg)
1017 # Reshape loop dimensions to match the original input shape
1018 output_dict = format_outputs(output_dict, loop_shape)
1020 for var in self.outputs:
1021 if var.compression is not None:
1022 for field in var.compression.fields:
1023 if field not in output_dict:
1024 self.logger.warning(f"Model return missing field '{field}' for output variable '{var}'. "
1025 f"This may indicate an error during model evaluation. Returning NaNs...")
1026 output_dict.setdefault(field, np.full((N,), np.nan))
1027 elif var.name not in output_dict:
1028 self.logger.warning(f"Model return missing output variable '{var.name}'. This may indicate "
1029 f"an error during model evaluation. Returning NaNs...")
1030 output_dict[var.name] = np.full((N,), np.nan)
1032 # Return the output dictionary and any errors
1033 if errors:
1034 output_dict['errors'] = errors
1035 return output_dict
1037 def predict(self, inputs: dict | Dataset,
1038 use_model: Literal['best', 'worst'] | tuple = None,
1039 model_dir: str | Path = None,
1040 index_set: Literal['train', 'test'] | IndexSet = 'test',
1041 misc_coeff: MiscTree = None,
1042 incremental: bool = False,
1043 executor: Executor = None,
1044 **kwds) -> Dataset:
1045 """Evaluate the MISC surrogate approximation at new inputs `x`.
1047 !!! Note "Using the underlying model"
1048 By default this will predict the MISC surrogate approximation; all inputs are assumed to be in a compressed
1049 and normalized form. If the component does not have a surrogate (i.e. it is analytical), then the inputs
1050 will be converted to model form and the underlying model will be called in place. If you instead want to
1051 override the surrogate, passing `use_model` will call the underlying model directly. In that case, the
1052 inputs should be passed in already in model form (i.e. full fields, denormalized).
1054 :param inputs: `dict` of input arrays for each variable input
1055 :param use_model: 'best'=high-fidelity, 'worst'=low-fidelity, tuple=a specific `alpha`, None=surrogate (default)
1056 :param model_dir: directory to save output files if `use_model` is specified, ignored otherwise
1057 :param index_set: the active index set, defaults to `self.active_set` if `'train'` or both
1058 `self.active_set + self.candidate_set` if `'test'`
1059 :param misc_coeff: the data structure holding the MISC coefficients to use, which defaults to the
1060 training or testing coefficients depending on the `index_set` parameter.
1061 :param incremental: a special flag to use if the provided `index_set` is an incremental update to the active
1062 index set. A temporary copy of the internal `misc_coeff` data structure will be updated
1063 and used to incorporate the new indices.
1064 :param executor: executor for parallel execution if the model is not vectorized (optional), will use the
1065 executor for looping over MISC coefficients if evaluating the surrogate rather than the model
1066 :param kwds: additional keyword arguments to pass to the model (if using the underlying model)
1067 :returns: the surrogate approximation of the model (or the model return itself if `use_model`)
1068 """
1069 # Use raw model inputs/outputs
1070 if use_model is not None:
1071 outputs = self.call_model(inputs, model_fidelity=use_model, output_path=model_dir, executor=executor,**kwds)
1072 return {str(var): outputs[var] for var in outputs}
1074 # Convert inputs/outputs to/from model if no surrogate (i.e. analytical models)
1075 if not self.has_surrogate:
1076 field_coords = {f'{var}{COORDS_STR_ID}':
1077 self.model_kwargs.get(f'{var}{COORDS_STR_ID}', kwds.get(f'{var}{COORDS_STR_ID}', None))
1078 for var in self.inputs}
1079 inputs, field_coords = to_model_dataset(inputs, self.inputs, del_latent=True, **field_coords)
1080 field_coords.update(kwds)
1081 outputs = self.call_model(inputs, model_fidelity=use_model or 'best', output_path=model_dir,
1082 executor=executor, **field_coords)
1083 outputs, _ = to_surrogate_dataset(outputs, self.outputs, del_fields=True, **field_coords)
1084 return {str(var): outputs[var] for var in outputs}
1086 # Choose the correct index set and misc_coeff data structures
1087 if incremental:
1088 misc_coeff = copy.deepcopy(self.misc_coeff_train)
1089 self.update_misc_coeff(index_set, self.active_set, misc_coeff)
1090 index_set = self.active_set.union(index_set)
1091 else:
1092 index_set, misc_coeff = self._match_index_set(index_set, misc_coeff)
1094 # Format inputs for surrogate prediction (all scalars at this point, including latent coeffs)
1095 inputs, loop_shape = format_inputs(inputs) # {'x': (N,)}
1096 outputs = {}
1098 # Handle prediction with empty active set (return nan)
1099 if len(index_set) == 0:
1100 self.logger.warning(f"Component '{self.name}' has an empty active set. "
1101 f"Has the surrogate been trained yet? Returning NaNs...")
1102 for var in self.outputs:
1103 outputs[var.name] = np.full(loop_shape, np.nan)
1104 return outputs
1106 y_vars = self._surrogate_outputs() # Only request this component's specified outputs (ignore all extras)
1108 # Combination technique MISC surrogate prediction
1109 results = []
1110 coeffs = []
1111 for alpha, beta in index_set:
1112 comb_coeff = misc_coeff[alpha, beta]
1113 if np.abs(comb_coeff) > 0:
1114 coeffs.append(comb_coeff)
1115 args = (self.misc_states.get((alpha, beta)),
1116 self.get_training_data(alpha, beta, y_vars=y_vars, cached=True))
1118 results.append(self.interpolator.predict(inputs, *args) if executor is None else
1119 executor.submit(self.interpolator.predict, inputs, *args))
1121 if executor is not None:
1122 wait(results, timeout=None, return_when=ALL_COMPLETED)
1123 results = [future.result() for future in results]
1125 for coeff, interp_pred in zip(coeffs, results):
1126 for var, arr in interp_pred.items():
1127 if outputs.get(var) is None:
1128 outputs[str(var)] = coeff * arr
1129 else:
1130 outputs[str(var)] += coeff * arr
1132 return format_outputs(outputs, loop_shape)
1134 def update_misc_coeff(self, new_indices: IndexSet, index_set: Literal['test', 'train'] | IndexSet = 'train',
1135 misc_coeff: MiscTree = None):
1136 """Update MISC coefficients incrementally resulting from the addition of new indices to an index set.
1138 !!! Warning "Incremental updates"
1139 This function is used to update the MISC coefficients stored in `misc_coeff` after adding new indices
1140 to the given `index_set`. If a custom `index_set` or `misc_coeff` are provided, the user is responsible
1141 for ensuring the data structures are consistent. Since this is an incremental update, this means all
1142 existing coefficients for every index in `index_set` should be precomputed and stored in `misc_coeff`.
1144 :param new_indices: a set of $(\\alpha, \\beta)$ tuples that are being added to the `index_set`
1145 :param index_set: the active index set, defaults to `self.active_set` if `'train'` or both
1146 `self.active_set + self.candidate_set` if `'test'`
1147 :param misc_coeff: the data structure holding the MISC coefficients to update, which defaults to the
1148 training or testing coefficients depending on the `index_set` parameter. This data structure
1149 is modified in place.
1150 """
1151 index_set, misc_coeff = self._match_index_set(index_set, misc_coeff)
1153 for new_alpha, new_beta in new_indices:
1154 new_ind = np.array(new_alpha + new_beta)
1156 # Update all existing/new coefficients if they are a distance of [0, 1] "below" the new index
1157 # Note that new indices can only be [0, 1] away from themselves -- not any other new indices
1158 for old_alpha, old_beta in itertools.chain(index_set, [(new_alpha, new_beta)]):
1159 old_ind = np.array(old_alpha + old_beta)
1160 diff = new_ind - old_ind
1161 if np.all(np.isin(diff, [0, 1])):
1162 if misc_coeff.get((old_alpha, old_beta)) is None:
1163 misc_coeff[old_alpha, old_beta] = 0
1164 misc_coeff[old_alpha, old_beta] += (-1) ** int(np.sum(np.abs(diff)))
1166 def activate_index(self, alpha: MultiIndex, beta: MultiIndex, model_dir: str | Path = None,
1167 executor: Executor = None, weight_fcns: dict[str, callable] | Literal['pdf'] | None = 'pdf'):
1168 """Add a multi-index to the active set and all neighbors to the candidate set.
1170 !!! Warning
1171 The user of this function is responsible for ensuring that the index set maintains downward-closedness.
1172 That is, only activate indices that are neighbors of the current active set.
1174 :param alpha: A multi-index specifying model fidelity
1175 :param beta: A multi-index specifying surrogate fidelity
1176 :param model_dir: Directory to save model output files
1177 :param executor: Executor for parallel execution of model on training data if the model is not vectorized
1178 :param weight_fcns: Dictionary of weight functions for each input variable (defaults to the variable PDFs);
1179 each function should be callable as `fcn(x: np.ndarray) -> np.ndarray`, where the input
1180 is an array of normalized input data and the output is an array of weights. If None, then
1181 no weighting is applied.
1182 """
1183 if (alpha, beta) in self.active_set:
1184 self.logger.warning(f'Multi-index {(alpha, beta)} is already in the active index set. Ignoring...')
1185 return
1186 if (alpha, beta) not in self.candidate_set and (sum(alpha) + sum(beta)) > 0:
1187 # Can only activate the initial index (0, 0, ... 0) without it being in the candidate set
1188 self.logger.warning(f'Multi-index {(alpha, beta)} is not a neighbor of the active index set, so it '
1189 f'cannot be activated. Please only add multi-indices from the candidate set. '
1190 f'Ignoring...')
1191 return
1193 # Collect all neighbor candidate indices; sort by largest model cost first
1194 neighbors = self._neighbors(alpha, beta, forward=True)
1195 indices = list(itertools.chain([(alpha, beta)] if (alpha, beta) not in self.candidate_set else [], neighbors))
1196 indices.sort(key=lambda ele: self.model_costs.get(ele[0], sum(ele[0])), reverse=True)
1198 # Refine and collect all new model inputs (i.e. training points) requested by the new candidates
1199 alpha_list = [] # keep track of model fidelities
1200 design_list = [] # keep track of training data coordinates/locations/indices
1201 model_inputs = {} # concatenate all model inputs
1202 field_coords = {f'{var}{COORDS_STR_ID}': self.model_kwargs.get(f'{var}{COORDS_STR_ID}', None)
1203 for var in self.inputs}
1204 domains = self.inputs.get_domains()
1206 if weight_fcns == 'pdf':
1207 weight_fcns = self.inputs.get_pdfs()
1209 for a, b in indices:
1210 if ((a, b[:len(self.data_fidelity)] + (0,) * len(self.surrogate_fidelity)) in
1211 self.active_set.union(self.candidate_set)):
1212 # Don't refine training data if only updating surrogate fidelity indices
1213 # Training data is the same for all surrogate fidelity indices, given constant data fidelity
1214 design_list.append([])
1215 continue
1217 design_coords, design_pts = self.training_data.refine(a, b[:len(self.data_fidelity)],
1218 domains, weight_fcns)
1219 design_pts, fc = to_model_dataset(design_pts, self.inputs, del_latent=True, **field_coords)
1221 # Remove duplicate (alpha, coords) pairs -- so you don't evaluate the model twice for the same input
1222 i = 0
1223 del_idx = []
1224 for other_design in design_list:
1225 for other_coord in other_design:
1226 for j, curr_coord in enumerate(design_coords):
1227 if curr_coord == other_coord and a == alpha_list[i] and j not in del_idx:
1228 del_idx.append(j)
1229 i += 1
1230 design_coords = [design_coords[j] for j in range(len(design_coords)) if j not in del_idx]
1231 design_pts = {var: np.delete(arr, del_idx, axis=0) for var, arr in design_pts.items()}
1233 alpha_list.extend([tuple(a)] * len(design_coords))
1234 design_list.append(design_coords)
1235 field_coords.update(fc)
1236 for var in design_pts:
1237 model_inputs[var] = design_pts[var] if model_inputs.get(var) is None else (
1238 np.concatenate((model_inputs[var], design_pts[var]), axis=0))
1240 # Evaluate model at designed training points
1241 if len(alpha_list) > 0:
1242 self.logger.info(f"Running {len(alpha_list)} total model evaluations for component "
1243 f"'{self.name}' new candidate indices: {indices}...")
1244 model_outputs = self.call_model(model_inputs, model_fidelity=alpha_list, output_path=model_dir,
1245 executor=executor, track_costs=True, **field_coords)
1246 self.logger.info(f"Model evaluations complete for component '{self.name}'.")
1247 errors = model_outputs.pop('errors', {})
1248 else:
1249 self._model_start_time = -1.0
1250 self._model_end_time = -1.0
1252 # Unpack model outputs and update states
1253 start_idx = 0
1254 for i, (a, b) in enumerate(indices):
1255 num_train_pts = len(design_list[i])
1256 end_idx = start_idx + num_train_pts # Ensure loop dim of 1 gets its own axis (might have been squeezed)
1258 if num_train_pts > 0:
1259 yi_dict = {var: arr[np.newaxis, ...] if len(alpha_list) == 1 and arr.shape[0] != 1 else
1260 arr[start_idx:end_idx, ...] for var, arr in model_outputs.items()}
1262 # Check for errors and store
1263 err_coords = []
1264 err_list = []
1265 for idx in list(errors.keys()):
1266 if idx < end_idx:
1267 err_info = errors.pop(idx)
1268 err_info['index'] = idx - start_idx
1269 err_coords.append(design_list[i][idx - start_idx])
1270 err_list.append(err_info)
1271 if len(err_list) > 0:
1272 self.logger.warning(f"Model errors occurred while adding candidate ({a}, {b}) for component "
1273 f"{self.name}. Leaving NaN values in training data...")
1274 self.training_data.set_errors(a, b[:len(self.data_fidelity)], err_coords, err_list)
1276 # Compress field quantities and normalize
1277 yi_dict, y_vars = to_surrogate_dataset(yi_dict, self.outputs, del_fields=False, **field_coords)
1279 # Store training data, computational cost, and new interpolator state
1280 self.training_data.set(a, b[:len(self.data_fidelity)], design_list[i], yi_dict)
1281 self.training_data.impute_missing_data(a, b[:len(self.data_fidelity)])
1283 else:
1284 y_vars = self._surrogate_outputs()
1286 self.misc_costs[a, b] = num_train_pts
1287 self.misc_states[a, b] = self.interpolator.refine(b[len(self.data_fidelity):],
1288 self.training_data.get(a, b[:len(self.data_fidelity)],
1289 y_vars=y_vars, skip_nan=True),
1290 self.misc_states.get((alpha, beta)),
1291 domains)
1292 start_idx = end_idx
1294 # Move to the active index set
1295 s = set()
1296 s.add((alpha, beta))
1297 self.update_misc_coeff(IndexSet(s), index_set='train')
1298 if (alpha, beta) in self.candidate_set:
1299 self.candidate_set.remove((alpha, beta))
1300 else:
1301 # Only for initial index which didn't come from the candidate set
1302 self.update_misc_coeff(IndexSet(s), index_set='test')
1303 self.active_set.update(s)
1305 self.update_misc_coeff(neighbors, index_set='test') # neighbors will only ever pass through here once
1306 self.candidate_set.update(neighbors)
1308 def gradient(self, inputs: dict | Dataset,
1309 index_set: Literal['train', 'test'] | IndexSet = 'test',
1310 misc_coeff: MiscTree = None,
1311 derivative: Literal['first', 'second'] = 'first',
1312 executor: Executor = None) -> Dataset:
1313 """Evaluate the Jacobian or Hessian of the MISC surrogate approximation at new `inputs`, i.e.
1314 the first or second derivatives, respectively.
1316 :param inputs: `dict` of input arrays for each variable input
1317 :param index_set: the active index set, defaults to `self.active_set` if `'train'` or both
1318 `self.active_set + self.candidate_set` if `'test'`
1319 :param misc_coeff: the data structure holding the MISC coefficients to use, which defaults to the
1320 training or testing coefficients depending on the `index_set` parameter.
1321 :param derivative: whether to compute the first or second derivative (i.e. Jacobian or Hessian)
1322 :param executor: executor for looping over MISC coefficients (optional)
1323 :returns: a `dict` of the Jacobian or Hessian of the surrogate approximation for each output variable
1324 """
1325 if not self.has_surrogate:
1326 self.logger.warning("No surrogate model available for gradient computation.")
1327 return None
1329 index_set, misc_coeff = self._match_index_set(index_set, misc_coeff)
1330 inputs, loop_shape = format_inputs(inputs) # {'x': (N,)}
1331 outputs = {}
1333 if len(index_set) == 0:
1334 for var in self.outputs:
1335 outputs[var] = np.full(loop_shape, np.nan)
1336 return outputs
1337 y_vars = self._surrogate_outputs()
1339 # Combination technique MISC gradient prediction
1340 results = []
1341 coeffs = []
1342 for alpha, beta in index_set:
1343 comb_coeff = misc_coeff[alpha, beta]
1344 if np.abs(comb_coeff) > 0:
1345 coeffs.append(comb_coeff)
1346 func = self.interpolator.gradient if derivative == 'first' else self.interpolator.hessian
1347 args = (self.misc_states.get((alpha, beta)),
1348 self.get_training_data(alpha, beta, y_vars=y_vars, cached=True))
1350 results.append(func(inputs, *args) if executor is None else executor.submit(func, inputs, *args))
1352 if executor is not None:
1353 wait(results, timeout=None, return_when=ALL_COMPLETED)
1354 results = [future.result() for future in results]
1356 for coeff, interp_pred in zip(coeffs, results):
1357 for var, arr in interp_pred.items():
1358 if outputs.get(var) is None:
1359 outputs[str(var)] = coeff * arr
1360 else:
1361 outputs[str(var)] += coeff * arr
1363 return format_outputs(outputs, loop_shape)
1365 def hessian(self, *args, **kwargs):
1366 """Alias for `Component.gradient(*args, derivative='second', **kwargs)`."""
1367 return self.gradient(*args, derivative='second', **kwargs)
1369 def model_kwarg_requested(self, kwarg_name: str) -> bool:
1370 """Return whether the underlying component model requested this `kwarg_name`. Special kwargs include:
1372 - `output_path` — a save directory created by `amisc` will be passed to the model for saving model output files.
1373 - `alpha` — a tuple or list of model fidelity indices will be passed to the model to adjust fidelity.
1374 - `input_vars` — a list of `Variable` objects will be passed to the model for input variable information.
1375 - `output_vars` — a list of `Variable` objects will be passed to the model for output variable information.
1377 :param kwarg_name: the argument to check for in the underlying component model's function signature kwargs
1378 :returns: whether the component model requests this `kwarg` argument
1379 """
1380 signature = inspect.signature(self.model)
1381 for param in signature.parameters.values():
1382 if param.name == kwarg_name and param.default != param.empty:
1383 return True
1384 return False
1386 def set_logger(self, log_file: str | Path = None, stdout: bool = None, logger: logging.Logger = None,
1387 level: int = logging.INFO):
1388 """Set a new `logging.Logger` object.
1390 :param log_file: log to file (if provided)
1391 :param stdout: whether to connect the logger to console (defaults to whatever is currently set or False)
1392 :param logger: the logging object to use (if None, then a new logger is created; this will override
1393 the `log_file` and `stdout` arguments if set)
1394 :param level: the logging level to set (default is `logging.INFO`)
1395 """
1396 if stdout is None:
1397 stdout = False
1398 if self._logger is not None:
1399 for handler in self._logger.handlers:
1400 if isinstance(handler, logging.StreamHandler):
1401 stdout = True
1402 break
1403 self._logger = logger or get_logger(self.name, log_file=log_file, stdout=stdout, level=level)
1405 def update_model(self, new_model: callable = None, model_kwargs: dict = None, **kwargs):
1406 """Update the underlying component model or its kwargs."""
1407 if new_model is not None:
1408 self.model = new_model
1409 new_kwargs = self.model_kwargs.data
1410 new_kwargs.update(model_kwargs or {})
1411 new_kwargs.update(kwargs)
1412 self.model_kwargs = new_kwargs
1414 def get_cost(self, alpha: MultiIndex, beta: MultiIndex) -> int:
1415 """Return the total cost (i.e. number of model evaluations) required to add $(\\alpha, \\beta)$ to the
1416 MISC approximation.
1418 :param alpha: A multi-index specifying model fidelity
1419 :param beta: A multi-index specifying surrogate fidelity
1420 :returns: the total number of model evaluations required for adding this multi-index to the MISC approximation
1421 """
1422 try:
1423 return self.misc_costs[alpha, beta]
1424 except Exception:
1425 return 0
1427 def get_model_timestamps(self):
1428 """Return a tuple with the (start, end) timestamps for the most recent call to `call_model`. This
1429 is useful for tracking the duration of model evaluations. Will return (None, None) if no model has been called.
1430 """
1431 if self._model_start_time < 0 or self._model_end_time < 0:
1432 return None, None
1433 else:
1434 return self._model_start_time, self._model_end_time
1436 @staticmethod
1437 def is_downward_closed(indices: IndexSet) -> bool:
1438 """Return if a list of $(\\alpha, \\beta)$ multi-indices is downward-closed.
1440 MISC approximations require a downward-closed set in order to use the combination-technique formula for the
1441 coefficients (as implemented by `Component.update_misc_coeff()`).
1443 !!! Example
1444 The list `[( (0,), (0,) ), ( (1,), (0,) ), ( (1,), (1,) )]` is downward-closed. You can visualize this as
1445 building a stack of cubes: in order to place a cube, all adjacent cubes must be present (does the logo
1446 make sense now?).
1448 :param indices: `IndexSet` of (`alpha`, `beta`) multi-indices
1449 :returns: whether the set of indices is downward-closed
1450 """
1451 # Iterate over every multi-index
1452 for alpha, beta in indices:
1453 # Every smaller multi-index must also be included in the indices list
1454 sub_sets = [np.arange(tuple(alpha + beta)[i] + 1) for i in range(len(alpha) + len(beta))]
1455 for ele in itertools.product(*sub_sets):
1456 tup = (MultiIndex(ele[:len(alpha)]), MultiIndex(ele[len(alpha):]))
1457 if tup not in indices:
1458 return False
1459 return True
1461 def clear(self):
1462 """Clear the component of all training data, index sets, and MISC states."""
1463 self.active_set.clear()
1464 self.candidate_set.clear()
1465 self.misc_states.clear()
1466 self.misc_costs.clear()
1467 self.misc_coeff_train.clear()
1468 self.misc_coeff_test.clear()
1469 self.model_costs.clear()
1470 self.model_evals.clear()
1471 self.training_data.clear()
1472 self._model_start_time = -1.0
1473 self._model_end_time = -1.0
1474 self.clear_cache()
1476 def serialize(self, keep_yaml_objects: bool = False, serialize_args: dict[str, tuple] = None,
1477 serialize_kwargs: dict[str: dict] = None) -> dict:
1478 """Convert to a `dict` with only standard Python types as fields and values.
1480 :param keep_yaml_objects: whether to keep `Variable` or other yaml serializable objects instead of
1481 also serializing them (default is False)
1482 :param serialize_args: additional arguments to pass to the `serialize` method of each `Component` attribute;
1483 specify as a `dict` of attribute names to tuple of arguments to pass
1484 :param serialize_kwargs: additional keyword arguments to pass to the `serialize` method of each
1485 `Component` attribute
1486 :returns: a `dict` representation of the `Component` object
1487 """
1488 serialize_args = serialize_args or dict()
1489 serialize_kwargs = serialize_kwargs or dict()
1490 d = {}
1491 for key, value in self.__dict__.items():
1492 if value is not None and not key.startswith('_'):
1493 if key == 'serializers':
1494 # Update the serializers
1495 serializers = self._validate_serializers({k: type(getattr(self, k)) for k in value.keys()})
1496 d[key] = {k: (v.obj if keep_yaml_objects else v.serialize()) for k, v in serializers.items()}
1497 elif key in ['inputs', 'outputs'] and not keep_yaml_objects:
1498 d[key] = value.serialize(**serialize_kwargs.get(key, {}))
1499 elif key == 'model' and not keep_yaml_objects:
1500 d[key] = YamlSerializable(obj=value).serialize()
1501 elif key in ['data_fidelity', 'surrogate_fidelity', 'model_fidelity']:
1502 if len(value) > 0:
1503 d[key] = str(value)
1504 elif key in ['active_set', 'candidate_set']:
1505 if len(value) > 0:
1506 d[key] = value.serialize()
1507 elif key in ['misc_costs', 'misc_coeff_train', 'misc_coeff_test', 'misc_states']:
1508 if len(value) > 0:
1509 d[key] = value.serialize(keep_yaml_objects=keep_yaml_objects)
1510 elif key in ['model_costs']:
1511 if len(value) > 0:
1512 d[key] = {str(k): float(v) for k, v in value.items()}
1513 elif key in ['model_evals']:
1514 if len(value) > 0:
1515 d[key] = {str(k): int(v) for k, v in value.items()}
1516 elif key in ComponentSerializers.__annotations__.keys():
1517 if key in ['training_data', 'interpolator'] and not self.has_surrogate:
1518 continue
1519 else:
1520 d[key] = value.serialize(*serialize_args.get(key, ()), **serialize_kwargs.get(key, {}))
1521 else:
1522 d[key] = value
1523 return d
1525 @classmethod
1526 def deserialize(cls, serialized_data: dict, search_paths: list[str | Path] = None,
1527 search_keys: list[str] = None) -> Component:
1528 """Return a `Component` from `data`. Let pydantic handle field validation and conversion. If any component
1529 data has been saved to file and the save file doesn't exist, then the loader will search for the file
1530 in the current working directory and any additional search paths provided.
1532 :param serialized_data: the serialized data to construct the object from
1533 :param search_paths: paths to try and find any save files (i.e. if they moved since they were serialized),
1534 will always search in the current working directory by default
1535 :param search_keys: keys to search for save files in each component (default is all keys in
1536 [`ComponentSerializers`][amisc.component.ComponentSerializers], in addition to variable
1537 inputs and outputs)
1538 """
1539 if isinstance(serialized_data, Component):
1540 return serialized_data
1541 elif callable(serialized_data):
1542 # try to construct a component from a raw function (assume data fidelity is (2,) for each inspected input)
1543 return cls(serialized_data, data_fidelity=(2,) * len(_inspect_function(serialized_data)[0]))
1545 search_paths = search_paths or []
1546 search_keys = search_keys or []
1547 search_keys.extend(ComponentSerializers.__annotations__.keys())
1548 comp = serialized_data
1550 for key in search_keys:
1551 if (filename := comp.get(key, None)) is not None:
1552 comp[key] = search_for_file(filename, search_paths=search_paths) # will ret original str if not found
1554 for key in ['inputs', 'outputs']:
1555 for var in comp.get(key, []):
1556 if isinstance(var, dict):
1557 if (compression := var.get('compression', None)) is not None:
1558 var['compression'] = search_for_file(compression, search_paths=search_paths)
1560 return cls(**comp)
1562 @staticmethod
1563 def _yaml_representer(dumper: yaml.Dumper, comp: Component) -> yaml.MappingNode:
1564 """Convert a single `Component` object (`data`) to a yaml MappingNode (i.e. a `dict`)."""
1565 save_path, save_file = _get_yaml_path(dumper)
1566 serialize_kwargs = {}
1567 for key, serializer in comp.serializers.items():
1568 if issubclass(serializer.obj, PickleSerializable):
1569 filename = save_path / f'{save_file}_{comp.name}_{key}.pkl'
1570 serialize_kwargs[key] = {'save_path': save_path / filename}
1571 return dumper.represent_mapping(Component.yaml_tag, comp.serialize(serialize_kwargs=serialize_kwargs,
1572 keep_yaml_objects=True))
1574 @staticmethod
1575 def _yaml_constructor(loader: yaml.Loader, node):
1576 """Convert the `!Component` tag in yaml to a `Component` object."""
1577 # Add a file search path in the same directory as the yaml file being loaded from
1578 save_path, save_file = _get_yaml_path(loader)
1579 if isinstance(node, yaml.SequenceNode):
1580 return [ele if isinstance(ele, Component) else Component.deserialize(ele, search_paths=[save_path])
1581 for ele in loader.construct_sequence(node, deep=True)]
1582 elif isinstance(node, yaml.MappingNode):
1583 return Component.deserialize(loader.construct_mapping(node, deep=True), search_paths=[save_path])
1584 else:
1585 raise NotImplementedError(f'The "{Component.yaml_tag}" yaml tag can only be used on a yaml sequence or '
1586 f'mapping, not a "{type(node)}".')