Coverage for src/amisc/variable.py: 90%
421 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
1"""Provides an object-oriented interface for model inputs/outputs, random variables, scalars, and field quantities.
3Includes:
5- `Variable` — an object that stores information about a variable and includes methods for sampling, pdf evaluation,
6 normalization, compression, loading from file, etc. Variables can mostly be treated as strings
7 that have some additional information and utilities attached to them.
8- `VariableList` — a container for `Variables` that provides dict-like access of `Variables` by `name` along with normal
9 indexing and slicing.
11The preferred serialization of `Variable` and `VariableList` is to/from yaml. This is done by default with the
12`!Variable` and `!VariableList` yaml tags.
13"""
14from __future__ import annotations
16import ast
17import random
18import string
19from collections import OrderedDict
20from pathlib import Path
21from typing import ClassVar, Optional, Union
23import numpy as np
24import yaml
25from numpy.typing import ArrayLike
26from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator
28from amisc.compression import Compression
29from amisc.distribution import Distribution, LogUniform, Normal, Uniform
30from amisc.serialize import Serializable
31from amisc.transform import Minmax, Transform, Zscore
32from amisc.typing import LATENT_STR_ID, CompressionData
33from amisc.utils import _get_yaml_path, _inspect_assignment, search_for_file
35__all__ = ['Variable', 'VariableList']
36_TransformLike = Union[str, Transform, list[str | Transform]] # something that can be converted to a Transform
39class Variable(BaseModel, Serializable):
40 """Object for storing information about variables and providing methods for pdf evaluation, sampling, etc.
41 All fields will undergo pydantic validation and conversion to the correct types.
43 A simple variable object can be created with `var = Variable()`. All initialization options are optional and will
44 be given good defaults. You should probably at the very least give a memorable `name` and a `domain`. Variables
45 can mostly be treated as strings with some extra information/utilities attached.
47 With the `pyyaml` library installed, all `Variable` objects can be saved or loaded directly from a `.yml` file by
48 using the `!Variable` yaml tag (which is loaded by default with `amisc`).
50 - Use `Variable.distribution` to specify PDFs, such as for random variables. See the `Distribution` classes.
51 - Use `Variable.norm` to specify a transformed-space that is more amenable to surrogate construction
52 (e.g. mapping to the range (0,1)). See the `Transform` classes.
53 - Use `Variable.compression` to specify high-dimensional, coordinate-based field quantities,
54 such as from the output of many simulation software programs. See the `Compression` classes.
55 - Use `Variable.category` as an additional layer for using Variable's in different ways (e.g. set a "calibration"
56 category for Bayesian inference).
58 !!! Example
59 ```python
60 # Random variable
61 temp = Variable(name='T', description='Temperature', units='K', distribution='Uniform(280, 320)')
62 samples = temp.sample(100)
63 pdf = temp.pdf(samples)
65 # Field quantity
66 vel = Variable(name='u', description='Velocity', units='m/s', compression={'fields': ['ux', 'uy', 'uz']})
67 vel_data = ... # from a simulation
68 reduced_vel = vel.compress(vel_data)
69 ```
71 !!! Warning
72 Changes to collection fields (like `Variable.norm`) should completely reassign the _whole_
73 collection to trigger the correct validation, rather than editing particular entries. For example, reassign
74 `norm=['log', 'linear(2, 2)']` rather than editing norm via `norm.append('linear(2, 2)')`.
76 :ivar name: an identifier for the variable, can compare variables directly with strings for indexing purposes
77 :ivar nominal: a typical value for this variable
78 :ivar description: a lengthier description of the variable
79 :ivar units: assumed units for the variable (if applicable)
80 :ivar category: an additional descriptor for how this variable is used, e.g. calibration, operating, design, etc.
81 :ivar tex: latex format for the variable, i.e. "$x_i$"
82 :ivar compression: specifies field quantities and links to relevant compression data
83 :ivar distribution: a string specifier of a probability distribution function (see the `Distribution` types)
84 :ivar domain: the explicit domain bounds of the variable (limits of where you expect to use it);
85 for field quantities, this is a list of domains for each latent dimension
86 :ivar norm: specifier of a map to a transformed-space for surrogate construction (see the `Transform` types)
87 """
88 yaml_tag: ClassVar[str] = u'!Variable'
89 model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True, validate_default=True)
91 name: Optional[str] = None
92 nominal: Optional[float] = None
93 description: Optional[str] = None
94 units: Optional[str] = None
95 category: Optional[str] = None
96 tex: Optional[str] = None
97 compression: Optional[str | dict | Compression] = None
98 distribution: Optional[str | Distribution] = None
99 domain: Optional[str | tuple[float, float] | list] = None
100 norm: Optional[_TransformLike] = None
102 def __init__(self, /, name=None, **kwargs):
103 # Try to set the variable name if instantiated as "x = Variable()"
104 if name is None:
105 name = _inspect_assignment('Variable')
106 name = name or "X_" + "".join(random.choices(string.digits, k=3))
107 super().__init__(name=name, **kwargs)
109 @field_validator('tex')
110 @classmethod
111 def _validate_tex(cls, tex: str) -> str | None:
112 if tex is None:
113 return tex
114 if not tex.startswith('$'):
115 tex = rf'${tex}'
116 if not tex[-1] == '$':
117 tex = rf'{tex}$'
118 return tex
120 @field_validator('compression')
121 @classmethod
122 def _validate_compression(cls, compression: str | dict | Compression, info: ValidationInfo) -> Compression | None:
123 if compression is None:
124 return compression
125 elif isinstance(compression, str):
126 return Compression.deserialize(compression)
127 elif isinstance(compression, dict):
128 compression['fields'] = compression.get('fields', None) or [info.data['name']]
129 return Compression.from_dict(compression)
130 else:
131 compression.fields = compression.fields or [info.data['name']]
132 return compression
134 @field_validator('distribution')
135 @classmethod
136 def _validate_dist(cls, dist: str | Distribution) -> Distribution | None:
137 if dist is None:
138 return dist
139 if isinstance(dist, Distribution):
140 return dist
141 elif isinstance(dist, str):
142 return Distribution.from_string(dist)
143 else:
144 raise ValueError(f'Cannot convert {dist} to a Distribution object.')
146 @field_validator('domain')
147 @classmethod
148 def _validate_domain(cls, domain: list | tuple | str, info: ValidationInfo) -> tuple | list | None:
149 """Try to extract the domain from the distribution if not provided, or convert from a string.
150 Returns a list of domains for each latent dimension if this is a field quantity with compression.
151 """
152 if domain is None:
153 if dist := info.data['distribution']:
154 domain = dist.domain()
155 elif compression := info.data['compression']:
156 if (ranges := compression.estimate_latent_ranges()) is not None:
157 domain = [tuple(map(float, val)) for val in ranges]
158 elif isinstance(domain, str):
159 domain = tuple(ast.literal_eval(domain.strip()))
160 elif isinstance(domain, list):
161 if len(domain) == 2 and isinstance(domain[0], float | int) and isinstance(domain[1], float | int):
162 domain = tuple(domain) # allow lists of 2 elements to be interpreted as a scalar variable domain
163 else:
164 domain = [tuple(ast.literal_eval(d.strip())) if isinstance(d, str) else d for d in domain] # field qty
166 if domain is None:
167 return domain
169 if isinstance(domain, list):
170 for d in domain:
171 assert isinstance(d, tuple) and len(d) == 2
172 assert d[1] > d[0], 'Domain must be specified as (lower_bound, upper_bound)'
173 else:
174 assert isinstance(domain, tuple) and len(domain) == 2
175 assert domain[1] > domain[0], 'Domain must be specified as (lower_bound, upper_bound)'
177 return domain
179 @field_validator('norm')
180 @classmethod
181 def _validate_norm(cls, norm: _TransformLike, info: ValidationInfo) -> list[Transform] | None:
182 if norm is None:
183 return norm
184 norm = Transform.from_string(norm)
186 # Set default values for minmax and zscore transforms
187 domain = info.data['domain']
188 normal_args = None
189 if dist := info.data['distribution']:
190 if isinstance(dist, Normal):
191 normal_args = dist.dist_args
192 for transform in norm:
193 if isinstance(transform, Minmax):
194 if domain and np.any(np.isnan(transform.transform_args[0:2])):
195 transform.update(lb=domain[0], ub=domain[1])
196 elif isinstance(transform, Zscore):
197 if normal_args and np.any(np.isnan(transform.transform_args)):
198 transform.update(mu=normal_args[0], std=normal_args[1])
200 return norm
202 def __getitem__(self, item):
203 return getattr(self, item)
205 def __setitem__(self, key, value):
206 setattr(self, key, value)
208 def __str__(self):
209 return self.name
211 def __repr__(self):
212 return self.__str__()
214 def __hash__(self):
215 """Allows variables to be used as keys in dictionaries and to be considered equal to their string
216 representations.
217 """
218 return hash(self.name)
220 def __eq__(self, other):
221 """Consider two `Variables` equal if they share the same string name
223 Also returns true when checking if this `Variable` is equal to a string by itself.
224 """
225 if isinstance(other, Variable):
226 return self.name == other.name
227 elif isinstance(other, str):
228 return self.name == other
229 else:
230 return False
232 def get_tex(self, units: bool = False, symbol: bool = True) -> str:
233 """Return a raw string that is well-formatted for plotting (with latex).
235 :param units: whether to include the units in the string
236 :param symbol: just latex symbol if true, otherwise the full description
237 :returns: the latex formatted string
238 """
239 s = (self.tex if symbol else self.description) or self.name
240 return r'{} [{}]'.format(s, self.units or '-') if units else r'{}'.format(s)
242 def get_nominal(self) -> float | list | None:
243 """Return the nominal value of the variable. Defaults to the mean for a normal distribution or the
244 center of the domain if `var.nominal` is not specified. Returns a list of nominal values for each latent
245 dimension if this is a field quantity with compression.
247 :returns: the nominal value(s)
248 """
249 nominal = self.nominal
250 if nominal is None:
251 if dist := self.distribution:
252 nominal = float(dist.nominal())
253 elif domain := self.get_domain():
254 nominal = [np.mean(d) for d in domain] if isinstance(domain, list) else float(np.mean(domain))
256 return nominal
258 def get_domain(self) -> tuple | list | None:
259 """Return a tuple of the defined domain of this variable. Returns a list of domains for each latent dimension
260 if this is a field quantity with compression.
262 :returns: the domain(s) of this variable
263 """
264 if self.domain is None:
265 return None
266 elif isinstance(self.domain, list):
267 return self.domain
268 elif self.compression is not None:
269 # Try to infer a list of domains from compression latent size
270 try:
271 return [self.domain] * self.compression.latent_size()
272 except Exception as e:
273 raise ValueError(f'Variables with `compression` data should return a list of domains, one '
274 f'for each latent coefficient. Could not infer domain for "{self.name}".') from e
275 else:
276 return self.domain
278 def sample_domain(self, shape: tuple | int) -> np.ndarray:
279 """Return an array of the given `shape` for uniform samples over the domain of this variable. Returns
280 samples for each latent dimension if this is a field quantity with compression.
282 Will always sample uniformly over the normalized surrogate domain if `norm` is specified, and will return
283 samples in the original unnormalized domain.
285 !!! Note
286 The last dim of the returned samples will be the latent space size for field quantities.
288 :param shape: the shape of samples to return
289 :returns: the random samples over the domain of the variable
290 """
291 if isinstance(shape, int):
292 shape = (shape, )
293 if domain := self.get_domain():
294 if isinstance(domain, list):
295 lb = np.atleast_1d([d[0] for d in domain])
296 ub = np.atleast_1d([d[1] for d in domain])
297 return np.random.rand(*shape, 1) * (ub - lb) + lb
298 else:
299 lb, ub = self.normalize(domain)
300 norm_samples = np.random.rand(*shape) * (ub - lb) + lb
301 return self.denormalize(norm_samples)
302 else:
303 raise RuntimeError(f'Variable "{self.name}" does not have a domain specified.')
305 def update_domain(self, domain: tuple[float, float] | list[tuple], override: bool = False):
306 """Update the domain of this variable by taking the minimum or maximum of the new domain with the current domain
307 for the lower and upper bounds, respectively. Will attempt to update the domain of each latent dimension
308 if this is a field quantity with compression. If the variable has a `Uniform` distribution, this will
309 update the distribution's bounds too.
311 :param domain: the new domain(s) to update with
312 :param override: will simply set the domain to the new values rather than update against the current domain;
313 (default `False`)
314 """
315 def _update_domain(domain, curr_domain):
316 lb, ub = domain
317 ret = (lb, ub) if override else (min(lb, curr_domain[0]) if curr_domain is not None else lb,
318 max(ub, curr_domain[1]) if curr_domain is not None else ub)
319 return tuple(map(float, ret))
321 curr_domain = self.get_domain()
322 if isinstance(domain, list):
323 if not isinstance(curr_domain, list):
324 curr_domain = [curr_domain] * len(domain)
325 self.domain = [_update_domain(d, curr_domain[i]) for i, d in enumerate(domain)]
326 elif isinstance(curr_domain, list):
327 if not isinstance(domain, list):
328 domain = [domain] * len(curr_domain)
329 self.domain = [_update_domain(d, curr_domain[i]) for i, d in enumerate(domain)]
330 else:
331 self.domain = _update_domain(domain, curr_domain)
332 if (dist := self.distribution) is not None and isinstance(dist, Uniform | LogUniform):
333 dist.dist_args = self.domain # keep Uniform dist in sync
335 def pdf(self, x: np.ndarray) -> np.ndarray:
336 """Compute the PDF of the Variable at the given `x` locations. Will just return one's if the variable
337 does not have a distribution.
339 :param x: locations to compute the PDF at
340 :returns: the PDF evaluations at `x`
341 """
342 if dist := self.distribution:
343 return dist.pdf(x)
344 else:
345 return np.ones(x.shape) # No pdf if no dist is specified
347 def sample(self, shape: tuple | int, nominal: float | np.ndarray = None) -> np.ndarray:
348 """Draw samples from this `Variable's` distribution. Just returns the nominal value of the given shape if
349 this `Variable` has no distribution.
351 :param shape: the shape of the returned samples
352 :param nominal: a nominal value to use if applicable (i.e. a center for relative, tolerance, or normal)
353 :returns: samples from the PDF of this `Variable's` distribution
354 """
355 if isinstance(shape, int):
356 shape = (shape, )
357 if nominal is None:
358 nominal = self.get_nominal()
360 if dist := self.distribution:
361 return dist.sample(shape, nominal)
362 else:
363 # Variable's with no distribution
364 if nominal is None:
365 raise ValueError(f'Cannot sample "{self.name}" with no dist or nominal value specified.')
366 elif isinstance(nominal, list | np.ndarray):
367 return np.ones(shape + (len(nominal),)) * np.atleast_1d(nominal) # For field quantities
368 else:
369 return np.ones(shape) * nominal
371 def normalize(self, values: ArrayLike, denorm: bool = False) -> ArrayLike | None:
372 """Normalize `values` based on this `Variable's` `norm` method(s). See `Transform` for available norm methods.
374 !!! Note
375 If this Variable's `self.norm` was specified as a list of norm methods, then each will be applied in
376 sequence in the original order (and in reverse for `denorm=True`). When `self.distribution` is involved in
377 the transforms (only for `minmax` and `zscore`), the `dist_args` will get normalized too at each
378 transform before applying the next transform.
380 :param values: the values to normalize (array-like)
381 :param denorm: whether to denormalize instead using the inverse of the original normalization method
382 :returns: the normalized (or unnormalized) values
383 """
384 if not self.norm or values is None:
385 return values
386 if dist := self.distribution:
387 normal_dist = isinstance(dist, Normal)
388 else:
389 normal_dist = False
391 def _normalize_single(values, transform, inverse, domain, dist_args):
392 """Do a single transform. Might need to override transform_args depending on the transform."""
393 transform_args = None
394 if isinstance(transform, Minmax) and domain:
395 transform_args = domain + transform.transform_args[2:] # Update minmax bounds
396 elif isinstance(transform, Zscore) and dist_args:
397 transform_args = dist_args # Update N(mu, std)
399 return transform.transform(values, inverse=inverse, transform_args=transform_args)
401 domain = self.get_domain() or ()
402 dist_args = self.distribution.dist_args if normal_dist else []
403 if isinstance(domain, list):
404 domain = () # For field quantities, domain is not used in normalization
406 if denorm:
407 # First, send domain and dist_args through the forward norm list (up until the last norm)
408 hyperparams = [np.hstack((domain, dist_args))]
409 for i, transform in enumerate(self.norm):
410 domain, dist_args = tuple(hyperparams[i][:2]), tuple(hyperparams[i][2:])
411 hyperparams.append(_normalize_single(hyperparams[i], transform, False, domain, dist_args))
413 # Now denormalize in reverse
414 hp_idx = -2
415 for transform in reversed(self.norm):
416 domain, dist_args = tuple(hyperparams[hp_idx][:2]), tuple(hyperparams[hp_idx][2:])
417 values = _normalize_single(values, transform, True, domain, dist_args)
418 hp_idx -= 1
419 else:
420 # Normalize values and hyperparams through the forward norm list
421 hyperparams = np.hstack((domain, dist_args))
422 for transform in self.norm:
423 domain, dist_args = tuple(hyperparams[:2]), tuple(hyperparams[2:])
424 values = _normalize_single(values, transform, denorm, domain, dist_args)
425 hyperparams = _normalize_single(hyperparams, transform, denorm, domain, dist_args)
427 return values
429 def denormalize(self, values):
430 """Alias for `normalize(denorm=True)`. See `normalize` for more details."""
431 return self.normalize(values, denorm=True)
433 def compress(self, values: CompressionData, coords: np.ndarray = None,
434 reconstruct: bool = False) -> CompressionData:
435 """Compress or reconstruct field quantity values using this `Variable's` compression info.
437 !!! Note "Specifying compression values"
438 If only one field quantity is associated with this variable, then
439 specify `values` as `dict(coords=..., name=...)` for this Variable's `name`. If `coords` is not specified,
440 then this assumes the locations are the same as the reconstruction data (and skips interpolation).
442 !!! Info "Compression workflow"
443 Generally, compression follows `interpolate -> normalize -> compress` to take raw values into the compressed
444 "latent" space. The interpolation step is required to make sure `values` align with the coordinates used
445 when building the compression map in the first place (such as through SVD).
447 :param values: a `dict` with a key for each field qty of shape `(..., qty.shape)` and a `coords` key of shape
448 `(qty.shape, dim)` that gives the coordinates of each point. Only a single `latent` key should
449 be given instead if `reconstruct=True`.
450 :param coords: the coordinates of each point in `values` if `values` did not contain a `coords` key;
451 defaults to the compression grid coordinates
452 :param reconstruct: whether to reconstruct values instead of compress
453 :returns: the compressed values with key `latent` and shape `(..., latent_size)`; if `reconstruct=True`,
454 then the reconstructed values with shape `(..., qty.shape)` for each `qty` key are returned.
455 The return `dict` also has a `coords` key with shape `(qty.shape, dim)`.
456 """
457 if not self.compression:
458 raise ValueError(f'Compression is not supported for variable "{self.name}". Please specify a compression'
459 f' method for this variable.')
460 if not self.compression.map_exists:
461 raise ValueError(f'Compression map not computed yet for "{self.name}".')
463 # Default field coordinates to the compression coordinates if they are not provided
464 field_coords = values.pop('coords', coords)
465 if field_coords is None:
466 field_coords = self.compression.coords
467 ret_dict = {'coords': field_coords}
469 # For reconstruction: decompress -> denormalize -> interpolate
470 if reconstruct:
471 try:
472 states = np.atleast_1d(values['latent']) # (..., rank)
473 except KeyError as e:
474 raise ValueError('Must pass values["latent"] in for reconstruction.') from e
475 states = self.compression.reconstruct(states) # (..., dof)
476 states = self.denormalize(states) # (..., dof)
477 states = self.compression.interpolate_from_grid(states, field_coords)
478 ret_dict.update(states)
480 # For compression: interpolate -> normalize -> compress
481 else:
482 states = self.compression.interpolate_to_grid(field_coords, values)
483 states = self.normalize(states) # (..., dof)
484 states = self.compression.compress(states) # (..., rank)
485 ret_dict['latent'] = states
487 return ret_dict
489 def reconstruct(self, values, coords=None):
490 """Alias for `compress(reconstruct=True)`. See `compress` for more details."""
491 return self.compress(values, coords=coords, reconstruct=True)
493 def serialize(self, save_path: str | Path = '.') -> dict:
494 """Convert a `Variable` to a `dict` with only standard Python types
495 (i.e. convert custom objects like `dist` and `norm` to strings and save `compression` to a `.pkl`).
497 :param save_path: the path to save the compression data to (defaults to current directory)
498 :returns: the serialized `dict` of the `Variable` object
499 """
500 d = {}
501 for key, value in self.__dict__.items():
502 if value is not None and not key.startswith('_'):
503 if key == 'domain':
504 d[key] = [str(v) for v in value] if isinstance(value, list) else str(value)
505 elif key == 'distribution':
506 d[key] = str(value)
507 elif key == 'norm':
508 d[key] = [str(transform) for transform in value]
509 elif key == 'compression':
510 fname = f'{self.name}_compression.pkl'
511 d[key] = value.serialize(save_path=Path(save_path) / fname)
512 else:
513 d[key] = value
514 return d
516 @classmethod
517 def deserialize(cls, data: dict, search_paths: list[str | Path] = None) -> Variable:
518 """Convert a `dict` to a `Variable` object. Let `pydantic` handle validation and conversion of fields.
520 :param data: the `dict` to convert to a `Variable`
521 :param search_paths: the paths to search for compression files (if necessary)
522 :returns: the `Variable` object
523 """
524 if isinstance(data, Variable):
525 return data
526 elif isinstance(data, str):
527 return cls(name=data)
528 else:
529 if (compression := data.get('compression', None)) is not None:
530 if isinstance(compression, str):
531 data['compression'] = search_for_file(compression, search_paths=search_paths)
532 return cls(**data)
534 @staticmethod
535 def _yaml_representer(dumper: yaml.Dumper, data: Variable) -> yaml.MappingNode:
536 """Convert a single `Variable` object (`data`) to a yaml MappingNode (i.e. a `dict`)."""
537 save_path, save_file = _get_yaml_path(dumper)
538 return dumper.represent_mapping(Variable.yaml_tag, data.serialize(save_path=save_path))
540 @staticmethod
541 def _yaml_constructor(loader: yaml.Loader, node):
542 """Convert the `!Variable` tag in yaml to a single `Variable` object (or a list of `Variables`)."""
543 save_path, save_file = _get_yaml_path(loader)
544 if isinstance(node, yaml.SequenceNode):
545 return [ele if isinstance(ele, Variable) else Variable.deserialize(ele, search_paths=[save_path]) for ele in
546 loader.construct_sequence(node, deep=True)]
547 elif isinstance(node, yaml.MappingNode):
548 return Variable.deserialize(loader.construct_mapping(node, deep=True), search_paths=[save_path])
549 else:
550 raise NotImplementedError(f'The "{Variable.yaml_tag}" yaml tag can only be used on a yaml sequence or '
551 f'mapping, not a "{type(node)}".')
554class VariableList(OrderedDict, Serializable):
555 """Store `Variables` as `str(var) : Variable` in the order they were passed in. You can:
557 - Initialize/update from a single `Variable` or a list of `Variables`
558 - Get/set a `Variable` directly or by name via `my_vars[var]` or `my_vars[str(var)]` etc.
559 - Retrieve the original order of insertion by `list(my_vars.items())`
560 - Access/delete elements by order of insertion using integer/slice indexing (i.e. `my_vars[1:3]`)
561 - Save/load from yaml file using the `!VariableList` tag
562 """
563 yaml_tag = '!VariableList'
565 def __init__(self, data: list[Variable] | Variable | OrderedDict | dict = None, **kwargs):
566 """Initialize a collection of `Variable` objects."""
567 super().__init__()
568 self.update(data, **kwargs)
570 def __iter__(self):
571 yield from self.values()
573 def __eq__(self, other):
574 if isinstance(other, VariableList):
575 for v1, v2 in zip(self.values(), other.values()):
576 if v1 != v2:
577 return False
578 return True
579 else:
580 return False
582 def append(self, data: Variable):
583 self.update(data)
585 def extend(self, data: list[Variable]):
586 self.update(data)
588 def index(self, key):
589 for i, k in enumerate(self.keys()):
590 if k == key:
591 return i
592 raise ValueError(f"'{key}' is not in list")
594 def get_domains(self, norm: bool = True):
595 """Get normalized variable domains (expand latent coefficient domains for field quantities). Assume a
596 domain of `(0, 1)` for variables if their domain is not specified.
598 :param norm: whether to normalize the domains using `Variable.norm` (useful for getting bds for surrogate);
599 latent coefficient domains do not get normalized
600 :returns: a `dict` of variables to their normalized domains; field quantities return a domain for each
601 of their latent coefficients
602 """
603 domains = {}
604 for var in self:
605 var_domain = var.get_domain()
606 if isinstance(var_domain, list): # only field qtys return a list of domains, one for each latent coeff
607 for i, domain in enumerate(var_domain):
608 domains[f'{var.name}{LATENT_STR_ID}{i}'] = domain
609 elif var_domain is None:
610 domains[var.name] = (0, 1)
611 else:
612 domains[var.name] = var.normalize(var_domain) if norm else var_domain
613 return domains
615 def get_pdfs(self, norm: bool = True):
616 """Get callable pdfs for all variables (skipping field quantities for now)
618 :param norm: whether values passed to the pdf functions are normalized and should be denormed first
619 before pdf evaluation (useful for surrogate construction where samples are gathered in the
620 normalized space)
621 :returns: a `dict` of variables to callable pdf functions; field quantities are skipped.
622 """
623 def _get_pdf(var, norm):
624 return lambda z: var.pdf(var.denormalize(z) if norm else z)
626 pdf_fcns = {}
627 for var in self:
628 var_domain = var.get_domain()
629 if isinstance(var_domain, list): # only field qtys return a list of domains, one for each latent coeff
630 # for i, domain in enumerate(var_domain):
631 # pdf_fcns[f'{var.name}{LATENT_STR_ID}{i}'] = var.latent_pdfs[i] TODO: Implement latent pdfs
632 pass
633 else:
634 pdf_fcns[var.name] = _get_pdf(var, norm)
635 return pdf_fcns
637 def update(self, data: list[Variable | str] | str | Variable | OrderedDict | dict = None, **kwargs):
638 """Update from a list or dict of `Variable` objects, or from `key=value` pairs."""
639 if data:
640 if isinstance(data, OrderedDict | dict):
641 for key, value in data.items():
642 self.__setitem__(key, value)
643 else:
644 data = [data] if not isinstance(data, list | tuple) else data
645 for variable in data:
646 self.__setitem__(str(variable), variable)
647 if kwargs:
648 for key, value in kwargs.items():
649 self.__setitem__(key, value)
651 def get(self, key, default=None):
652 """Make sure this passes through `__getitem__()`"""
653 try:
654 return self.__getitem__(key)
655 except Exception:
656 return default
658 def __setitem__(self, key, value):
659 """Only allow `str(var): Variable` items. Or normal list indexing via `my_vars[0] = var`."""
660 if isinstance(key, int):
661 k = list(self.keys())[key]
662 self.__setitem__(k, value)
663 return
664 if isinstance(value, str):
665 value = Variable(name=value)
666 if not isinstance(key, str | Variable):
667 raise TypeError(f'VariableList key "{key}" is not a Variable or string.')
668 if not isinstance(value, Variable):
669 raise TypeError(f'VariableList value "{value}" is not a Variable.')
670 super().__setitem__(str(key), value)
672 def __getitem__(self, key):
673 """Allow accessing variable(s) directly via `my_vars[var]` or by index/slicing."""
674 if isinstance(key, list | tuple):
675 return [self.__getitem__(ele) for ele in key]
676 elif isinstance(key, int | slice):
677 return list(self.values())[key]
678 elif isinstance(key, str | Variable):
679 return super().__getitem__(str(key))
680 else:
681 raise TypeError(f'VariableList key "{key}" is not valid.')
683 def __delitem__(self, key):
684 """Allow deleting variable(s) directly or by index/slicing."""
685 if isinstance(key, list | tuple):
686 for ele in key:
687 self.__delitem__(ele)
688 elif isinstance(key, int | slice):
689 ele = list(self.keys())[key]
690 if isinstance(ele, list):
691 for item in ele:
692 super().__delitem__(item)
693 else:
694 super().__delitem__(ele)
695 elif isinstance(key, str | Variable):
696 super().__delitem__(str(key))
697 else:
698 raise TypeError(f'VariableList key "{key}" is not valid.')
700 def __str__(self):
701 return str(list(self.values()))
703 def __repr__(self):
704 return self.__str__()
706 def serialize(self, save_path='.') -> list[dict]:
707 """Convert to a list of `dict` objects for each `Variable` in the list.
709 :param save_path: the path to save the compression data to (defaults to current directory)
710 """
711 return [var.serialize(save_path=save_path) for var in self.values()]
713 @classmethod
714 def merge(cls, *variable_lists) -> VariableList:
715 """Merge multiple sets of variables into a single `VariableList` object.
717 !!! Note
718 Variables with the same name will be merged by keeping the one with the most information provided.
720 :param variable_lists: the variables/lists to merge
721 :returns: the merged `VariableList` object
722 """
723 merged_vars = cls()
725 def _get_best_variable(var1, var2):
726 var1_dict = {key: value for key, value in var1.__dict__.items() if value is not None}
727 var2_dict = {key: value for key, value in var2.__dict__.items() if value is not None}
728 return var1 if len(var1_dict) >= len(var2_dict) else var2
730 for var_list in variable_lists:
731 for var in cls(var_list):
732 if var.name in merged_vars:
733 merged_vars[var.name] = _get_best_variable(merged_vars[var.name], var)
734 else:
735 merged_vars[var.name] = var
737 return merged_vars
739 @classmethod
740 def deserialize(cls, data: dict | list[dict], search_paths=None) -> VariableList:
741 """Convert a `dict` or list of `dict` objects to a `VariableList` object. Let `pydantic` handle validation."""
742 if not isinstance(data, list):
743 data = [data]
744 return cls([Variable.deserialize(d, search_paths=search_paths) for d in data])
746 @staticmethod
747 def _yaml_representer(dumper: yaml.Dumper, data: VariableList) -> yaml.SequenceNode:
748 """Convert a single `VariableList` object (`data`) to a yaml SequenceNode (i.e. a list)."""
749 save_path, save_file = _get_yaml_path(dumper)
750 return dumper.represent_sequence(VariableList.yaml_tag, data.serialize(save_path=save_path))
752 @staticmethod
753 def _yaml_constructor(loader: yaml.Loader, node):
754 """Convert the `!VariableList` tag in yaml to a `VariableList` object."""
755 save_path, save_file = _get_yaml_path(loader)
756 if isinstance(node, yaml.SequenceNode):
757 return VariableList.deserialize(loader.construct_sequence(node, deep=True), search_paths=[save_path])
758 elif isinstance(node, yaml.MappingNode):
759 return VariableList.deserialize(loader.construct_mapping(node), search_paths=[save_path])
760 else:
761 raise NotImplementedError(f'The "{VariableList.yaml_tag}" yaml tag can only be used on a yaml sequence or '
762 f'mapping, not a "{type(node)}".')