Coverage for src/amisc/compression.py: 93%

196 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-01-24 04:51 +0000

1"""Module for compression methods. 

2 

3Especially useful for field quantities with high dimensions. 

4 

5Includes: 

6 

7- `Compression` — an interface for specifying a compression method for field quantities. 

8- `SVD` — a Singular Value Decomposition (SVD) compression method. 

9""" 

10from __future__ import annotations 

11 

12from abc import ABC, abstractmethod 

13from dataclasses import dataclass, field 

14 

15import numpy as np 

16from scipy.interpolate import RBFInterpolator 

17 

18from amisc.serialize import PickleSerializable 

19from amisc.utils import relative_error 

20 

21__all__ = ["Compression", "SVD"] 

22 

23 

24@dataclass 

25class Compression(PickleSerializable, ABC): 

26 """Base class for compression methods. Compression methods should: 

27 

28 - `compute_map` - compute the compression map from provided data 

29 - `compress` - compress data into a latent space 

30 - `reconstruct` - reconstruct the compressed data back into the full space 

31 - `latent_size` - return the size of the latent space 

32 - `estimate_latent_ranges` - estimate the range of the latent space coefficients 

33 

34 !!! Note "Specifying fields" 

35 The `fields` attribute is a list of strings that specify the field quantities to compress. For example, for 

36 3D velocity data, the fields might be `['ux', 'uy', 'uz']`. The length of the 

37 `fields` attribute is used to determine the number of quantities of interest at each grid point in `coords`. 

38 Note that interpolation to/from the compression grid will assume a shape of `(num_pts, num_qoi)` for the 

39 states on the grid, where `num_qoi` is the length of `fields` and `num_pts` is the length of `coords`. When 

40 constructing the compression map, this important fact should be considered when passing data to 

41 `compute_map`. 

42 

43 In order to use a `Compression` object, you must first call `compute_map` to compute the compression map, which 

44 should set the private value `self._map_computed=True`. The `coords` of the compression grid must also be 

45 specified. The `coords` should have the shape `(num_pts, dim)` where `num_pts` is the number of points in the 

46 compression grid and `dim` is the number of spatial dimensions. If `coords` is a 1d array, then the `dim` is 

47 assumed to be 1. 

48 

49 :ivar fields: list of field quantities to compress 

50 :ivar method: the compression method to use (only svd is supported for now) 

51 :ivar coords: the coordinates of the compression grid 

52 :ivar interpolate_method: the interpolation method to use to interpolate to/from the compression grid 

53 (only `rbf` (i.e. radial basis function) is supported for now) 

54 :ivar interpolate_opts: additional options to pass to the interpolation method 

55 :ivar _map_computed: whether the compression map has been computed 

56 """ 

57 fields: list[str] = field(default_factory=list) 

58 method: str = 'svd' 

59 coords: np.ndarray = None # (num_pts, dim) 

60 interpolate_method: str = 'rbf' 

61 interpolate_opts: dict = field(default_factory=dict) 

62 _map_computed: bool = False 

63 

64 @property 

65 def map_exists(self): 

66 """All compression methods should have `coords` when their map has been constructed.""" 

67 return self.coords is not None and self._map_computed 

68 

69 @property 

70 def dim(self): 

71 """Number of physical grid coordinates for the field quantity, (i.e. x,y,z spatial dims)""" 

72 return self.coords.shape[1] if (self.coords is not None and len(self.coords.shape) > 1) else 1 

73 

74 @property 

75 def num_pts(self): 

76 """Number of physical points in the compression grid.""" 

77 return self.coords.shape[0] if self.coords is not None else None 

78 

79 @property 

80 def num_qoi(self): 

81 """Number of quantities of interest at each grid point, (i.e. `ux, uy, uz` for 3d velocity data).""" 

82 return len(self.fields) if self.fields is not None else 1 

83 

84 @property 

85 def dof(self): 

86 """Total degrees of freedom in the compression grid (i.e. `num_pts * num_qoi`).""" 

87 return self.num_pts * self.num_qoi if self.num_pts is not None else None 

88 

89 def _correct_coords(self, coords): 

90 """Correct the coordinates to be in the correct shape for compression.""" 

91 coords = np.atleast_1d(coords) 

92 if np.issubdtype(coords.dtype, np.object_): # must be object array of np.arrays (for unique coords) 

93 for i, arr in np.ndenumerate(coords): 

94 if len(arr.shape) == 1: 

95 coords[i] = arr[..., np.newaxis] if self.dim == 1 else arr[np.newaxis, ...] 

96 else: 

97 if len(coords.shape) == 1: 

98 coords = coords[..., np.newaxis] if self.dim == 1 else coords[np.newaxis, ...] 

99 return coords 

100 

101 def interpolator(self): 

102 """The interpolator to use during compression and reconstruction. Interpolator expects to be used as: 

103 

104 ```python 

105 xg = np.ndarray # (num_pts, dim) grid coordinates 

106 yg = np.ndarray # (num_pts, ...) scalar values on grid 

107 xp = np.ndarray # (Q, dim) evaluation points 

108 

109 interp = interpolate_method(xg, yg, **interpolate_opts) 

110 

111 yp = interp(xp) # (Q, ...) interpolated values 

112 ``` 

113 """ 

114 method = self.interpolate_method or 'rbf' 

115 match method.lower(): 

116 case 'rbf': 

117 return RBFInterpolator 

118 case other: 

119 raise NotImplementedError(f"Interpolation method '{other}' is not implemented.") 

120 

121 def interpolate_from_grid(self, states: np.ndarray, new_coords: np.ndarray): 

122 """Interpolate the states on the compression grid to new coordinates. 

123 

124 :param states: `(*loop_shape, dof)` - the states on the compression grid 

125 :param new_coords: `(*coord_shape, dim)` - the new coordinates to interpolate to; if a 1d object array, then 

126 each element is assumed to be a unique `(*coord_shape, dim)` array with assumed 

127 same length as loop_shape of the states -- will interpolate each state to the 

128 corresponding new coordinates 

129 :return: `dict` of `(*loop_shape, *coord_shape)` for each qoi - the interpolated states; will return a 

130 single 1d object array for each qoi if new_coords is a 1d object array 

131 """ 

132 new_coords = self._correct_coords(new_coords) 

133 grid_coords = self._correct_coords(self.coords) 

134 

135 coords_obj_array = np.issubdtype(new_coords.dtype, np.object_) 

136 

137 # Iterate over one set of coords and states at a time 

138 def _iterate_coords_and_states(): 

139 if coords_obj_array: 

140 for index, c in np.ndenumerate(new_coords): # assumes same number of coords and states 

141 yield index, c, states[index] 

142 else: 

143 yield (0,), new_coords, states # assumes same coords for all states 

144 

145 all_qois = np.empty(new_coords.shape if coords_obj_array else (1,), dtype=object) 

146 

147 # Do interpolation for each set of unique coordinates (if multiple) 

148 for j, n_coords, state in _iterate_coords_and_states(): 

149 skip_interp = (n_coords.shape == grid_coords.shape and np.allclose(n_coords, grid_coords)) 

150 

151 ret_dict = {} 

152 loop_shape = state.shape[:-1] 

153 coords_shape = n_coords.shape[:-1] 

154 state = state.reshape((*loop_shape, self.num_pts, self.num_qoi)) 

155 n_coords = n_coords.reshape((-1, self.dim)) 

156 for i, qoi in enumerate(self.fields): 

157 if skip_interp: 

158 ret_dict[qoi] = state[..., i] 

159 else: 

160 reshaped_states = state[..., i].reshape(-1, self.num_pts).T # (num_pts, ...) 

161 interp = self.interpolator()(grid_coords, reshaped_states, **self.interpolate_opts) 

162 yp = interp(n_coords) 

163 ret_dict[qoi] = yp.T.reshape(*loop_shape, *coords_shape) 

164 

165 all_qois[j] = ret_dict 

166 

167 if coords_obj_array: 

168 # Make an object array for each qoi, where each element is a unique `(*loop_shape, *coord_shape)` array 

169 _, _first_dict = next(np.ndenumerate(all_qois)) 

170 ret = {qoi: np.empty(all_qois.shape, dtype=object) for qoi in _first_dict} 

171 for qoi in ret: 

172 for index, qoi_dict in np.ndenumerate(all_qois): 

173 ret[qoi][index] = qoi_dict[qoi] 

174 else: 

175 # Otherwise, all loop dims used the same coords, so just return the single array for each qoi 

176 ret = all_qois[0] 

177 

178 return ret 

179 

180 def interpolate_to_grid(self, field_coords: np.ndarray, field_values): 

181 """Interpolate the field values at given coordinates to the compression grid. 

182 

183 :param field_coords: `(*coord_shape, dim)` - the coordinates of the field values; if a 1d object array, then 

184 each element is assumed to be a unique `(*coord_shape, dim)` array 

185 :param field_values: `dict` of `(*loop_shape, *coord_shape)` for each qoi - the field values at the coordinates; 

186 if each array is a 1d object array, then each element is assumed to be a unique 

187 `(*loop_shape, *coord_shape)` array corresponding to the `field_coords` 

188 :return: `(*loop_shape, dof)` - the interpolated values on the compression grid 

189 """ 

190 field_coords = self._correct_coords(field_coords) 

191 grid_coords = self._correct_coords(self.coords) 

192 

193 # Loop over each set of coordinates and field values (multiple if they are object arrays) 

194 # If only one set of coords, then they are assumed the same for each set of field values 

195 coords_obj_array = np.issubdtype(field_coords.dtype, np.object_) 

196 fields_obj_array = np.issubdtype(next(iter(field_values.values())).dtype, np.object_) 

197 def _iterate_coords_and_fields(): 

198 if coords_obj_array: 

199 for index, c in np.ndenumerate(field_coords): # assumes same number of coords and field values 

200 yield index, c, {qoi: field_values[qoi][index] for qoi in field_values} 

201 elif fields_obj_array: 

202 for index in np.ndindex(next(iter(field_values.values())).shape): 

203 yield index, field_coords, {qoi: field_values[qoi][index] for qoi in field_values} 

204 else: 

205 yield (0,), field_coords, field_values # assumes same coords for all field values 

206 

207 if coords_obj_array: 

208 shape = field_coords.shape 

209 elif fields_obj_array: 

210 shape = next(iter(field_values.values())).shape 

211 else: 

212 shape = (1,) 

213 

214 always_skip_interp = not coords_obj_array and np.array_equal(field_coords, grid_coords) # must be exact match 

215 

216 all_states = np.empty(shape, dtype=object) # are you in good hands? 

217 

218 for j, f_coords, f_values in _iterate_coords_and_fields(): 

219 skip_interp = always_skip_interp or np.array_equal(f_coords, grid_coords) # exact even for floats 

220 

221 coords_shape = f_coords.shape[:-1] 

222 loop_shape = next(iter(f_values.values())).shape[:-len(coords_shape)] 

223 states = np.empty((*loop_shape, self.num_pts, self.num_qoi)) 

224 f_coords = f_coords.reshape(-1, self.dim) 

225 for i, qoi in enumerate(self.fields): 

226 field_vals = f_values[qoi].reshape((*loop_shape, -1)) # (..., Q) 

227 if skip_interp: 

228 states[..., i] = field_vals 

229 else: 

230 field_vals = field_vals.reshape((-1, field_vals.shape[-1])).T # (Q, ...) 

231 interp = self.interpolator()(f_coords, field_vals, **self.interpolate_opts) 

232 yg = interp(grid_coords) 

233 states[..., i] = yg.T.reshape(*loop_shape, self.num_pts) 

234 all_states[j] = states.reshape((*loop_shape, self.dof)) 

235 

236 # All fields now on the same dof grid, so stack them in same array 

237 index = next(np.ndindex(all_states.shape)) 

238 ret_states = np.empty(shape + all_states[index].shape) 

239 

240 for index, arr in np.ndenumerate(all_states): 

241 ret_states[index] = arr 

242 

243 if not (coords_obj_array or fields_obj_array): 

244 ret_states = np.squeeze(ret_states, axis=0) # artificial leading dim for non-object arrays 

245 

246 return ret_states 

247 

248 @abstractmethod 

249 def compute_map(self, **kwargs): 

250 """Compute and store the compression map. Must set the value of `coords` and `_is_computed`. Should 

251 use the same normalization as the parent `Variable` object. 

252 

253 !!! Note 

254 You should pass any required data to `compute_map` with the assumption that the data will be used in the 

255 shape `(num_pts, num_qoi)` where `num_qoi` is the length of `fields` and `num_pts` is the length of 

256 `coords`. This is the shape that the compression map should be constructed in. 

257 """ 

258 raise NotImplementedError 

259 

260 @abstractmethod 

261 def compress(self, data: np.ndarray) -> np.ndarray: 

262 """Compress the data into a latent space. 

263 

264 :param data: `(..., dof)` - the data to compress from full size of `dof` 

265 :return: `(..., rank)` - the compressed latent space data with size `rank` 

266 """ 

267 raise NotImplementedError 

268 

269 @abstractmethod 

270 def reconstruct(self, compressed: np.ndarray) -> np.ndarray: 

271 """Reconstruct the compressed data back into the full `dof` space. 

272 

273 :param compressed: `(..., rank)` - the compressed data to reconstruct 

274 :return: `(..., dof)` - the reconstructed data with full `dof` 

275 """ 

276 raise NotImplementedError 

277 

278 @abstractmethod 

279 def latent_size(self) -> int: 

280 """Return the size of the latent space.""" 

281 raise NotImplementedError 

282 

283 @abstractmethod 

284 def estimate_latent_ranges(self) -> list[tuple[float, float]]: 

285 """Estimate the range of the latent space coefficients.""" 

286 raise NotImplementedError 

287 

288 @classmethod 

289 def from_dict(cls, spec: dict) -> Compression: 

290 """Construct a `Compression` object from a spec dictionary.""" 

291 method = spec.pop('method', 'svd').lower() 

292 match method: 

293 case 'svd': 

294 return SVD(**spec) 

295 case other: 

296 raise NotImplementedError(f"Compression method '{other}' is not implemented.") 

297 

298 

299@dataclass 

300class SVD(Compression): 

301 """A Singular Value Decomposition (SVD) compression method. The SVD will be computed on initialization if the 

302 `data_matrix` is provided. 

303 

304 :ivar data_matrix: `(dof, num_samples)` - the data matrix 

305 :ivar projection_matrix: `(dof, rank)` - the projection matrix 

306 :ivar rank: the rank of the SVD decomposition 

307 :ivar energy_tol: the energy tolerance of the SVD decomposition 

308 :ivar reconstruction_tol: the reconstruction error tolerance of the SVD decomposition 

309 """ 

310 data_matrix: np.ndarray = None # (dof, num_samples) 

311 projection_matrix: np.ndarray = None # (dof, rank) 

312 rank: int = None 

313 energy_tol: float = None 

314 reconstruction_tol: float = None 

315 

316 def __post_init__(self): 

317 """Compute the SVD if the data matrix is provided.""" 

318 if (data_matrix := self.data_matrix) is not None: 

319 self.compute_map(data_matrix, rank=self.rank, energy_tol=self.energy_tol, 

320 reconstruction_tol=self.reconstruction_tol) 

321 

322 def compute_map(self, data_matrix: np.ndarray | dict, rank: int = None, energy_tol: float = None, 

323 reconstruction_tol: float = None): 

324 """Compute the SVD compression map from the data matrix. Recall that `dof` is the total number of degrees of 

325 freedom, equal to the number of grid points `num_pts` times the number of quantities of interest `num_qoi` 

326 at each grid point. 

327 

328 **Rank priority:** if `rank` is provided, then it will be used. Otherwise, if `reconstruction_tol` is provided, 

329 then the rank will be chosen to meet this reconstruction error level. Finally, if `energy_tol` is provided, 

330 then the rank will be chosen to meet this energy fraction level (sum of squared singular values). 

331 

332 :param data_matrix: `(dof, num_samples)` - the data matrix. If passed in as a `dict`, then the data matrix 

333 will be formed by concatenating the values of the `dict` along the last axis in the order 

334 of the `fields` attribute and flattening the last two axes. This is useful for passing 

335 in a dictionary of field values like `{field1: (num_samples, num_pts), field2: ...}` 

336 which ensures consistency of shape with the compression `coords`. 

337 :param rank: the rank of the SVD decomposition 

338 :param energy_tol: the energy tolerance of the SVD decomposition (defaults to 0.95) 

339 :param reconstruction_tol: the reconstruction error tolerance of the SVD decomposition 

340 """ 

341 if isinstance(data_matrix, dict): 

342 data_matrix = np.concatenate([data_matrix[field][..., np.newaxis] for field in self.fields], axis=-1) 

343 data_matrix = data_matrix.reshape(*data_matrix.shape[:-2], -1).T # (dof, num_samples) 

344 

345 nan_idx = np.any(np.isnan(data_matrix), axis=0) 

346 data_matrix = data_matrix[:, ~nan_idx] 

347 u, s, vt = np.linalg.svd(data_matrix) 

348 energy_frac = np.cumsum(s ** 2 / np.sum(s ** 2)) 

349 if rank := (rank or self.rank): 

350 energy_tol = energy_frac[rank - 1] 

351 reconstruction_tol = relative_error(u[:, :rank] @ u[:, :rank].T @ data_matrix, data_matrix) 

352 elif reconstruction_tol := (reconstruction_tol or self.reconstruction_tol): 

353 rank = u.shape[1] 

354 for r in range(1, u.shape[1] + 1): 

355 if relative_error(u[:, :r] @ u[:, :r].T @ data_matrix, data_matrix) <= reconstruction_tol: 

356 rank = r 

357 break 

358 energy_tol = energy_frac[rank - 1] 

359 else: 

360 energy_tol = energy_tol or self.energy_tol or 0.95 

361 idx = int(np.where(energy_frac >= energy_tol)[0][0]) 

362 rank = idx + 1 

363 reconstruction_tol = relative_error(u[:, :rank] @ u[:, :rank].T @ data_matrix, data_matrix) 

364 

365 self.data_matrix = data_matrix 

366 self.rank = rank 

367 self.energy_tol = energy_tol 

368 self.reconstruction_tol = reconstruction_tol 

369 self.projection_matrix = u[:, :rank] # (dof, rank) 

370 self._map_computed = True 

371 

372 def compress(self, data): 

373 return np.squeeze(self.projection_matrix.T @ data[..., np.newaxis], axis=-1) 

374 

375 def reconstruct(self, compressed): 

376 return np.squeeze(self.projection_matrix @ compressed[..., np.newaxis], axis=-1) 

377 

378 def latent_size(self): 

379 return self.rank 

380 

381 def estimate_latent_ranges(self): 

382 if self.map_exists: 

383 latent_data = self.compress(self.data_matrix.T) # (rank, num_samples) 

384 latent_min = np.min(latent_data, axis=0) 

385 latent_max = np.max(latent_data, axis=0) 

386 return [(lmin, lmax) for lmin, lmax in zip(latent_min, latent_max)] 

387 else: 

388 return None