Coverage for src/amisc/utils.py: 94%
284 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
1"""Provides some basic utilities for the package.
3Includes:
5- `to_model_dataset` — convert surrogate input/output dataset to a form usable by the true model
6- `to_surrogate_dataset` — convert true model input/output dataset to a form usable by the surrogate
7- `constrained_lls` — solve a constrained linear least squares problem
8- `search_for_file` — search for a file in the current working directory and additional search paths
9- `format_inputs` — broadcast and reshape all inputs to the same shape
10- `format_outputs` — reshape all outputs to a common loop shape
11- `parse_function_string` — convert function-like strings to arguments and keyword-arguments
12- `relative_error` — compute the relative L2 error between two vectors
13- `get_logger` — logging utility with nice formatting
14"""
15from __future__ import annotations
17import ast
18import copy
19import inspect
20import logging
21import re
22import sys
23from pathlib import Path
24from typing import TYPE_CHECKING
26import numpy as np
27import yaml
29__all__ = ['parse_function_string', 'relative_error', 'get_logger', 'format_inputs', 'format_outputs',
30 'search_for_file', 'constrained_lls', 'to_surrogate_dataset', 'to_model_dataset']
32from amisc.typing import COORDS_STR_ID, LATENT_STR_ID, Dataset
34if TYPE_CHECKING:
35 import amisc.variable
37LOG_FORMATTER = logging.Formatter(u"%(asctime)s — [%(levelname)s] — %(name)-15s — %(message)s")
40def _combine_latent_arrays(arr):
41 """Helper function to concatenate latent arrays into a single variable in the `arr` Dataset."""
42 for var in list(arr.keys()):
43 if LATENT_STR_ID in var: # extract latent variables from surrogate data
44 base_id = str(var).split(LATENT_STR_ID)[0]
45 arr[base_id] = arr[var][..., np.newaxis] if arr.get(base_id) is None else (
46 np.concatenate((arr[base_id], arr[var][..., np.newaxis]), axis=-1))
47 del arr[var]
50def to_surrogate_dataset(dataset: Dataset, variables: 'amisc.variable.VariableList', del_fields: bool = True,
51 **field_coords) -> tuple[Dataset, list[str]]:
52 """Convert true model input/output dataset to a form usable by the surrogate. Primarily, compress field
53 quantities and normalize.
55 :param dataset: the dataset to convert
56 :param variables: the `VariableList` containing the variable objects used in `dataset` -- these objects define
57 the normalization and compression methods to use for each variable
58 :param del_fields: whether to delete the original field quantities from the dataset after compression
59 :param field_coords: pass in extra field qty coords as f'{var}_coords' for compression (optional)
60 :returns: the compressed/normalized dataset and a list of variable names to pass to surrogate
61 """
62 surr_vars = []
63 dataset = copy.deepcopy(dataset)
64 for var in variables:
65 # Only grab scalars in the dataset or field qtys if all fields are present
66 if var in dataset or (var.compression is not None and all([f in dataset for f in var.compression.fields])):
67 if var.compression is not None:
68 coords = dataset.get(f'{var}{COORDS_STR_ID}', field_coords.get(f'{var}{COORDS_STR_ID}', None))
69 latent = var.compress({field: dataset[field] for field in
70 var.compression.fields}, coords=coords)['latent'] # all fields must be present
71 for i in range(latent.shape[-1]):
72 dataset[f'{var.name}{LATENT_STR_ID}{i}'] = latent[..., i]
73 surr_vars.append(f'{var.name}{LATENT_STR_ID}{i}')
74 if del_fields:
75 for field in var.compression.fields:
76 del dataset[field]
77 if dataset.get(f'{var}{COORDS_STR_ID}', None) is not None:
78 del dataset[f'{var}{COORDS_STR_ID}']
79 else:
80 dataset[var.name] = var.normalize(dataset[var.name])
81 surr_vars.append(f'{var.name}')
83 return dataset, surr_vars
86def to_model_dataset(dataset: Dataset, variables: 'amisc.variable.VariableList', del_latent: bool = True,
87 **field_coords) -> tuple[Dataset, Dataset]:
88 """Convert surrogate input/output dataset to a form usable by the true model. Primarily, reconstruct
89 field quantities and denormalize.
91 :param dataset: the dataset to convert
92 :param variables: the `VariableList` containing the variable objects used in `dataset` -- these objects define
93 the normalization and compression methods to use for each variable
94 :param del_latent: whether to delete the latent variables from the dataset after reconstruction
95 :param field_coords: pass in extra field qty coords as f'{var}_coords' for reconstruction (optional)
96 :returns: the reconstructed/denormalized dataset and any field coordinates used during reconstruction
97 """
98 dataset = copy.deepcopy(dataset)
99 _combine_latent_arrays(dataset)
101 ret_coords = {}
102 for var in variables:
103 if var in dataset:
104 if var.compression is not None:
105 # coords = self.model_kwargs.get(f'{var.name}_coords', None)
106 coords = field_coords.get(f'{var}{COORDS_STR_ID}', None)
107 field = var.reconstruct({'latent': dataset[var]}, coords=coords)
108 if del_latent:
109 del dataset[var]
110 coords = field.pop('coords')
111 ret_coords[f'{var.name}{COORDS_STR_ID}'] = copy.deepcopy(coords)
112 dataset.update(field)
113 else:
114 dataset[var] = var.denormalize(dataset[var])
116 return dataset, ret_coords
119def constrained_lls(A: np.ndarray, b: np.ndarray, C: np.ndarray, d: np.ndarray) -> np.ndarray:
120 """Minimize $||Ax-b||_2$, subject to $Cx=d$, i.e. constrained linear least squares.
122 !!! Note
123 See [these lecture notes](http://www.seas.ucla.edu/~vandenbe/133A/lectures/cls.pdf) for more detail.
125 :param A: `(..., M, N)`, vandermonde matrix
126 :param b: `(..., M, 1)`, data
127 :param C: `(..., P, N)`, constraint operator
128 :param d: `(..., P, 1)`, constraint condition
129 :returns: `(..., N, 1)`, the solution parameter vector `x`
130 """
131 M = A.shape[-2]
132 dims = len(A.shape[:-2])
133 T_axes = tuple(np.arange(0, dims)) + (-1, -2)
134 Q, R = np.linalg.qr(np.concatenate((A, C), axis=-2))
135 Q1 = Q[..., :M, :]
136 Q2 = Q[..., M:, :]
137 Q1_T = np.transpose(Q1, axes=T_axes)
138 Q2_T = np.transpose(Q2, axes=T_axes)
139 Qtilde, Rtilde = np.linalg.qr(Q2_T)
140 Qtilde_T = np.transpose(Qtilde, axes=T_axes)
141 Rtilde_T_inv = np.linalg.pinv(np.transpose(Rtilde, axes=T_axes))
142 w = np.linalg.pinv(Rtilde) @ (Qtilde_T @ Q1_T @ b - Rtilde_T_inv @ d)
144 return np.linalg.pinv(R) @ (Q1_T @ b - Q2_T @ w)
147class _RidgeRegression:
148 """A simple class for ridge regression with closed-form solution."""
150 def __init__(self, alpha=1.0):
151 """Initialize the ridge regression model with the given regularization strength $\alpha$."""
152 self.alpha = alpha
153 self.weights = None
155 def fit(self, X: np.ndarray, y: np.ndarray):
156 """Fit the ridge regression model to the given data. Compute linear weights (with intercept)
157 of shape `(n_features + 1, n_targets)`.
159 $w = (X^T X + \alpha I)^{-1} X^T y$
161 :param X: the design matrix of shape `(n_samples, n_features)`
162 :param y: the target values of shape `(n_samples, n_targets)`
163 """
164 n_samples, n_features = X.shape
166 # Add bias term (column of ones) to the design matrix for intercept
167 X_bias = np.hstack([np.ones((n_samples, 1)), X])
169 # Regularization matrix (identity matrix with top-left value zero for intercept term)
170 identity = np.eye(n_features + 1)
171 identity[0, 0] = 0
173 # Closed-form solution (normal equation) for ridge regression
174 A = X_bias.T @ X_bias + self.alpha * identity
175 B = X_bias.T @ y
176 self.weights = np.linalg.solve(A, B)
178 def predict(self, X: np.ndarray):
179 """Compute the predicted target values for the given input data.
181 :param X: the input data of shape `(n_samples, n_features)`
182 :returns: the predicted target values of shape `(n_samples, n_targets)`
183 """
184 if self.weights is None:
185 raise ValueError("Model is not fitted yet. Call 'fit' with appropriate arguments before using this method.")
187 n_samples, n_features = X.shape
189 # Add bias term (column of ones) to the design matrix for intercept
190 X_bias = np.hstack([np.ones((n_samples, 1)), X])
192 return X_bias @ self.weights
195def _inspect_function(func):
196 """Try to inspect the inputs and outputs of a callable function.
198 !!! Example
199 ```python
200 def my_func(a, b, c, **kwargs):
201 # Do something
202 return y1, y2
204 _inspect_function(my_func)
205 # Returns (['a', 'b', 'c'], ['y1', 'y2'])
206 ```
208 :param func: The callable function to inspect.
209 :returns: A tuple of the positional arguments and return values of the function.
210 """
211 try:
212 sig = inspect.signature(func)
213 pos_args = [param.name for param in sig.parameters.values() if param.default == param.empty
214 and param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY)]
216 try:
217 source = inspect.getsource(func).strip()
218 except OSError:
219 from dill.source import getsource # inspect in IDLE using dill
220 source = getsource(func).strip()
222 tree = ast.parse(source)
224 # Find the return values
225 class ReturnVisitor(ast.NodeVisitor):
226 def __init__(self):
227 self.return_values = []
229 def visit_Return(self, node):
230 if isinstance(node.value, ast.Tuple):
231 self.return_values = [elt.id for elt in node.value.elts]
232 elif isinstance(node.value, ast.Name):
233 self.return_values = [node.value.id]
234 else:
235 self.return_values = []
237 return_visitor = ReturnVisitor()
238 return_visitor.visit(tree)
240 return pos_args, return_visitor.return_values
241 except Exception:
242 return [], []
245def _inspect_assignment(class_name: str, stack_idx: int = 2) -> str | None:
246 """Return the left-hand side of an assignment like "x = class_name(...)".
248 !!! Example
249 ```python
250 class MyClass:
251 def __init__(self):
252 self.name = _inspect_assignment('MyClass')
253 obj = MyClass()
254 print(obj.name)
255 # Output: 'obj'
256 ```
258 This function will do it's best to only return single assignments (i.e. `x = MyClass()`) and not more
259 complex expressions like list comprehension or tuple unpacking. If the assignment is not found or an error occurs
260 during inspection, it will return `None`.
262 :param class_name: the name of the class that is being constructed
263 :param stack_idx: the index of the stack frame to inspect (default is 2 since you likely call this from
264 inside the class constructor, so you need to go back one more frame from that to find
265 the original assignment caller).
266 :returns: the variable name assigned to the class constructor (or `None`)
267 """
268 variable_name = None
269 try:
270 stack = inspect.stack()
271 frame_info = stack[stack_idx]
272 code_line = frame_info.code_context[frame_info.index].strip()
273 parsed_code = ast.parse(code_line)
274 if isinstance(parsed_code.body[0], ast.Assign):
275 assignment = parsed_code.body[0]
276 if len(assignment.targets) == 1 and isinstance(assignment.targets[0], ast.Name):
277 target_name = assignment.targets[0].id
278 if isinstance(assignment.value, ast.Call) and isinstance(assignment.value.func, ast.Name):
279 if assignment.value.func.id == class_name:
280 variable_name = target_name
281 except Exception:
282 variable_name = None
283 finally:
284 # del current_frame, caller_frame
285 return variable_name
288def _get_yaml_path(yaml_obj: yaml.Loader | yaml.Dumper):
289 """Get the path to the YAML file being loaded or dumped."""
290 try:
291 save_path = Path(yaml_obj.stream.name).parent
292 save_file = Path(yaml_obj.stream.name).with_suffix('')
293 except Exception:
294 save_path = Path('.')
295 save_file = 'yaml'
296 return save_path, save_file
299def search_for_file(filename: str | Path, search_paths=None):
300 """Search for the given filename in the current working directory and any additional search paths provided.
302 :param filename: the filename to search for
303 :param search_paths: paths to try and find the file in
304 :returns: the full path to the file if found, otherwise the original `filename`
305 """
306 if not isinstance(filename, str | Path):
307 return filename
309 search_paths = search_paths or []
310 search_paths.append('.')
312 save_file = Path(filename)
313 need_to_search = True
314 try:
315 need_to_search = ((len(save_file.parts) == 1 and len(save_file.suffix) > 0) or
316 (len(save_file.parts) > 1 and not save_file.exists()))
317 except Exception:
318 need_to_search = False
320 # Search for the save file if it was a valid path and does not exist
321 if need_to_search:
322 found_file = False
323 name = save_file.name
324 for path in search_paths:
325 if (pth := Path(path) / name).exists():
326 filename = pth.resolve().as_posix()
327 found_file = True
328 break
329 if not found_file:
330 pass # Let the caller handle the error (just return the original filename back to caller)
331 # raise FileNotFoundError(f"Could not find save file '{filename}' in paths: {search_paths}.")
333 return filename
336def format_inputs(inputs: Dataset, var_shape: dict = None) -> tuple[Dataset, tuple[int, ...]]:
337 """Broadcast and reshape all inputs to the same shape. Loop shape is inferred from broadcasting the leading dims
338 of all input arrays. Input arrays are broadcast to this shape and then flattened.
340 !!! Example
341 ```python
342 inputs = {'x': np.random.rand(10, 1, 5), 'y': np.random.rand(1, 1), 'z': np.random.rand(1, 20, 3)}
343 fmt_inputs, loop_shape = format_inputs(inputs)
344 # Output: {'x': np.ndarray(200, 5), 'y': np.ndarray(200,), 'z': np.ndarray(200, 3)}, (10, 20)
345 ```
347 :param inputs: `dict` of input arrays
348 :param var_shape: `dict` of expected input variable shapes (i.e. for field quantities); assumes all inputs are 1d
349 if None or not specified (i.e. scalar)
350 :returns: the reshaped inputs and the common loop shape
351 """
352 var_shape = var_shape or {}
354 def _common_shape(shape1, shape2):
355 """Find the common leading dimensions between two shapes (with np broadcasting rules)."""
356 min_len = min(len(shape1), len(shape2))
357 common_shape = []
358 for i in range(min_len):
359 if shape1[i] == shape2[i]:
360 common_shape.append(shape1[i])
361 elif shape1[i] == 1:
362 common_shape.append(shape2[i])
363 elif shape2[i] == 1:
364 common_shape.append(shape1[i])
365 else:
366 break
367 return tuple(common_shape)
369 def _shorten_shape(name, array):
370 """Remove extra variable dimensions from the end of the array shape (i.e. field quantity dimensions)."""
371 shape = var_shape.get(name, None)
372 if shape is not None and len(shape) > 0:
373 if len(shape) > len(array.shape):
374 raise ValueError(f"Variable '{name}' shape {shape} is longer than input array shape {array.shape}. "
375 f"The input array for '{name}' should have at least {len(shape)} dimensions.")
376 return array.shape[:-len(shape)]
377 else:
378 return array.shape
380 # Get the common "loop" dimensions from all input arrays
381 inputs = {name: np.atleast_1d(value) for name, value in inputs.items()}
382 name, array = next(iter(inputs.items()))
383 loop_shape = _shorten_shape(name, array)
384 for name, array in inputs.items():
385 array_shape = _shorten_shape(name, array)
386 loop_shape = _common_shape(loop_shape, array_shape)
387 if not loop_shape:
388 break
389 N = np.prod(loop_shape)
390 common_dim_cnt = len(loop_shape)
392 # Flatten and broadcast all inputs to the common shape
393 ret_inputs = {}
394 for var_id, array in inputs.items():
395 if common_dim_cnt > 0:
396 broadcast_shape = np.broadcast_shapes(loop_shape, array.shape[:common_dim_cnt])
397 broadcast_shape += array.shape[common_dim_cnt:]
398 ret_inputs[var_id] = np.broadcast_to(array, broadcast_shape).reshape((N, *array.shape[common_dim_cnt:]))
399 else:
400 ret_inputs[var_id] = array
402 return ret_inputs, loop_shape
405def format_outputs(outputs: Dataset, loop_shape: tuple[int, ...]) -> Dataset:
406 """Reshape all outputs to the common loop shape. Loop shape is as obtained from a call to `format_inputs`.
407 Assumes that all outputs are the same along the first dimension. This first dimension gets reshaped back into
408 the `loop_shape`. Singleton outputs are squeezed along the last dimension. A singleton loop shape is squeezed
409 along the first dimension.
411 !!! Example
412 ```python
413 outputs = {'x': np.random.rand(10, 1, 5), 'y': np.random.rand(10, 1), 'z': np.random.rand(10, 20, 3)}
414 loop_shape = (2, 5)
415 fmt_outputs = format_outputs(outputs, loop_shape)
416 # Output: {'x': np.ndarray(2, 5, 1, 5), 'y': np.ndarray(2, 5), 'z': np.ndarray(200, 3)}, (2, 5, 20, 3)
417 ```
419 :param outputs: `dict` of output arrays
420 :param loop_shape: the common leading dimensions to reshape the output arrays to
421 :returns: the reshaped outputs
422 """
423 output_dict = {}
424 for key, val in outputs.items():
425 val = np.atleast_1d(val)
426 output_shape = val.shape[1:] # Assumes (N, ...) output shape to start with
427 val = val.reshape(loop_shape + output_shape)
428 if output_shape == (1,):
429 val = np.atleast_1d(np.squeeze(val, axis=-1)) # Squeeze singleton outputs
430 if loop_shape == (1,):
431 val = np.atleast_1d(np.squeeze(val, axis=0)) # Squeeze singleton loop dimensions
432 output_dict[key] = val
433 return output_dict
436def _tokenize(args_str: str) -> list[str]:
437 """
438 Helper function to extract tokens from a string of arguments while respecting nested structures.
440 This function processes a string of arguments and splits it into individual tokens, ensuring that nested
441 structures such as parentheses, brackets, and quotes are correctly handled.
443 :param args_str: The string of arguments to tokenize
444 :return: A list of tokens extracted from the input string
446 !!! Example
447 ```python
448 args_str = "func(1, 2), {'key': 'value'}, [1, 2, 3]"
449 _tokenize(args_str)
450 # Output: ['func(1, 2)', "{'key': 'value'}", '[1, 2, 3]']
451 ```
452 """
453 if args_str is None or len(args_str) == 0:
454 return []
455 tokens = []
456 current_token = []
457 brace_depth = 0
458 in_string = False
460 i = 0
461 while i < len(args_str):
462 char = args_str[i]
463 if char in ('"', "'") and (i == 0 or args_str[i - 1] != '\\'): # Toggle string state
464 in_string = not in_string
465 current_token.append(char)
466 elif in_string:
467 current_token.append(char)
468 elif char in '([{':
469 brace_depth += 1
470 current_token.append(char)
471 elif char in ')]}':
472 brace_depth -= 1
473 current_token.append(char)
474 elif char == ',' and brace_depth == 0:
475 if current_token:
476 tokens.append(''.join(current_token).strip())
477 current_token = []
478 else:
479 current_token.append(char)
480 i += 1
482 # Add last token
483 if current_token:
484 tokens.append(''.join(current_token).strip())
486 return tokens
489def parse_function_string(call_string: str) -> tuple[str, list, dict]:
490 """Convert a function signature like `func(a, b, key=value)` to name, args, kwargs.
492 :param call_string: a function-like string to parse
493 :returns: the function name, positional arguments, and keyword arguments
494 """
495 # Regex pattern to match function name and arguments
496 pattern = r"(\w+)(?:\((.*)\))?"
497 match = re.match(pattern, call_string.strip())
499 if not match:
500 raise ValueError(f"Function string '{call_string}' is not valid.")
502 # Extracting name and arguments section
503 name = match.group(1)
504 args_str = match.group(2)
506 # Regex to split arguments respecting parentheses and quotes
507 # arg_pattern = re.compile(r'''((?:[^,'"()\[\]{}*]+|'[^']*'|"(?:\\.|[^"\\])*"|\([^)]*\)|\[[^\]]*\]|\{[^{}]*\}|\*)+|,)''') # noqa: E501
508 # pieces = [piece.strip() for piece in arg_pattern.findall(args_str) if piece.strip() != ',']
509 pieces = _tokenize(args_str)
511 args = []
512 kwargs = {}
513 keyword_only = False
515 for piece in pieces:
516 if piece == '/':
517 continue
518 elif piece == '*':
519 keyword_only = True
520 elif '=' in piece and (piece.index('=') < piece.find('{') or piece.find('{') == -1):
521 key, val = piece.split('=', 1)
522 kwargs[key.strip()] = ast.literal_eval(val.strip())
523 keyword_only = True
524 else:
525 if keyword_only:
526 raise ValueError("Positional arguments cannot follow keyword arguments.")
527 args.append(ast.literal_eval(piece))
529 return name, args, kwargs
532def relative_error(pred, targ, axis=None, skip_nan=False):
533 """Compute the relative L2 error between two vectors along the given axis.
535 :param pred: the predicted values
536 :param targ: the target values
537 :param axis: the axis along which to compute the error
538 :param skip_nan: whether to skip NaN values in the error calculation
539 :returns: the relative L2 error
540 """
541 with np.errstate(divide='ignore', invalid='ignore'):
542 sum_func = np.nansum if skip_nan else np.sum
543 err = np.sqrt(sum_func((pred - targ)**2, axis=axis) / sum_func(targ**2, axis=axis))
544 return np.nan_to_num(err, nan=np.nan, posinf=np.nan, neginf=np.nan)
547def get_logger(name: str, stdout: bool = True, log_file: str | Path = None,
548 level: int = logging.INFO) -> logging.Logger:
549 """Return a file/stdout logger with the given name.
551 :param name: the name of the logger to return
552 :param stdout: whether to add a stdout stream handler to the logger
553 :param log_file: add file logging to this file (optional)
554 :param level: the logging level to set
555 :returns: the logger
556 """
557 logger = logging.getLogger(name)
558 logger.setLevel(level)
559 logger.handlers.clear()
560 if stdout:
561 std_handler = logging.StreamHandler(sys.stdout)
562 std_handler.setFormatter(LOG_FORMATTER)
563 logger.addHandler(std_handler)
564 if log_file is not None:
565 f_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
566 f_handler.setLevel(level)
567 f_handler.setFormatter(LOG_FORMATTER)
568 logger.addHandler(f_handler)
570 return logger