Coverage for src/amisc/typing.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-01-24 04:51 +0000

1"""Module with type hints for the AMISC package. 

2 

3Includes: 

4 

5- `MultiIndex` — tuples of integers or similar string representations 

6- `Dataset` — a type hint for the input/output `dicts` handled by `Component.model` 

7- `TrainIteration` — the results of a single training iteration 

8- `CompressionData` — a dictionary spec for passing data to/from `Variable.compress()` 

9- `LATENT_STR_ID` — a string identifier for latent coefficients of field quantities 

10- `COORDS_STR_ID` — a string identifier for coordinate locations of field quantities 

11""" 

12import ast as _ast 

13from pathlib import Path as _Path 

14from typing import Optional as _Optional 

15 

16import numpy as _np 

17from numpy.typing import ArrayLike as _ArrayLike 

18from typing_extensions import TypedDict as _TypedDict 

19 

20__all__ = ["MultiIndex", "Dataset", "TrainIteration", "CompressionData", "LATENT_STR_ID", "COORDS_STR_ID"] 

21 

22LATENT_STR_ID = "_LATENT" # String identifier for latent variables 

23COORDS_STR_ID = "_coords" # String identifier for coordinate locations 

24 

25 

26class MultiIndex(tuple): 

27 """A multi-index is a tuple of integers, can be converted from a string.""" 

28 def __new__(cls, __tuple=()): 

29 if isinstance(__tuple, str): 

30 return super().__new__(cls, map(int, _ast.literal_eval(__tuple))) 

31 else: 

32 return super().__new__(cls, map(int, __tuple)) 

33 

34 

35class Dataset(_TypedDict, total=False): 

36 """Type hint for the input/output `dicts` of a call to `Component.model`. The keys are the variable names and the 

37 values are the corresponding `np.ndarrays`. There are also a few special keys that can be returned by the model 

38 that are described below. 

39 

40 The model can return additional items that are not part of `Component.outputs`. These items are returned as object 

41 arrays in the output. 

42 

43 This data structure is very similar to the `Dataset` class in the `xarray` package. Later versions might 

44 consider migrating to `xarray` for more advanced data manipulation. 

45 

46 :ivar model_cost: the computational cost (seconds of CPU time) of a single model evaluation 

47 :ivar output_path: the path to the output file or directory written by the model 

48 :ivar errors: a `dict` with the indices where the model evaluation failed with context about the errors 

49 """ 

50 model_cost: float | list | _ArrayLike 

51 output_path: str | _Path 

52 errors: dict 

53 

54 

55class TrainIteration(_TypedDict): 

56 """Gives the results of a single training iteration. 

57 

58 :ivar component: the name of the component selected for refinement at this iteration 

59 :ivar alpha: the selected candidate model fidelity multi-index 

60 :ivar beta: the selected candidate surrogate fidelity multi-index 

61 :ivar num_evals: the number of model evaluations performed during this iteration 

62 :ivar added_cost: the total added computational cost of the new model evaluations (CPU time in seconds) 

63 :ivar added_error: the error/difference between the refined surrogate and the previous surrogate 

64 :ivar test_error: the error of the refined surrogate on the test set (optional) 

65 :ivar overhead_s: the algorithmic overhead wall time in seconds for the training iteration 

66 :ivar model_s: the total wall time in seconds for the model evaluations for the training iteration 

67 """ 

68 component: str 

69 alpha: MultiIndex 

70 beta: MultiIndex 

71 num_evals: int 

72 added_cost: float 

73 added_error: float 

74 test_error: _Optional[dict[str, float]] 

75 overhead_s: float 

76 model_s: float 

77 

78 

79class CompressionData(_TypedDict, total=False): 

80 """Configuration `dict` for passing compression data to/from `Variable.compress()`. 

81 

82 !!! Info "Field quantity shapes" 

83 Field quantity data can take on any arbitrary shape, which we indicate with `qty.shape`. For example, a 3d 

84 structured grid might have `qty.shape = (10, 15, 10)`. Unstructured data might just have `qty.shape = (N,)` 

85 for $N$ points in an unstructured grid. Regardless, `Variable.compress()` will flatten this and compress 

86 to a single latent vector of size `latent_size`. That is, `qty.shape` → `latent_size`. 

87 

88 !!! Note "Compression coordinates" 

89 Field quantity data must be specified along with its coordinate locations. If the coordinate locations are 

90 different from what was used when building the compression map (i.e. the SVD data matrix), then they will be 

91 interpolated to/from the SVD coordinates. 

92 

93 :ivar coords: `(qty.shape, dim)` the coordinate locations of the qty data; coordinates exist in `dim` space (e.g. 

94 `dim=2` for 2d Cartesian coordinates). Defaults to the coordinates used when building the construction 

95 map (i.e. the coordinates of the data in the SVD data matrix) 

96 :ivar latent: `(..., latent_size)` array of latent space coefficients for a field quantity; this is what is 

97 _returned_ by `Variable.compress()` and what is _expected_ as input by `Variable.reconstruct()`. 

98 :ivar qty: `(..., qty.shape)` array of uncompressed field quantity data for this qty within 

99 the `fields` list. Each qty in this list will be its own `key:value` pair in the 

100 `CompressionData` structure 

101 """ 

102 coords: _np.ndarray 

103 latent: _np.ndarray 

104 qty: _np.ndarray