Coverage for src/amisc/variable.py: 90%

421 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-01-24 04:51 +0000

1"""Provides an object-oriented interface for model inputs/outputs, random variables, scalars, and field quantities. 

2 

3Includes: 

4 

5- `Variable` — an object that stores information about a variable and includes methods for sampling, pdf evaluation, 

6 normalization, compression, loading from file, etc. Variables can mostly be treated as strings 

7 that have some additional information and utilities attached to them. 

8- `VariableList` — a container for `Variables` that provides dict-like access of `Variables` by `name` along with normal 

9 indexing and slicing. 

10 

11The preferred serialization of `Variable` and `VariableList` is to/from yaml. This is done by default with the 

12`!Variable` and `!VariableList` yaml tags. 

13""" 

14from __future__ import annotations 

15 

16import ast 

17import random 

18import string 

19from collections import OrderedDict 

20from pathlib import Path 

21from typing import ClassVar, Optional, Union 

22 

23import numpy as np 

24import yaml 

25from numpy.typing import ArrayLike 

26from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator 

27 

28from amisc.compression import Compression 

29from amisc.distribution import Distribution, LogUniform, Normal, Uniform 

30from amisc.serialize import Serializable 

31from amisc.transform import Minmax, Transform, Zscore 

32from amisc.typing import LATENT_STR_ID, CompressionData 

33from amisc.utils import _get_yaml_path, _inspect_assignment, search_for_file 

34 

35__all__ = ['Variable', 'VariableList'] 

36_TransformLike = Union[str, Transform, list[str | Transform]] # something that can be converted to a Transform 

37 

38 

39class Variable(BaseModel, Serializable): 

40 """Object for storing information about variables and providing methods for pdf evaluation, sampling, etc. 

41 All fields will undergo pydantic validation and conversion to the correct types. 

42 

43 A simple variable object can be created with `var = Variable()`. All initialization options are optional and will 

44 be given good defaults. You should probably at the very least give a memorable `name` and a `domain`. Variables 

45 can mostly be treated as strings with some extra information/utilities attached. 

46 

47 With the `pyyaml` library installed, all `Variable` objects can be saved or loaded directly from a `.yml` file by 

48 using the `!Variable` yaml tag (which is loaded by default with `amisc`). 

49 

50 - Use `Variable.distribution` to specify PDFs, such as for random variables. See the `Distribution` classes. 

51 - Use `Variable.norm` to specify a transformed-space that is more amenable to surrogate construction 

52 (e.g. mapping to the range (0,1)). See the `Transform` classes. 

53 - Use `Variable.compression` to specify high-dimensional, coordinate-based field quantities, 

54 such as from the output of many simulation software programs. See the `Compression` classes. 

55 - Use `Variable.category` as an additional layer for using Variable's in different ways (e.g. set a "calibration" 

56 category for Bayesian inference). 

57 

58 !!! Example 

59 ```python 

60 # Random variable 

61 temp = Variable(name='T', description='Temperature', units='K', distribution='Uniform(280, 320)') 

62 samples = temp.sample(100) 

63 pdf = temp.pdf(samples) 

64 

65 # Field quantity 

66 vel = Variable(name='u', description='Velocity', units='m/s', compression={'fields': ['ux', 'uy', 'uz']}) 

67 vel_data = ... # from a simulation 

68 reduced_vel = vel.compress(vel_data) 

69 ``` 

70 

71 !!! Warning 

72 Changes to collection fields (like `Variable.norm`) should completely reassign the _whole_ 

73 collection to trigger the correct validation, rather than editing particular entries. For example, reassign 

74 `norm=['log', 'linear(2, 2)']` rather than editing norm via `norm.append('linear(2, 2)')`. 

75 

76 :ivar name: an identifier for the variable, can compare variables directly with strings for indexing purposes 

77 :ivar nominal: a typical value for this variable 

78 :ivar description: a lengthier description of the variable 

79 :ivar units: assumed units for the variable (if applicable) 

80 :ivar category: an additional descriptor for how this variable is used, e.g. calibration, operating, design, etc. 

81 :ivar tex: latex format for the variable, i.e. "$x_i$" 

82 :ivar compression: specifies field quantities and links to relevant compression data 

83 :ivar distribution: a string specifier of a probability distribution function (see the `Distribution` types) 

84 :ivar domain: the explicit domain bounds of the variable (limits of where you expect to use it); 

85 for field quantities, this is a list of domains for each latent dimension 

86 :ivar norm: specifier of a map to a transformed-space for surrogate construction (see the `Transform` types) 

87 """ 

88 yaml_tag: ClassVar[str] = u'!Variable' 

89 model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True, validate_default=True) 

90 

91 name: Optional[str] = None 

92 nominal: Optional[float] = None 

93 description: Optional[str] = None 

94 units: Optional[str] = None 

95 category: Optional[str] = None 

96 tex: Optional[str] = None 

97 compression: Optional[str | dict | Compression] = None 

98 distribution: Optional[str | Distribution] = None 

99 domain: Optional[str | tuple[float, float] | list] = None 

100 norm: Optional[_TransformLike] = None 

101 

102 def __init__(self, /, name=None, **kwargs): 

103 # Try to set the variable name if instantiated as "x = Variable()" 

104 if name is None: 

105 name = _inspect_assignment('Variable') 

106 name = name or "X_" + "".join(random.choices(string.digits, k=3)) 

107 super().__init__(name=name, **kwargs) 

108 

109 @field_validator('tex') 

110 @classmethod 

111 def _validate_tex(cls, tex: str) -> str | None: 

112 if tex is None: 

113 return tex 

114 if not tex.startswith('$'): 

115 tex = rf'${tex}' 

116 if not tex[-1] == '$': 

117 tex = rf'{tex}$' 

118 return tex 

119 

120 @field_validator('compression') 

121 @classmethod 

122 def _validate_compression(cls, compression: str | dict | Compression, info: ValidationInfo) -> Compression | None: 

123 if compression is None: 

124 return compression 

125 elif isinstance(compression, str): 

126 return Compression.deserialize(compression) 

127 elif isinstance(compression, dict): 

128 compression['fields'] = compression.get('fields', None) or [info.data['name']] 

129 return Compression.from_dict(compression) 

130 else: 

131 compression.fields = compression.fields or [info.data['name']] 

132 return compression 

133 

134 @field_validator('distribution') 

135 @classmethod 

136 def _validate_dist(cls, dist: str | Distribution) -> Distribution | None: 

137 if dist is None: 

138 return dist 

139 if isinstance(dist, Distribution): 

140 return dist 

141 elif isinstance(dist, str): 

142 return Distribution.from_string(dist) 

143 else: 

144 raise ValueError(f'Cannot convert {dist} to a Distribution object.') 

145 

146 @field_validator('domain') 

147 @classmethod 

148 def _validate_domain(cls, domain: list | tuple | str, info: ValidationInfo) -> tuple | list | None: 

149 """Try to extract the domain from the distribution if not provided, or convert from a string. 

150 Returns a list of domains for each latent dimension if this is a field quantity with compression. 

151 """ 

152 if domain is None: 

153 if dist := info.data['distribution']: 

154 domain = dist.domain() 

155 elif compression := info.data['compression']: 

156 if (ranges := compression.estimate_latent_ranges()) is not None: 

157 domain = [tuple(map(float, val)) for val in ranges] 

158 elif isinstance(domain, str): 

159 domain = tuple(ast.literal_eval(domain.strip())) 

160 elif isinstance(domain, list): 

161 if len(domain) == 2 and isinstance(domain[0], float | int) and isinstance(domain[1], float | int): 

162 domain = tuple(domain) # allow lists of 2 elements to be interpreted as a scalar variable domain 

163 else: 

164 domain = [tuple(ast.literal_eval(d.strip())) if isinstance(d, str) else d for d in domain] # field qty 

165 

166 if domain is None: 

167 return domain 

168 

169 if isinstance(domain, list): 

170 for d in domain: 

171 assert isinstance(d, tuple) and len(d) == 2 

172 assert d[1] > d[0], 'Domain must be specified as (lower_bound, upper_bound)' 

173 else: 

174 assert isinstance(domain, tuple) and len(domain) == 2 

175 assert domain[1] > domain[0], 'Domain must be specified as (lower_bound, upper_bound)' 

176 

177 return domain 

178 

179 @field_validator('norm') 

180 @classmethod 

181 def _validate_norm(cls, norm: _TransformLike, info: ValidationInfo) -> list[Transform] | None: 

182 if norm is None: 

183 return norm 

184 norm = Transform.from_string(norm) 

185 

186 # Set default values for minmax and zscore transforms 

187 domain = info.data['domain'] 

188 normal_args = None 

189 if dist := info.data['distribution']: 

190 if isinstance(dist, Normal): 

191 normal_args = dist.dist_args 

192 for transform in norm: 

193 if isinstance(transform, Minmax): 

194 if domain and np.any(np.isnan(transform.transform_args[0:2])): 

195 transform.update(lb=domain[0], ub=domain[1]) 

196 elif isinstance(transform, Zscore): 

197 if normal_args and np.any(np.isnan(transform.transform_args)): 

198 transform.update(mu=normal_args[0], std=normal_args[1]) 

199 

200 return norm 

201 

202 def __getitem__(self, item): 

203 return getattr(self, item) 

204 

205 def __setitem__(self, key, value): 

206 setattr(self, key, value) 

207 

208 def __str__(self): 

209 return self.name 

210 

211 def __repr__(self): 

212 return self.__str__() 

213 

214 def __hash__(self): 

215 """Allows variables to be used as keys in dictionaries and to be considered equal to their string 

216 representations. 

217 """ 

218 return hash(self.name) 

219 

220 def __eq__(self, other): 

221 """Consider two `Variables` equal if they share the same string name 

222 

223 Also returns true when checking if this `Variable` is equal to a string by itself. 

224 """ 

225 if isinstance(other, Variable): 

226 return self.name == other.name 

227 elif isinstance(other, str): 

228 return self.name == other 

229 else: 

230 return False 

231 

232 def get_tex(self, units: bool = False, symbol: bool = True) -> str: 

233 """Return a raw string that is well-formatted for plotting (with latex). 

234 

235 :param units: whether to include the units in the string 

236 :param symbol: just latex symbol if true, otherwise the full description 

237 :returns: the latex formatted string 

238 """ 

239 s = (self.tex if symbol else self.description) or self.name 

240 return r'{} [{}]'.format(s, self.units or '-') if units else r'{}'.format(s) 

241 

242 def get_nominal(self) -> float | list | None: 

243 """Return the nominal value of the variable. Defaults to the mean for a normal distribution or the 

244 center of the domain if `var.nominal` is not specified. Returns a list of nominal values for each latent 

245 dimension if this is a field quantity with compression. 

246 

247 :returns: the nominal value(s) 

248 """ 

249 nominal = self.nominal 

250 if nominal is None: 

251 if dist := self.distribution: 

252 nominal = float(dist.nominal()) 

253 elif domain := self.get_domain(): 

254 nominal = [np.mean(d) for d in domain] if isinstance(domain, list) else float(np.mean(domain)) 

255 

256 return nominal 

257 

258 def get_domain(self) -> tuple | list | None: 

259 """Return a tuple of the defined domain of this variable. Returns a list of domains for each latent dimension 

260 if this is a field quantity with compression. 

261 

262 :returns: the domain(s) of this variable 

263 """ 

264 if self.domain is None: 

265 return None 

266 elif isinstance(self.domain, list): 

267 return self.domain 

268 elif self.compression is not None: 

269 # Try to infer a list of domains from compression latent size 

270 try: 

271 return [self.domain] * self.compression.latent_size() 

272 except Exception as e: 

273 raise ValueError(f'Variables with `compression` data should return a list of domains, one ' 

274 f'for each latent coefficient. Could not infer domain for "{self.name}".') from e 

275 else: 

276 return self.domain 

277 

278 def sample_domain(self, shape: tuple | int) -> np.ndarray: 

279 """Return an array of the given `shape` for uniform samples over the domain of this variable. Returns 

280 samples for each latent dimension if this is a field quantity with compression. 

281 

282 Will always sample uniformly over the normalized surrogate domain if `norm` is specified, and will return 

283 samples in the original unnormalized domain. 

284 

285 !!! Note 

286 The last dim of the returned samples will be the latent space size for field quantities. 

287 

288 :param shape: the shape of samples to return 

289 :returns: the random samples over the domain of the variable 

290 """ 

291 if isinstance(shape, int): 

292 shape = (shape, ) 

293 if domain := self.get_domain(): 

294 if isinstance(domain, list): 

295 lb = np.atleast_1d([d[0] for d in domain]) 

296 ub = np.atleast_1d([d[1] for d in domain]) 

297 return np.random.rand(*shape, 1) * (ub - lb) + lb 

298 else: 

299 lb, ub = self.normalize(domain) 

300 norm_samples = np.random.rand(*shape) * (ub - lb) + lb 

301 return self.denormalize(norm_samples) 

302 else: 

303 raise RuntimeError(f'Variable "{self.name}" does not have a domain specified.') 

304 

305 def update_domain(self, domain: tuple[float, float] | list[tuple], override: bool = False): 

306 """Update the domain of this variable by taking the minimum or maximum of the new domain with the current domain 

307 for the lower and upper bounds, respectively. Will attempt to update the domain of each latent dimension 

308 if this is a field quantity with compression. If the variable has a `Uniform` distribution, this will 

309 update the distribution's bounds too. 

310 

311 :param domain: the new domain(s) to update with 

312 :param override: will simply set the domain to the new values rather than update against the current domain; 

313 (default `False`) 

314 """ 

315 def _update_domain(domain, curr_domain): 

316 lb, ub = domain 

317 ret = (lb, ub) if override else (min(lb, curr_domain[0]) if curr_domain is not None else lb, 

318 max(ub, curr_domain[1]) if curr_domain is not None else ub) 

319 return tuple(map(float, ret)) 

320 

321 curr_domain = self.get_domain() 

322 if isinstance(domain, list): 

323 if not isinstance(curr_domain, list): 

324 curr_domain = [curr_domain] * len(domain) 

325 self.domain = [_update_domain(d, curr_domain[i]) for i, d in enumerate(domain)] 

326 elif isinstance(curr_domain, list): 

327 if not isinstance(domain, list): 

328 domain = [domain] * len(curr_domain) 

329 self.domain = [_update_domain(d, curr_domain[i]) for i, d in enumerate(domain)] 

330 else: 

331 self.domain = _update_domain(domain, curr_domain) 

332 if (dist := self.distribution) is not None and isinstance(dist, Uniform | LogUniform): 

333 dist.dist_args = self.domain # keep Uniform dist in sync 

334 

335 def pdf(self, x: np.ndarray) -> np.ndarray: 

336 """Compute the PDF of the Variable at the given `x` locations. Will just return one's if the variable 

337 does not have a distribution. 

338 

339 :param x: locations to compute the PDF at 

340 :returns: the PDF evaluations at `x` 

341 """ 

342 if dist := self.distribution: 

343 return dist.pdf(x) 

344 else: 

345 return np.ones(x.shape) # No pdf if no dist is specified 

346 

347 def sample(self, shape: tuple | int, nominal: float | np.ndarray = None) -> np.ndarray: 

348 """Draw samples from this `Variable's` distribution. Just returns the nominal value of the given shape if 

349 this `Variable` has no distribution. 

350 

351 :param shape: the shape of the returned samples 

352 :param nominal: a nominal value to use if applicable (i.e. a center for relative, tolerance, or normal) 

353 :returns: samples from the PDF of this `Variable's` distribution 

354 """ 

355 if isinstance(shape, int): 

356 shape = (shape, ) 

357 if nominal is None: 

358 nominal = self.get_nominal() 

359 

360 if dist := self.distribution: 

361 return dist.sample(shape, nominal) 

362 else: 

363 # Variable's with no distribution 

364 if nominal is None: 

365 raise ValueError(f'Cannot sample "{self.name}" with no dist or nominal value specified.') 

366 elif isinstance(nominal, list | np.ndarray): 

367 return np.ones(shape + (len(nominal),)) * np.atleast_1d(nominal) # For field quantities 

368 else: 

369 return np.ones(shape) * nominal 

370 

371 def normalize(self, values: ArrayLike, denorm: bool = False) -> ArrayLike | None: 

372 """Normalize `values` based on this `Variable's` `norm` method(s). See `Transform` for available norm methods. 

373 

374 !!! Note 

375 If this Variable's `self.norm` was specified as a list of norm methods, then each will be applied in 

376 sequence in the original order (and in reverse for `denorm=True`). When `self.distribution` is involved in 

377 the transforms (only for `minmax` and `zscore`), the `dist_args` will get normalized too at each 

378 transform before applying the next transform. 

379 

380 :param values: the values to normalize (array-like) 

381 :param denorm: whether to denormalize instead using the inverse of the original normalization method 

382 :returns: the normalized (or unnormalized) values 

383 """ 

384 if not self.norm or values is None: 

385 return values 

386 if dist := self.distribution: 

387 normal_dist = isinstance(dist, Normal) 

388 else: 

389 normal_dist = False 

390 

391 def _normalize_single(values, transform, inverse, domain, dist_args): 

392 """Do a single transform. Might need to override transform_args depending on the transform.""" 

393 transform_args = None 

394 if isinstance(transform, Minmax) and domain: 

395 transform_args = domain + transform.transform_args[2:] # Update minmax bounds 

396 elif isinstance(transform, Zscore) and dist_args: 

397 transform_args = dist_args # Update N(mu, std) 

398 

399 return transform.transform(values, inverse=inverse, transform_args=transform_args) 

400 

401 domain = self.get_domain() or () 

402 dist_args = self.distribution.dist_args if normal_dist else [] 

403 if isinstance(domain, list): 

404 domain = () # For field quantities, domain is not used in normalization 

405 

406 if denorm: 

407 # First, send domain and dist_args through the forward norm list (up until the last norm) 

408 hyperparams = [np.hstack((domain, dist_args))] 

409 for i, transform in enumerate(self.norm): 

410 domain, dist_args = tuple(hyperparams[i][:2]), tuple(hyperparams[i][2:]) 

411 hyperparams.append(_normalize_single(hyperparams[i], transform, False, domain, dist_args)) 

412 

413 # Now denormalize in reverse 

414 hp_idx = -2 

415 for transform in reversed(self.norm): 

416 domain, dist_args = tuple(hyperparams[hp_idx][:2]), tuple(hyperparams[hp_idx][2:]) 

417 values = _normalize_single(values, transform, True, domain, dist_args) 

418 hp_idx -= 1 

419 else: 

420 # Normalize values and hyperparams through the forward norm list 

421 hyperparams = np.hstack((domain, dist_args)) 

422 for transform in self.norm: 

423 domain, dist_args = tuple(hyperparams[:2]), tuple(hyperparams[2:]) 

424 values = _normalize_single(values, transform, denorm, domain, dist_args) 

425 hyperparams = _normalize_single(hyperparams, transform, denorm, domain, dist_args) 

426 

427 return values 

428 

429 def denormalize(self, values): 

430 """Alias for `normalize(denorm=True)`. See `normalize` for more details.""" 

431 return self.normalize(values, denorm=True) 

432 

433 def compress(self, values: CompressionData, coords: np.ndarray = None, 

434 reconstruct: bool = False) -> CompressionData: 

435 """Compress or reconstruct field quantity values using this `Variable's` compression info. 

436 

437 !!! Note "Specifying compression values" 

438 If only one field quantity is associated with this variable, then 

439 specify `values` as `dict(coords=..., name=...)` for this Variable's `name`. If `coords` is not specified, 

440 then this assumes the locations are the same as the reconstruction data (and skips interpolation). 

441 

442 !!! Info "Compression workflow" 

443 Generally, compression follows `interpolate -> normalize -> compress` to take raw values into the compressed 

444 "latent" space. The interpolation step is required to make sure `values` align with the coordinates used 

445 when building the compression map in the first place (such as through SVD). 

446 

447 :param values: a `dict` with a key for each field qty of shape `(..., qty.shape)` and a `coords` key of shape 

448 `(qty.shape, dim)` that gives the coordinates of each point. Only a single `latent` key should 

449 be given instead if `reconstruct=True`. 

450 :param coords: the coordinates of each point in `values` if `values` did not contain a `coords` key; 

451 defaults to the compression grid coordinates 

452 :param reconstruct: whether to reconstruct values instead of compress 

453 :returns: the compressed values with key `latent` and shape `(..., latent_size)`; if `reconstruct=True`, 

454 then the reconstructed values with shape `(..., qty.shape)` for each `qty` key are returned. 

455 The return `dict` also has a `coords` key with shape `(qty.shape, dim)`. 

456 """ 

457 if not self.compression: 

458 raise ValueError(f'Compression is not supported for variable "{self.name}". Please specify a compression' 

459 f' method for this variable.') 

460 if not self.compression.map_exists: 

461 raise ValueError(f'Compression map not computed yet for "{self.name}".') 

462 

463 # Default field coordinates to the compression coordinates if they are not provided 

464 field_coords = values.pop('coords', coords) 

465 if field_coords is None: 

466 field_coords = self.compression.coords 

467 ret_dict = {'coords': field_coords} 

468 

469 # For reconstruction: decompress -> denormalize -> interpolate 

470 if reconstruct: 

471 try: 

472 states = np.atleast_1d(values['latent']) # (..., rank) 

473 except KeyError as e: 

474 raise ValueError('Must pass values["latent"] in for reconstruction.') from e 

475 states = self.compression.reconstruct(states) # (..., dof) 

476 states = self.denormalize(states) # (..., dof) 

477 states = self.compression.interpolate_from_grid(states, field_coords) 

478 ret_dict.update(states) 

479 

480 # For compression: interpolate -> normalize -> compress 

481 else: 

482 states = self.compression.interpolate_to_grid(field_coords, values) 

483 states = self.normalize(states) # (..., dof) 

484 states = self.compression.compress(states) # (..., rank) 

485 ret_dict['latent'] = states 

486 

487 return ret_dict 

488 

489 def reconstruct(self, values, coords=None): 

490 """Alias for `compress(reconstruct=True)`. See `compress` for more details.""" 

491 return self.compress(values, coords=coords, reconstruct=True) 

492 

493 def serialize(self, save_path: str | Path = '.') -> dict: 

494 """Convert a `Variable` to a `dict` with only standard Python types 

495 (i.e. convert custom objects like `dist` and `norm` to strings and save `compression` to a `.pkl`). 

496 

497 :param save_path: the path to save the compression data to (defaults to current directory) 

498 :returns: the serialized `dict` of the `Variable` object 

499 """ 

500 d = {} 

501 for key, value in self.__dict__.items(): 

502 if value is not None and not key.startswith('_'): 

503 if key == 'domain': 

504 d[key] = [str(v) for v in value] if isinstance(value, list) else str(value) 

505 elif key == 'distribution': 

506 d[key] = str(value) 

507 elif key == 'norm': 

508 d[key] = [str(transform) for transform in value] 

509 elif key == 'compression': 

510 fname = f'{self.name}_compression.pkl' 

511 d[key] = value.serialize(save_path=Path(save_path) / fname) 

512 else: 

513 d[key] = value 

514 return d 

515 

516 @classmethod 

517 def deserialize(cls, data: dict, search_paths: list[str | Path] = None) -> Variable: 

518 """Convert a `dict` to a `Variable` object. Let `pydantic` handle validation and conversion of fields. 

519 

520 :param data: the `dict` to convert to a `Variable` 

521 :param search_paths: the paths to search for compression files (if necessary) 

522 :returns: the `Variable` object 

523 """ 

524 if isinstance(data, Variable): 

525 return data 

526 elif isinstance(data, str): 

527 return cls(name=data) 

528 else: 

529 if (compression := data.get('compression', None)) is not None: 

530 if isinstance(compression, str): 

531 data['compression'] = search_for_file(compression, search_paths=search_paths) 

532 return cls(**data) 

533 

534 @staticmethod 

535 def _yaml_representer(dumper: yaml.Dumper, data: Variable) -> yaml.MappingNode: 

536 """Convert a single `Variable` object (`data`) to a yaml MappingNode (i.e. a `dict`).""" 

537 save_path, save_file = _get_yaml_path(dumper) 

538 return dumper.represent_mapping(Variable.yaml_tag, data.serialize(save_path=save_path)) 

539 

540 @staticmethod 

541 def _yaml_constructor(loader: yaml.Loader, node): 

542 """Convert the `!Variable` tag in yaml to a single `Variable` object (or a list of `Variables`).""" 

543 save_path, save_file = _get_yaml_path(loader) 

544 if isinstance(node, yaml.SequenceNode): 

545 return [ele if isinstance(ele, Variable) else Variable.deserialize(ele, search_paths=[save_path]) for ele in 

546 loader.construct_sequence(node, deep=True)] 

547 elif isinstance(node, yaml.MappingNode): 

548 return Variable.deserialize(loader.construct_mapping(node, deep=True), search_paths=[save_path]) 

549 else: 

550 raise NotImplementedError(f'The "{Variable.yaml_tag}" yaml tag can only be used on a yaml sequence or ' 

551 f'mapping, not a "{type(node)}".') 

552 

553 

554class VariableList(OrderedDict, Serializable): 

555 """Store `Variables` as `str(var) : Variable` in the order they were passed in. You can: 

556 

557 - Initialize/update from a single `Variable` or a list of `Variables` 

558 - Get/set a `Variable` directly or by name via `my_vars[var]` or `my_vars[str(var)]` etc. 

559 - Retrieve the original order of insertion by `list(my_vars.items())` 

560 - Access/delete elements by order of insertion using integer/slice indexing (i.e. `my_vars[1:3]`) 

561 - Save/load from yaml file using the `!VariableList` tag 

562 """ 

563 yaml_tag = '!VariableList' 

564 

565 def __init__(self, data: list[Variable] | Variable | OrderedDict | dict = None, **kwargs): 

566 """Initialize a collection of `Variable` objects.""" 

567 super().__init__() 

568 self.update(data, **kwargs) 

569 

570 def __iter__(self): 

571 yield from self.values() 

572 

573 def __eq__(self, other): 

574 if isinstance(other, VariableList): 

575 for v1, v2 in zip(self.values(), other.values()): 

576 if v1 != v2: 

577 return False 

578 return True 

579 else: 

580 return False 

581 

582 def append(self, data: Variable): 

583 self.update(data) 

584 

585 def extend(self, data: list[Variable]): 

586 self.update(data) 

587 

588 def index(self, key): 

589 for i, k in enumerate(self.keys()): 

590 if k == key: 

591 return i 

592 raise ValueError(f"'{key}' is not in list") 

593 

594 def get_domains(self, norm: bool = True): 

595 """Get normalized variable domains (expand latent coefficient domains for field quantities). Assume a 

596 domain of `(0, 1)` for variables if their domain is not specified. 

597 

598 :param norm: whether to normalize the domains using `Variable.norm` (useful for getting bds for surrogate); 

599 latent coefficient domains do not get normalized 

600 :returns: a `dict` of variables to their normalized domains; field quantities return a domain for each 

601 of their latent coefficients 

602 """ 

603 domains = {} 

604 for var in self: 

605 var_domain = var.get_domain() 

606 if isinstance(var_domain, list): # only field qtys return a list of domains, one for each latent coeff 

607 for i, domain in enumerate(var_domain): 

608 domains[f'{var.name}{LATENT_STR_ID}{i}'] = domain 

609 elif var_domain is None: 

610 domains[var.name] = (0, 1) 

611 else: 

612 domains[var.name] = var.normalize(var_domain) if norm else var_domain 

613 return domains 

614 

615 def get_pdfs(self, norm: bool = True): 

616 """Get callable pdfs for all variables (skipping field quantities for now) 

617 

618 :param norm: whether values passed to the pdf functions are normalized and should be denormed first 

619 before pdf evaluation (useful for surrogate construction where samples are gathered in the 

620 normalized space) 

621 :returns: a `dict` of variables to callable pdf functions; field quantities are skipped. 

622 """ 

623 def _get_pdf(var, norm): 

624 return lambda z: var.pdf(var.denormalize(z) if norm else z) 

625 

626 pdf_fcns = {} 

627 for var in self: 

628 var_domain = var.get_domain() 

629 if isinstance(var_domain, list): # only field qtys return a list of domains, one for each latent coeff 

630 # for i, domain in enumerate(var_domain): 

631 # pdf_fcns[f'{var.name}{LATENT_STR_ID}{i}'] = var.latent_pdfs[i] TODO: Implement latent pdfs 

632 pass 

633 else: 

634 pdf_fcns[var.name] = _get_pdf(var, norm) 

635 return pdf_fcns 

636 

637 def update(self, data: list[Variable | str] | str | Variable | OrderedDict | dict = None, **kwargs): 

638 """Update from a list or dict of `Variable` objects, or from `key=value` pairs.""" 

639 if data: 

640 if isinstance(data, OrderedDict | dict): 

641 for key, value in data.items(): 

642 self.__setitem__(key, value) 

643 else: 

644 data = [data] if not isinstance(data, list | tuple) else data 

645 for variable in data: 

646 self.__setitem__(str(variable), variable) 

647 if kwargs: 

648 for key, value in kwargs.items(): 

649 self.__setitem__(key, value) 

650 

651 def get(self, key, default=None): 

652 """Make sure this passes through `__getitem__()`""" 

653 try: 

654 return self.__getitem__(key) 

655 except Exception: 

656 return default 

657 

658 def __setitem__(self, key, value): 

659 """Only allow `str(var): Variable` items. Or normal list indexing via `my_vars[0] = var`.""" 

660 if isinstance(key, int): 

661 k = list(self.keys())[key] 

662 self.__setitem__(k, value) 

663 return 

664 if isinstance(value, str): 

665 value = Variable(name=value) 

666 if not isinstance(key, str | Variable): 

667 raise TypeError(f'VariableList key "{key}" is not a Variable or string.') 

668 if not isinstance(value, Variable): 

669 raise TypeError(f'VariableList value "{value}" is not a Variable.') 

670 super().__setitem__(str(key), value) 

671 

672 def __getitem__(self, key): 

673 """Allow accessing variable(s) directly via `my_vars[var]` or by index/slicing.""" 

674 if isinstance(key, list | tuple): 

675 return [self.__getitem__(ele) for ele in key] 

676 elif isinstance(key, int | slice): 

677 return list(self.values())[key] 

678 elif isinstance(key, str | Variable): 

679 return super().__getitem__(str(key)) 

680 else: 

681 raise TypeError(f'VariableList key "{key}" is not valid.') 

682 

683 def __delitem__(self, key): 

684 """Allow deleting variable(s) directly or by index/slicing.""" 

685 if isinstance(key, list | tuple): 

686 for ele in key: 

687 self.__delitem__(ele) 

688 elif isinstance(key, int | slice): 

689 ele = list(self.keys())[key] 

690 if isinstance(ele, list): 

691 for item in ele: 

692 super().__delitem__(item) 

693 else: 

694 super().__delitem__(ele) 

695 elif isinstance(key, str | Variable): 

696 super().__delitem__(str(key)) 

697 else: 

698 raise TypeError(f'VariableList key "{key}" is not valid.') 

699 

700 def __str__(self): 

701 return str(list(self.values())) 

702 

703 def __repr__(self): 

704 return self.__str__() 

705 

706 def serialize(self, save_path='.') -> list[dict]: 

707 """Convert to a list of `dict` objects for each `Variable` in the list. 

708 

709 :param save_path: the path to save the compression data to (defaults to current directory) 

710 """ 

711 return [var.serialize(save_path=save_path) for var in self.values()] 

712 

713 @classmethod 

714 def merge(cls, *variable_lists) -> VariableList: 

715 """Merge multiple sets of variables into a single `VariableList` object. 

716 

717 !!! Note 

718 Variables with the same name will be merged by keeping the one with the most information provided. 

719 

720 :param variable_lists: the variables/lists to merge 

721 :returns: the merged `VariableList` object 

722 """ 

723 merged_vars = cls() 

724 

725 def _get_best_variable(var1, var2): 

726 var1_dict = {key: value for key, value in var1.__dict__.items() if value is not None} 

727 var2_dict = {key: value for key, value in var2.__dict__.items() if value is not None} 

728 return var1 if len(var1_dict) >= len(var2_dict) else var2 

729 

730 for var_list in variable_lists: 

731 for var in cls(var_list): 

732 if var.name in merged_vars: 

733 merged_vars[var.name] = _get_best_variable(merged_vars[var.name], var) 

734 else: 

735 merged_vars[var.name] = var 

736 

737 return merged_vars 

738 

739 @classmethod 

740 def deserialize(cls, data: dict | list[dict], search_paths=None) -> VariableList: 

741 """Convert a `dict` or list of `dict` objects to a `VariableList` object. Let `pydantic` handle validation.""" 

742 if not isinstance(data, list): 

743 data = [data] 

744 return cls([Variable.deserialize(d, search_paths=search_paths) for d in data]) 

745 

746 @staticmethod 

747 def _yaml_representer(dumper: yaml.Dumper, data: VariableList) -> yaml.SequenceNode: 

748 """Convert a single `VariableList` object (`data`) to a yaml SequenceNode (i.e. a list).""" 

749 save_path, save_file = _get_yaml_path(dumper) 

750 return dumper.represent_sequence(VariableList.yaml_tag, data.serialize(save_path=save_path)) 

751 

752 @staticmethod 

753 def _yaml_constructor(loader: yaml.Loader, node): 

754 """Convert the `!VariableList` tag in yaml to a `VariableList` object.""" 

755 save_path, save_file = _get_yaml_path(loader) 

756 if isinstance(node, yaml.SequenceNode): 

757 return VariableList.deserialize(loader.construct_sequence(node, deep=True), search_paths=[save_path]) 

758 elif isinstance(node, yaml.MappingNode): 

759 return VariableList.deserialize(loader.construct_mapping(node), search_paths=[save_path]) 

760 else: 

761 raise NotImplementedError(f'The "{VariableList.yaml_tag}" yaml tag can only be used on a yaml sequence or ' 

762 f'mapping, not a "{type(node)}".')