Coverage for src/amisc/compression.py: 93%

209 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-03-10 15:12 +0000

1"""Module for compression methods. 

2 

3Especially useful for field quantities with high dimensions. 

4 

5Includes: 

6 

7- `Compression` — an interface for specifying a compression method for field quantities. 

8- `SVD` — a Singular Value Decomposition (SVD) compression method. 

9""" 

10from __future__ import annotations 

11 

12from abc import ABC, abstractmethod 

13from dataclasses import dataclass, field 

14 

15import numpy as np 

16from scipy.interpolate import RBFInterpolator 

17 

18from amisc.serialize import PickleSerializable 

19from amisc.utils import relative_error 

20 

21__all__ = ["Compression", "SVD"] 

22 

23 

24@dataclass 

25class Compression(PickleSerializable, ABC): 

26 """Base class for compression methods. Compression methods should: 

27 

28 - `compute_map` - compute the compression map from provided data 

29 - `compress` - compress data into a latent space 

30 - `reconstruct` - reconstruct the compressed data back into the full space 

31 - `latent_size` - return the size of the latent space 

32 - `estimate_latent_ranges` - estimate the range of the latent space coefficients 

33 

34 !!! Note "Specifying fields" 

35 The `fields` attribute is a list of strings that specify the field quantities to compress. For example, for 

36 3D velocity data, the fields might be `['ux', 'uy', 'uz']`. The length of the 

37 `fields` attribute is used to determine the number of quantities of interest at each grid point in `coords`. 

38 Note that interpolation to/from the compression grid will assume a shape of `(num_pts, num_qoi)` for the 

39 states on the grid, where `num_qoi` is the length of `fields` and `num_pts` is the length of `coords`. When 

40 constructing the compression map, this important fact should be considered when passing data to 

41 `compute_map`. 

42 

43 In order to use a `Compression` object, you must first call `compute_map` to compute the compression map, which 

44 should set the private value `self._map_computed=True`. The `coords` of the compression grid must also be 

45 specified. The `coords` should have the shape `(num_pts, dim)` where `num_pts` is the number of points in the 

46 compression grid and `dim` is the number of spatial dimensions. If `coords` is a 1d array, then the `dim` is 

47 assumed to be 1. 

48 

49 :ivar fields: list of field quantities to compress 

50 :ivar method: the compression method to use (only svd is supported for now) 

51 :ivar coords: the coordinates of the compression grid 

52 :ivar interpolate_method: the interpolation method to use to interpolate to/from the compression grid 

53 (only `rbf` (i.e. radial basis function) is supported for now) 

54 :ivar interpolate_opts: additional options to pass to the interpolation method 

55 :ivar _map_computed: whether the compression map has been computed 

56 """ 

57 fields: list[str] = field(default_factory=list) 

58 method: str = 'svd' 

59 coords: np.ndarray = None # (num_pts, dim) 

60 interpolate_method: str = 'rbf' 

61 interpolate_opts: dict = field(default_factory=dict) 

62 _map_computed: bool = False 

63 

64 @property 

65 def map_exists(self): 

66 """All compression methods should have `coords` when their map has been constructed.""" 

67 return self.coords is not None and self._map_computed 

68 

69 @property 

70 def dim(self): 

71 """Number of physical grid coordinates for the field quantity, (i.e. x,y,z spatial dims)""" 

72 return self.coords.shape[1] if (self.coords is not None and len(self.coords.shape) > 1) else 1 

73 

74 @property 

75 def num_pts(self): 

76 """Number of physical points in the compression grid.""" 

77 return self.coords.shape[0] if self.coords is not None else None 

78 

79 @property 

80 def num_qoi(self): 

81 """Number of quantities of interest at each grid point, (i.e. `ux, uy, uz` for 3d velocity data).""" 

82 return len(self.fields) if self.fields is not None else 1 

83 

84 @property 

85 def dof(self): 

86 """Total degrees of freedom in the compression grid (i.e. `num_pts * num_qoi`).""" 

87 return self.num_pts * self.num_qoi if self.num_pts is not None else None 

88 

89 def _correct_coords(self, coords): 

90 """Correct the coordinates to be in the correct shape for compression.""" 

91 coords = np.atleast_1d(coords) 

92 if np.issubdtype(coords.dtype, np.object_): # must be object array of np.arrays (for unique coords) 

93 for i, arr in np.ndenumerate(coords): 

94 if arr is None: 

95 continue # skip empty values 

96 if len(arr.shape) == 1: 

97 coords[i] = arr[..., np.newaxis] if self.dim == 1 else arr[np.newaxis, ...] 

98 else: 

99 if len(coords.shape) == 1: 

100 coords = coords[..., np.newaxis] if self.dim == 1 else coords[np.newaxis, ...] 

101 return coords 

102 

103 def interpolator(self): 

104 """The interpolator to use during compression and reconstruction. Interpolator expects to be used as: 

105 

106 ```python 

107 xg = np.ndarray # (num_pts, dim) grid coordinates 

108 yg = np.ndarray # (num_pts, ...) scalar values on grid 

109 xp = np.ndarray # (Q, dim) evaluation points 

110 

111 interp = interpolate_method(xg, yg, **interpolate_opts) 

112 

113 yp = interp(xp) # (Q, ...) interpolated values 

114 ``` 

115 """ 

116 method = self.interpolate_method or 'rbf' 

117 match method.lower(): 

118 case 'rbf': 

119 return RBFInterpolator 

120 case other: 

121 raise NotImplementedError(f"Interpolation method '{other}' is not implemented.") 

122 

123 def interpolate_from_grid(self, states: np.ndarray, new_coords: np.ndarray): 

124 """Interpolate the states on the compression grid to new coordinates. 

125 

126 :param states: `(*loop_shape, dof)` - the states on the compression grid 

127 :param new_coords: `(*coord_shape, dim)` - the new coordinates to interpolate to; if a 1d object array, then 

128 each element is assumed to be a unique `(*coord_shape, dim)` array with assumed 

129 same length as loop_shape of the states -- will interpolate each state to the 

130 corresponding new coordinates 

131 :return: `dict` of `(*loop_shape, *coord_shape)` for each qoi - the interpolated states; will return a 

132 single 1d object array for each qoi if new_coords is a 1d object array 

133 """ 

134 new_coords = self._correct_coords(new_coords) 

135 grid_coords = self._correct_coords(self.coords) 

136 

137 coords_obj_array = np.issubdtype(new_coords.dtype, np.object_) 

138 

139 # Iterate over one set of coords and states at a time 

140 def _iterate_coords_and_states(): 

141 if coords_obj_array: 

142 for index, c in np.ndenumerate(new_coords): # assumes same number of coords and states 

143 yield index, c, states[index] 

144 else: 

145 yield (0,), new_coords, states # assumes same coords for all states 

146 

147 all_qois = np.empty(new_coords.shape if coords_obj_array else (1,), dtype=object) 

148 

149 # Do interpolation for each set of unique coordinates (if multiple) 

150 for j, n_coords, state in _iterate_coords_and_states(): 

151 if n_coords is None: # Skip empty coords 

152 continue 

153 

154 skip_interp = (n_coords.shape == grid_coords.shape and np.allclose(n_coords, grid_coords)) 

155 

156 ret_dict = {} 

157 loop_shape = state.shape[:-1] 

158 coords_shape = n_coords.shape[:-1] 

159 state = state.reshape((*loop_shape, self.num_pts, self.num_qoi)) 

160 n_coords = n_coords.reshape((-1, self.dim)) 

161 for i, qoi in enumerate(self.fields): 

162 if skip_interp: 

163 ret_dict[qoi] = state[..., i] 

164 else: 

165 reshaped_states = state[..., i].reshape(-1, self.num_pts).T # (num_pts, ...) 

166 interp = self.interpolator()(grid_coords, reshaped_states, **self.interpolate_opts) 

167 yp = interp(n_coords) 

168 ret_dict[qoi] = yp.T.reshape(*loop_shape, *coords_shape) 

169 

170 all_qois[j] = ret_dict 

171 

172 if coords_obj_array: 

173 # Make an object array for each qoi, where each element is a unique `(*loop_shape, *coord_shape)` array 

174 for _, _first_dict in np.ndenumerate(all_qois): 

175 if _first_dict is not None: 

176 break 

177 ret = {qoi: np.empty(all_qois.shape, dtype=object) for qoi in _first_dict} 

178 for qoi in ret: 

179 for index, qoi_dict in np.ndenumerate(all_qois): 

180 if qoi_dict is not None: 

181 ret[qoi][index] = qoi_dict[qoi] 

182 else: 

183 # Otherwise, all loop dims used the same coords, so just return the single array for each qoi 

184 ret = all_qois[0] 

185 

186 return ret 

187 

188 def interpolate_to_grid(self, field_coords: np.ndarray, field_values): 

189 """Interpolate the field values at given coordinates to the compression grid. An array of nan is returned 

190 for any coordinates or field values that are empty or None. 

191 

192 :param field_coords: `(*coord_shape, dim)` - the coordinates of the field values; if an object array, then 

193 each element is assumed to be a unique `(*coord_shape, dim)` array 

194 :param field_values: `dict` of `(*loop_shape, *coord_shape)` for each qoi - the field values at the coordinates; 

195 if each array is an object array, then each element is assumed to be a unique 

196 `(*loop_shape, *coord_shape)` array corresponding to the `field_coords` 

197 :return: `(*loop_shape, dof)` - the interpolated values on the compression grid 

198 """ 

199 field_coords = self._correct_coords(field_coords) 

200 grid_coords = self._correct_coords(self.coords) 

201 

202 # Loop over each set of coordinates and field values (multiple if they are object arrays) 

203 # If only one set of coords, then they are assumed the same for each set of field values 

204 coords_obj_array = np.issubdtype(field_coords.dtype, np.object_) 

205 fields_obj_array = np.issubdtype(next(iter(field_values.values())).dtype, np.object_) 

206 def _iterate_coords_and_fields(): 

207 if coords_obj_array: 

208 for index, c in np.ndenumerate(field_coords): # assumes same number of coords and field values 

209 yield index, c, {qoi: field_values[qoi][index] for qoi in field_values} 

210 elif fields_obj_array: 

211 for index in np.ndindex(next(iter(field_values.values())).shape): 

212 yield index, field_coords, {qoi: field_values[qoi][index] for qoi in field_values} 

213 else: 

214 yield (0,), field_coords, field_values # assumes same coords for all field values 

215 

216 if coords_obj_array: 

217 shape = field_coords.shape 

218 elif fields_obj_array: 

219 shape = next(iter(field_values.values())).shape 

220 else: 

221 shape = (1,) 

222 

223 always_skip_interp = not coords_obj_array and np.array_equal(field_coords, grid_coords) # must be exact match 

224 

225 all_states = np.empty(shape, dtype=object) # are you in good hands? 

226 

227 for j, f_coords, f_values in _iterate_coords_and_fields(): 

228 if f_coords is None or any([val is None for val in f_values.values()]): # Skip empty samples 

229 continue 

230 

231 skip_interp = always_skip_interp or np.array_equal(f_coords, grid_coords) # exact even for floats 

232 

233 coords_shape = f_coords.shape[:-1] 

234 loop_shape = next(iter(f_values.values())).shape[:-len(coords_shape)] 

235 states = np.empty((*loop_shape, self.num_pts, self.num_qoi)) 

236 f_coords = f_coords.reshape(-1, self.dim) 

237 for i, qoi in enumerate(self.fields): 

238 field_vals = f_values[qoi].reshape((*loop_shape, -1)) # (..., Q) 

239 if skip_interp: 

240 states[..., i] = field_vals 

241 else: 

242 field_vals = field_vals.reshape((-1, field_vals.shape[-1])).T # (Q, ...) 

243 interp = self.interpolator()(f_coords, field_vals, **self.interpolate_opts) 

244 yg = interp(grid_coords) 

245 states[..., i] = yg.T.reshape(*loop_shape, self.num_pts) 

246 all_states[j] = states.reshape((*loop_shape, self.dof)) 

247 

248 # All fields now on the same dof grid, so stack them in same array 

249 state_shape = () 

250 for index in np.ndindex(all_states.shape): 

251 if all_states[index] is not None: 

252 state_shape = all_states[index].shape 

253 break 

254 ret_states = np.empty(shape + state_shape) 

255 

256 for index, arr in np.ndenumerate(all_states): 

257 ret_states[index] = arr if arr is not None else np.nan 

258 

259 if not (coords_obj_array or fields_obj_array): 

260 ret_states = np.squeeze(ret_states, axis=0) # artificial leading dim for non-object arrays 

261 

262 return ret_states 

263 

264 @abstractmethod 

265 def compute_map(self, **kwargs): 

266 """Compute and store the compression map. Must set the value of `coords` and `_is_computed`. Should 

267 use the same normalization as the parent `Variable` object. 

268 

269 !!! Note 

270 You should pass any required data to `compute_map` with the assumption that the data will be used in the 

271 shape `(num_pts, num_qoi)` where `num_qoi` is the length of `fields` and `num_pts` is the length of 

272 `coords`. This is the shape that the compression map should be constructed in. 

273 """ 

274 raise NotImplementedError 

275 

276 @abstractmethod 

277 def compress(self, data: np.ndarray) -> np.ndarray: 

278 """Compress the data into a latent space. 

279 

280 :param data: `(..., dof)` - the data to compress from full size of `dof` 

281 :return: `(..., rank)` - the compressed latent space data with size `rank` 

282 """ 

283 raise NotImplementedError 

284 

285 @abstractmethod 

286 def reconstruct(self, compressed: np.ndarray) -> np.ndarray: 

287 """Reconstruct the compressed data back into the full `dof` space. 

288 

289 :param compressed: `(..., rank)` - the compressed data to reconstruct 

290 :return: `(..., dof)` - the reconstructed data with full `dof` 

291 """ 

292 raise NotImplementedError 

293 

294 @abstractmethod 

295 def latent_size(self) -> int: 

296 """Return the size of the latent space.""" 

297 raise NotImplementedError 

298 

299 @abstractmethod 

300 def estimate_latent_ranges(self) -> list[tuple[float, float]]: 

301 """Estimate the range of the latent space coefficients.""" 

302 raise NotImplementedError 

303 

304 @classmethod 

305 def from_dict(cls, spec: dict) -> Compression: 

306 """Construct a `Compression` object from a spec dictionary.""" 

307 method = spec.pop('method', 'svd').lower() 

308 match method: 

309 case 'svd': 

310 return SVD(**spec) 

311 case other: 

312 raise NotImplementedError(f"Compression method '{other}' is not implemented.") 

313 

314 

315@dataclass 

316class SVD(Compression): 

317 """A Singular Value Decomposition (SVD) compression method. The SVD will be computed on initialization if the 

318 `data_matrix` is provided. 

319 

320 :ivar data_matrix: `(dof, num_samples)` - the data matrix 

321 :ivar projection_matrix: `(dof, rank)` - the projection matrix 

322 :ivar rank: the rank of the SVD decomposition 

323 :ivar energy_tol: the energy tolerance of the SVD decomposition 

324 :ivar reconstruction_tol: the reconstruction error tolerance of the SVD decomposition 

325 """ 

326 data_matrix: np.ndarray = None # (dof, num_samples) 

327 projection_matrix: np.ndarray = None # (dof, rank) 

328 rank: int = None 

329 energy_tol: float = None 

330 reconstruction_tol: float = None 

331 

332 def __post_init__(self): 

333 """Compute the SVD if the data matrix is provided.""" 

334 if (data_matrix := self.data_matrix) is not None: 

335 self.compute_map(data_matrix, rank=self.rank, energy_tol=self.energy_tol, 

336 reconstruction_tol=self.reconstruction_tol) 

337 

338 def compute_map(self, data_matrix: np.ndarray | dict, rank: int = None, energy_tol: float = None, 

339 reconstruction_tol: float = None): 

340 """Compute the SVD compression map from the data matrix. Recall that `dof` is the total number of degrees of 

341 freedom, equal to the number of grid points `num_pts` times the number of quantities of interest `num_qoi` 

342 at each grid point. 

343 

344 **Rank priority:** if `rank` is provided, then it will be used. Otherwise, if `reconstruction_tol` is provided, 

345 then the rank will be chosen to meet this reconstruction error level. Finally, if `energy_tol` is provided, 

346 then the rank will be chosen to meet this energy fraction level (sum of squared singular values). 

347 

348 :param data_matrix: `(dof, num_samples)` - the data matrix. If passed in as a `dict`, then the data matrix 

349 will be formed by concatenating the values of the `dict` along the last axis in the order 

350 of the `fields` attribute and flattening the last two axes. This is useful for passing 

351 in a dictionary of field values like `{field1: (num_samples, num_pts), field2: ...}` 

352 which ensures consistency of shape with the compression `coords`. 

353 :param rank: the rank of the SVD decomposition 

354 :param energy_tol: the energy tolerance of the SVD decomposition (defaults to 0.95) 

355 :param reconstruction_tol: the reconstruction error tolerance of the SVD decomposition 

356 """ 

357 if isinstance(data_matrix, dict): 

358 data_matrix = np.concatenate([data_matrix[field][..., np.newaxis] for field in self.fields], axis=-1) 

359 data_matrix = data_matrix.reshape(*data_matrix.shape[:-2], -1).T # (dof, num_samples) 

360 

361 nan_idx = np.any(np.isnan(data_matrix), axis=0) 

362 data_matrix = data_matrix[:, ~nan_idx] 

363 u, s, vt = np.linalg.svd(data_matrix) 

364 energy_frac = np.cumsum(s ** 2 / np.sum(s ** 2)) 

365 if rank := (rank or self.rank): 

366 energy_tol = energy_frac[rank - 1] 

367 reconstruction_tol = relative_error(u[:, :rank] @ u[:, :rank].T @ data_matrix, data_matrix) 

368 elif reconstruction_tol := (reconstruction_tol or self.reconstruction_tol): 

369 rank = u.shape[1] 

370 for r in range(1, u.shape[1] + 1): 

371 if relative_error(u[:, :r] @ u[:, :r].T @ data_matrix, data_matrix) <= reconstruction_tol: 

372 rank = r 

373 break 

374 energy_tol = energy_frac[rank - 1] 

375 else: 

376 energy_tol = energy_tol or self.energy_tol or 0.95 

377 idx = int(np.where(energy_frac >= energy_tol)[0][0]) 

378 rank = idx + 1 

379 reconstruction_tol = relative_error(u[:, :rank] @ u[:, :rank].T @ data_matrix, data_matrix) 

380 

381 self.data_matrix = data_matrix 

382 self.rank = rank 

383 self.energy_tol = energy_tol 

384 self.reconstruction_tol = reconstruction_tol 

385 self.projection_matrix = u[:, :rank] # (dof, rank) 

386 self._map_computed = True 

387 

388 def compress(self, data): 

389 return np.squeeze(self.projection_matrix.T @ data[..., np.newaxis], axis=-1) 

390 

391 def reconstruct(self, compressed): 

392 return np.squeeze(self.projection_matrix @ compressed[..., np.newaxis], axis=-1) 

393 

394 def latent_size(self): 

395 return self.rank 

396 

397 def estimate_latent_ranges(self): 

398 if self.map_exists: 

399 latent_data = self.compress(self.data_matrix.T) # (rank, num_samples) 

400 latent_min = np.min(latent_data, axis=0) 

401 latent_max = np.max(latent_data, axis=0) 

402 return [(lmin, lmax) for lmin, lmax in zip(latent_min, latent_max)] 

403 else: 

404 return None