Coverage for src/amisc/utils.py: 94%

284 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-01-24 04:51 +0000

1"""Provides some basic utilities for the package. 

2 

3Includes: 

4 

5- `to_model_dataset` — convert surrogate input/output dataset to a form usable by the true model 

6- `to_surrogate_dataset` — convert true model input/output dataset to a form usable by the surrogate 

7- `constrained_lls` — solve a constrained linear least squares problem 

8- `search_for_file` — search for a file in the current working directory and additional search paths 

9- `format_inputs` — broadcast and reshape all inputs to the same shape 

10- `format_outputs` — reshape all outputs to a common loop shape 

11- `parse_function_string` — convert function-like strings to arguments and keyword-arguments 

12- `relative_error` — compute the relative L2 error between two vectors 

13- `get_logger` — logging utility with nice formatting 

14""" 

15from __future__ import annotations 

16 

17import ast 

18import copy 

19import inspect 

20import logging 

21import re 

22import sys 

23from pathlib import Path 

24from typing import TYPE_CHECKING 

25 

26import numpy as np 

27import yaml 

28 

29__all__ = ['parse_function_string', 'relative_error', 'get_logger', 'format_inputs', 'format_outputs', 

30 'search_for_file', 'constrained_lls', 'to_surrogate_dataset', 'to_model_dataset'] 

31 

32from amisc.typing import COORDS_STR_ID, LATENT_STR_ID, Dataset 

33 

34if TYPE_CHECKING: 

35 import amisc.variable 

36 

37LOG_FORMATTER = logging.Formatter(u"%(asctime)s — [%(levelname)s] — %(name)-15s — %(message)s") 

38 

39 

40def _combine_latent_arrays(arr): 

41 """Helper function to concatenate latent arrays into a single variable in the `arr` Dataset.""" 

42 for var in list(arr.keys()): 

43 if LATENT_STR_ID in var: # extract latent variables from surrogate data 

44 base_id = str(var).split(LATENT_STR_ID)[0] 

45 arr[base_id] = arr[var][..., np.newaxis] if arr.get(base_id) is None else ( 

46 np.concatenate((arr[base_id], arr[var][..., np.newaxis]), axis=-1)) 

47 del arr[var] 

48 

49 

50def to_surrogate_dataset(dataset: Dataset, variables: 'amisc.variable.VariableList', del_fields: bool = True, 

51 **field_coords) -> tuple[Dataset, list[str]]: 

52 """Convert true model input/output dataset to a form usable by the surrogate. Primarily, compress field 

53 quantities and normalize. 

54 

55 :param dataset: the dataset to convert 

56 :param variables: the `VariableList` containing the variable objects used in `dataset` -- these objects define 

57 the normalization and compression methods to use for each variable 

58 :param del_fields: whether to delete the original field quantities from the dataset after compression 

59 :param field_coords: pass in extra field qty coords as f'{var}_coords' for compression (optional) 

60 :returns: the compressed/normalized dataset and a list of variable names to pass to surrogate 

61 """ 

62 surr_vars = [] 

63 dataset = copy.deepcopy(dataset) 

64 for var in variables: 

65 # Only grab scalars in the dataset or field qtys if all fields are present 

66 if var in dataset or (var.compression is not None and all([f in dataset for f in var.compression.fields])): 

67 if var.compression is not None: 

68 coords = dataset.get(f'{var}{COORDS_STR_ID}', field_coords.get(f'{var}{COORDS_STR_ID}', None)) 

69 latent = var.compress({field: dataset[field] for field in 

70 var.compression.fields}, coords=coords)['latent'] # all fields must be present 

71 for i in range(latent.shape[-1]): 

72 dataset[f'{var.name}{LATENT_STR_ID}{i}'] = latent[..., i] 

73 surr_vars.append(f'{var.name}{LATENT_STR_ID}{i}') 

74 if del_fields: 

75 for field in var.compression.fields: 

76 del dataset[field] 

77 if dataset.get(f'{var}{COORDS_STR_ID}', None) is not None: 

78 del dataset[f'{var}{COORDS_STR_ID}'] 

79 else: 

80 dataset[var.name] = var.normalize(dataset[var.name]) 

81 surr_vars.append(f'{var.name}') 

82 

83 return dataset, surr_vars 

84 

85 

86def to_model_dataset(dataset: Dataset, variables: 'amisc.variable.VariableList', del_latent: bool = True, 

87 **field_coords) -> tuple[Dataset, Dataset]: 

88 """Convert surrogate input/output dataset to a form usable by the true model. Primarily, reconstruct 

89 field quantities and denormalize. 

90 

91 :param dataset: the dataset to convert 

92 :param variables: the `VariableList` containing the variable objects used in `dataset` -- these objects define 

93 the normalization and compression methods to use for each variable 

94 :param del_latent: whether to delete the latent variables from the dataset after reconstruction 

95 :param field_coords: pass in extra field qty coords as f'{var}_coords' for reconstruction (optional) 

96 :returns: the reconstructed/denormalized dataset and any field coordinates used during reconstruction 

97 """ 

98 dataset = copy.deepcopy(dataset) 

99 _combine_latent_arrays(dataset) 

100 

101 ret_coords = {} 

102 for var in variables: 

103 if var in dataset: 

104 if var.compression is not None: 

105 # coords = self.model_kwargs.get(f'{var.name}_coords', None) 

106 coords = field_coords.get(f'{var}{COORDS_STR_ID}', None) 

107 field = var.reconstruct({'latent': dataset[var]}, coords=coords) 

108 if del_latent: 

109 del dataset[var] 

110 coords = field.pop('coords') 

111 ret_coords[f'{var.name}{COORDS_STR_ID}'] = copy.deepcopy(coords) 

112 dataset.update(field) 

113 else: 

114 dataset[var] = var.denormalize(dataset[var]) 

115 

116 return dataset, ret_coords 

117 

118 

119def constrained_lls(A: np.ndarray, b: np.ndarray, C: np.ndarray, d: np.ndarray) -> np.ndarray: 

120 """Minimize $||Ax-b||_2$, subject to $Cx=d$, i.e. constrained linear least squares. 

121 

122 !!! Note 

123 See [these lecture notes](http://www.seas.ucla.edu/~vandenbe/133A/lectures/cls.pdf) for more detail. 

124 

125 :param A: `(..., M, N)`, vandermonde matrix 

126 :param b: `(..., M, 1)`, data 

127 :param C: `(..., P, N)`, constraint operator 

128 :param d: `(..., P, 1)`, constraint condition 

129 :returns: `(..., N, 1)`, the solution parameter vector `x` 

130 """ 

131 M = A.shape[-2] 

132 dims = len(A.shape[:-2]) 

133 T_axes = tuple(np.arange(0, dims)) + (-1, -2) 

134 Q, R = np.linalg.qr(np.concatenate((A, C), axis=-2)) 

135 Q1 = Q[..., :M, :] 

136 Q2 = Q[..., M:, :] 

137 Q1_T = np.transpose(Q1, axes=T_axes) 

138 Q2_T = np.transpose(Q2, axes=T_axes) 

139 Qtilde, Rtilde = np.linalg.qr(Q2_T) 

140 Qtilde_T = np.transpose(Qtilde, axes=T_axes) 

141 Rtilde_T_inv = np.linalg.pinv(np.transpose(Rtilde, axes=T_axes)) 

142 w = np.linalg.pinv(Rtilde) @ (Qtilde_T @ Q1_T @ b - Rtilde_T_inv @ d) 

143 

144 return np.linalg.pinv(R) @ (Q1_T @ b - Q2_T @ w) 

145 

146 

147class _RidgeRegression: 

148 """A simple class for ridge regression with closed-form solution.""" 

149 

150 def __init__(self, alpha=1.0): 

151 """Initialize the ridge regression model with the given regularization strength $\alpha$.""" 

152 self.alpha = alpha 

153 self.weights = None 

154 

155 def fit(self, X: np.ndarray, y: np.ndarray): 

156 """Fit the ridge regression model to the given data. Compute linear weights (with intercept) 

157 of shape `(n_features + 1, n_targets)`. 

158 

159 $w = (X^T X + \alpha I)^{-1} X^T y$ 

160 

161 :param X: the design matrix of shape `(n_samples, n_features)` 

162 :param y: the target values of shape `(n_samples, n_targets)` 

163 """ 

164 n_samples, n_features = X.shape 

165 

166 # Add bias term (column of ones) to the design matrix for intercept 

167 X_bias = np.hstack([np.ones((n_samples, 1)), X]) 

168 

169 # Regularization matrix (identity matrix with top-left value zero for intercept term) 

170 identity = np.eye(n_features + 1) 

171 identity[0, 0] = 0 

172 

173 # Closed-form solution (normal equation) for ridge regression 

174 A = X_bias.T @ X_bias + self.alpha * identity 

175 B = X_bias.T @ y 

176 self.weights = np.linalg.solve(A, B) 

177 

178 def predict(self, X: np.ndarray): 

179 """Compute the predicted target values for the given input data. 

180 

181 :param X: the input data of shape `(n_samples, n_features)` 

182 :returns: the predicted target values of shape `(n_samples, n_targets)` 

183 """ 

184 if self.weights is None: 

185 raise ValueError("Model is not fitted yet. Call 'fit' with appropriate arguments before using this method.") 

186 

187 n_samples, n_features = X.shape 

188 

189 # Add bias term (column of ones) to the design matrix for intercept 

190 X_bias = np.hstack([np.ones((n_samples, 1)), X]) 

191 

192 return X_bias @ self.weights 

193 

194 

195def _inspect_function(func): 

196 """Try to inspect the inputs and outputs of a callable function. 

197 

198 !!! Example 

199 ```python 

200 def my_func(a, b, c, **kwargs): 

201 # Do something 

202 return y1, y2 

203 

204 _inspect_function(my_func) 

205 # Returns (['a', 'b', 'c'], ['y1', 'y2']) 

206 ``` 

207 

208 :param func: The callable function to inspect. 

209 :returns: A tuple of the positional arguments and return values of the function. 

210 """ 

211 try: 

212 sig = inspect.signature(func) 

213 pos_args = [param.name for param in sig.parameters.values() if param.default == param.empty 

214 and param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY)] 

215 

216 try: 

217 source = inspect.getsource(func).strip() 

218 except OSError: 

219 from dill.source import getsource # inspect in IDLE using dill 

220 source = getsource(func).strip() 

221 

222 tree = ast.parse(source) 

223 

224 # Find the return values 

225 class ReturnVisitor(ast.NodeVisitor): 

226 def __init__(self): 

227 self.return_values = [] 

228 

229 def visit_Return(self, node): 

230 if isinstance(node.value, ast.Tuple): 

231 self.return_values = [elt.id for elt in node.value.elts] 

232 elif isinstance(node.value, ast.Name): 

233 self.return_values = [node.value.id] 

234 else: 

235 self.return_values = [] 

236 

237 return_visitor = ReturnVisitor() 

238 return_visitor.visit(tree) 

239 

240 return pos_args, return_visitor.return_values 

241 except Exception: 

242 return [], [] 

243 

244 

245def _inspect_assignment(class_name: str, stack_idx: int = 2) -> str | None: 

246 """Return the left-hand side of an assignment like "x = class_name(...)". 

247 

248 !!! Example 

249 ```python 

250 class MyClass: 

251 def __init__(self): 

252 self.name = _inspect_assignment('MyClass') 

253 obj = MyClass() 

254 print(obj.name) 

255 # Output: 'obj' 

256 ``` 

257 

258 This function will do it's best to only return single assignments (i.e. `x = MyClass()`) and not more 

259 complex expressions like list comprehension or tuple unpacking. If the assignment is not found or an error occurs 

260 during inspection, it will return `None`. 

261 

262 :param class_name: the name of the class that is being constructed 

263 :param stack_idx: the index of the stack frame to inspect (default is 2 since you likely call this from 

264 inside the class constructor, so you need to go back one more frame from that to find 

265 the original assignment caller). 

266 :returns: the variable name assigned to the class constructor (or `None`) 

267 """ 

268 variable_name = None 

269 try: 

270 stack = inspect.stack() 

271 frame_info = stack[stack_idx] 

272 code_line = frame_info.code_context[frame_info.index].strip() 

273 parsed_code = ast.parse(code_line) 

274 if isinstance(parsed_code.body[0], ast.Assign): 

275 assignment = parsed_code.body[0] 

276 if len(assignment.targets) == 1 and isinstance(assignment.targets[0], ast.Name): 

277 target_name = assignment.targets[0].id 

278 if isinstance(assignment.value, ast.Call) and isinstance(assignment.value.func, ast.Name): 

279 if assignment.value.func.id == class_name: 

280 variable_name = target_name 

281 except Exception: 

282 variable_name = None 

283 finally: 

284 # del current_frame, caller_frame 

285 return variable_name 

286 

287 

288def _get_yaml_path(yaml_obj: yaml.Loader | yaml.Dumper): 

289 """Get the path to the YAML file being loaded or dumped.""" 

290 try: 

291 save_path = Path(yaml_obj.stream.name).parent 

292 save_file = Path(yaml_obj.stream.name).with_suffix('') 

293 except Exception: 

294 save_path = Path('.') 

295 save_file = 'yaml' 

296 return save_path, save_file 

297 

298 

299def search_for_file(filename: str | Path, search_paths=None): 

300 """Search for the given filename in the current working directory and any additional search paths provided. 

301 

302 :param filename: the filename to search for 

303 :param search_paths: paths to try and find the file in 

304 :returns: the full path to the file if found, otherwise the original `filename` 

305 """ 

306 if not isinstance(filename, str | Path): 

307 return filename 

308 

309 search_paths = search_paths or [] 

310 search_paths.append('.') 

311 

312 save_file = Path(filename) 

313 need_to_search = True 

314 try: 

315 need_to_search = ((len(save_file.parts) == 1 and len(save_file.suffix) > 0) or 

316 (len(save_file.parts) > 1 and not save_file.exists())) 

317 except Exception: 

318 need_to_search = False 

319 

320 # Search for the save file if it was a valid path and does not exist 

321 if need_to_search: 

322 found_file = False 

323 name = save_file.name 

324 for path in search_paths: 

325 if (pth := Path(path) / name).exists(): 

326 filename = pth.resolve().as_posix() 

327 found_file = True 

328 break 

329 if not found_file: 

330 pass # Let the caller handle the error (just return the original filename back to caller) 

331 # raise FileNotFoundError(f"Could not find save file '{filename}' in paths: {search_paths}.") 

332 

333 return filename 

334 

335 

336def format_inputs(inputs: Dataset, var_shape: dict = None) -> tuple[Dataset, tuple[int, ...]]: 

337 """Broadcast and reshape all inputs to the same shape. Loop shape is inferred from broadcasting the leading dims 

338 of all input arrays. Input arrays are broadcast to this shape and then flattened. 

339 

340 !!! Example 

341 ```python 

342 inputs = {'x': np.random.rand(10, 1, 5), 'y': np.random.rand(1, 1), 'z': np.random.rand(1, 20, 3)} 

343 fmt_inputs, loop_shape = format_inputs(inputs) 

344 # Output: {'x': np.ndarray(200, 5), 'y': np.ndarray(200,), 'z': np.ndarray(200, 3)}, (10, 20) 

345 ``` 

346 

347 :param inputs: `dict` of input arrays 

348 :param var_shape: `dict` of expected input variable shapes (i.e. for field quantities); assumes all inputs are 1d 

349 if None or not specified (i.e. scalar) 

350 :returns: the reshaped inputs and the common loop shape 

351 """ 

352 var_shape = var_shape or {} 

353 

354 def _common_shape(shape1, shape2): 

355 """Find the common leading dimensions between two shapes (with np broadcasting rules).""" 

356 min_len = min(len(shape1), len(shape2)) 

357 common_shape = [] 

358 for i in range(min_len): 

359 if shape1[i] == shape2[i]: 

360 common_shape.append(shape1[i]) 

361 elif shape1[i] == 1: 

362 common_shape.append(shape2[i]) 

363 elif shape2[i] == 1: 

364 common_shape.append(shape1[i]) 

365 else: 

366 break 

367 return tuple(common_shape) 

368 

369 def _shorten_shape(name, array): 

370 """Remove extra variable dimensions from the end of the array shape (i.e. field quantity dimensions).""" 

371 shape = var_shape.get(name, None) 

372 if shape is not None and len(shape) > 0: 

373 if len(shape) > len(array.shape): 

374 raise ValueError(f"Variable '{name}' shape {shape} is longer than input array shape {array.shape}. " 

375 f"The input array for '{name}' should have at least {len(shape)} dimensions.") 

376 return array.shape[:-len(shape)] 

377 else: 

378 return array.shape 

379 

380 # Get the common "loop" dimensions from all input arrays 

381 inputs = {name: np.atleast_1d(value) for name, value in inputs.items()} 

382 name, array = next(iter(inputs.items())) 

383 loop_shape = _shorten_shape(name, array) 

384 for name, array in inputs.items(): 

385 array_shape = _shorten_shape(name, array) 

386 loop_shape = _common_shape(loop_shape, array_shape) 

387 if not loop_shape: 

388 break 

389 N = np.prod(loop_shape) 

390 common_dim_cnt = len(loop_shape) 

391 

392 # Flatten and broadcast all inputs to the common shape 

393 ret_inputs = {} 

394 for var_id, array in inputs.items(): 

395 if common_dim_cnt > 0: 

396 broadcast_shape = np.broadcast_shapes(loop_shape, array.shape[:common_dim_cnt]) 

397 broadcast_shape += array.shape[common_dim_cnt:] 

398 ret_inputs[var_id] = np.broadcast_to(array, broadcast_shape).reshape((N, *array.shape[common_dim_cnt:])) 

399 else: 

400 ret_inputs[var_id] = array 

401 

402 return ret_inputs, loop_shape 

403 

404 

405def format_outputs(outputs: Dataset, loop_shape: tuple[int, ...]) -> Dataset: 

406 """Reshape all outputs to the common loop shape. Loop shape is as obtained from a call to `format_inputs`. 

407 Assumes that all outputs are the same along the first dimension. This first dimension gets reshaped back into 

408 the `loop_shape`. Singleton outputs are squeezed along the last dimension. A singleton loop shape is squeezed 

409 along the first dimension. 

410 

411 !!! Example 

412 ```python 

413 outputs = {'x': np.random.rand(10, 1, 5), 'y': np.random.rand(10, 1), 'z': np.random.rand(10, 20, 3)} 

414 loop_shape = (2, 5) 

415 fmt_outputs = format_outputs(outputs, loop_shape) 

416 # Output: {'x': np.ndarray(2, 5, 1, 5), 'y': np.ndarray(2, 5), 'z': np.ndarray(200, 3)}, (2, 5, 20, 3) 

417 ``` 

418 

419 :param outputs: `dict` of output arrays 

420 :param loop_shape: the common leading dimensions to reshape the output arrays to 

421 :returns: the reshaped outputs 

422 """ 

423 output_dict = {} 

424 for key, val in outputs.items(): 

425 val = np.atleast_1d(val) 

426 output_shape = val.shape[1:] # Assumes (N, ...) output shape to start with 

427 val = val.reshape(loop_shape + output_shape) 

428 if output_shape == (1,): 

429 val = np.atleast_1d(np.squeeze(val, axis=-1)) # Squeeze singleton outputs 

430 if loop_shape == (1,): 

431 val = np.atleast_1d(np.squeeze(val, axis=0)) # Squeeze singleton loop dimensions 

432 output_dict[key] = val 

433 return output_dict 

434 

435 

436def _tokenize(args_str: str) -> list[str]: 

437 """ 

438 Helper function to extract tokens from a string of arguments while respecting nested structures. 

439 

440 This function processes a string of arguments and splits it into individual tokens, ensuring that nested 

441 structures such as parentheses, brackets, and quotes are correctly handled. 

442 

443 :param args_str: The string of arguments to tokenize 

444 :return: A list of tokens extracted from the input string 

445 

446 !!! Example 

447 ```python 

448 args_str = "func(1, 2), {'key': 'value'}, [1, 2, 3]" 

449 _tokenize(args_str) 

450 # Output: ['func(1, 2)', "{'key': 'value'}", '[1, 2, 3]'] 

451 ``` 

452 """ 

453 if args_str is None or len(args_str) == 0: 

454 return [] 

455 tokens = [] 

456 current_token = [] 

457 brace_depth = 0 

458 in_string = False 

459 

460 i = 0 

461 while i < len(args_str): 

462 char = args_str[i] 

463 if char in ('"', "'") and (i == 0 or args_str[i - 1] != '\\'): # Toggle string state 

464 in_string = not in_string 

465 current_token.append(char) 

466 elif in_string: 

467 current_token.append(char) 

468 elif char in '([{': 

469 brace_depth += 1 

470 current_token.append(char) 

471 elif char in ')]}': 

472 brace_depth -= 1 

473 current_token.append(char) 

474 elif char == ',' and brace_depth == 0: 

475 if current_token: 

476 tokens.append(''.join(current_token).strip()) 

477 current_token = [] 

478 else: 

479 current_token.append(char) 

480 i += 1 

481 

482 # Add last token 

483 if current_token: 

484 tokens.append(''.join(current_token).strip()) 

485 

486 return tokens 

487 

488 

489def parse_function_string(call_string: str) -> tuple[str, list, dict]: 

490 """Convert a function signature like `func(a, b, key=value)` to name, args, kwargs. 

491 

492 :param call_string: a function-like string to parse 

493 :returns: the function name, positional arguments, and keyword arguments 

494 """ 

495 # Regex pattern to match function name and arguments 

496 pattern = r"(\w+)(?:\((.*)\))?" 

497 match = re.match(pattern, call_string.strip()) 

498 

499 if not match: 

500 raise ValueError(f"Function string '{call_string}' is not valid.") 

501 

502 # Extracting name and arguments section 

503 name = match.group(1) 

504 args_str = match.group(2) 

505 

506 # Regex to split arguments respecting parentheses and quotes 

507 # arg_pattern = re.compile(r'''((?:[^,'"()\[\]{}*]+|'[^']*'|"(?:\\.|[^"\\])*"|\([^)]*\)|\[[^\]]*\]|\{[^{}]*\}|\*)+|,)''') # noqa: E501 

508 # pieces = [piece.strip() for piece in arg_pattern.findall(args_str) if piece.strip() != ','] 

509 pieces = _tokenize(args_str) 

510 

511 args = [] 

512 kwargs = {} 

513 keyword_only = False 

514 

515 for piece in pieces: 

516 if piece == '/': 

517 continue 

518 elif piece == '*': 

519 keyword_only = True 

520 elif '=' in piece and (piece.index('=') < piece.find('{') or piece.find('{') == -1): 

521 key, val = piece.split('=', 1) 

522 kwargs[key.strip()] = ast.literal_eval(val.strip()) 

523 keyword_only = True 

524 else: 

525 if keyword_only: 

526 raise ValueError("Positional arguments cannot follow keyword arguments.") 

527 args.append(ast.literal_eval(piece)) 

528 

529 return name, args, kwargs 

530 

531 

532def relative_error(pred, targ, axis=None, skip_nan=False): 

533 """Compute the relative L2 error between two vectors along the given axis. 

534 

535 :param pred: the predicted values 

536 :param targ: the target values 

537 :param axis: the axis along which to compute the error 

538 :param skip_nan: whether to skip NaN values in the error calculation 

539 :returns: the relative L2 error 

540 """ 

541 with np.errstate(divide='ignore', invalid='ignore'): 

542 sum_func = np.nansum if skip_nan else np.sum 

543 err = np.sqrt(sum_func((pred - targ)**2, axis=axis) / sum_func(targ**2, axis=axis)) 

544 return np.nan_to_num(err, nan=np.nan, posinf=np.nan, neginf=np.nan) 

545 

546 

547def get_logger(name: str, stdout: bool = True, log_file: str | Path = None, 

548 level: int = logging.INFO) -> logging.Logger: 

549 """Return a file/stdout logger with the given name. 

550 

551 :param name: the name of the logger to return 

552 :param stdout: whether to add a stdout stream handler to the logger 

553 :param log_file: add file logging to this file (optional) 

554 :param level: the logging level to set 

555 :returns: the logger 

556 """ 

557 logger = logging.getLogger(name) 

558 logger.setLevel(level) 

559 logger.handlers.clear() 

560 if stdout: 

561 std_handler = logging.StreamHandler(sys.stdout) 

562 std_handler.setFormatter(LOG_FORMATTER) 

563 logger.addHandler(std_handler) 

564 if log_file is not None: 

565 f_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') 

566 f_handler.setLevel(level) 

567 f_handler.setFormatter(LOG_FORMATTER) 

568 logger.addHandler(f_handler) 

569 

570 return logger