Coverage for src/amisc/compression.py: 93%
209 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-10 15:12 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-10 15:12 +0000
1"""Module for compression methods.
3Especially useful for field quantities with high dimensions.
5Includes:
7- `Compression` — an interface for specifying a compression method for field quantities.
8- `SVD` — a Singular Value Decomposition (SVD) compression method.
9"""
10from __future__ import annotations
12from abc import ABC, abstractmethod
13from dataclasses import dataclass, field
15import numpy as np
16from scipy.interpolate import RBFInterpolator
18from amisc.serialize import PickleSerializable
19from amisc.utils import relative_error
21__all__ = ["Compression", "SVD"]
24@dataclass
25class Compression(PickleSerializable, ABC):
26 """Base class for compression methods. Compression methods should:
28 - `compute_map` - compute the compression map from provided data
29 - `compress` - compress data into a latent space
30 - `reconstruct` - reconstruct the compressed data back into the full space
31 - `latent_size` - return the size of the latent space
32 - `estimate_latent_ranges` - estimate the range of the latent space coefficients
34 !!! Note "Specifying fields"
35 The `fields` attribute is a list of strings that specify the field quantities to compress. For example, for
36 3D velocity data, the fields might be `['ux', 'uy', 'uz']`. The length of the
37 `fields` attribute is used to determine the number of quantities of interest at each grid point in `coords`.
38 Note that interpolation to/from the compression grid will assume a shape of `(num_pts, num_qoi)` for the
39 states on the grid, where `num_qoi` is the length of `fields` and `num_pts` is the length of `coords`. When
40 constructing the compression map, this important fact should be considered when passing data to
41 `compute_map`.
43 In order to use a `Compression` object, you must first call `compute_map` to compute the compression map, which
44 should set the private value `self._map_computed=True`. The `coords` of the compression grid must also be
45 specified. The `coords` should have the shape `(num_pts, dim)` where `num_pts` is the number of points in the
46 compression grid and `dim` is the number of spatial dimensions. If `coords` is a 1d array, then the `dim` is
47 assumed to be 1.
49 :ivar fields: list of field quantities to compress
50 :ivar method: the compression method to use (only svd is supported for now)
51 :ivar coords: the coordinates of the compression grid
52 :ivar interpolate_method: the interpolation method to use to interpolate to/from the compression grid
53 (only `rbf` (i.e. radial basis function) is supported for now)
54 :ivar interpolate_opts: additional options to pass to the interpolation method
55 :ivar _map_computed: whether the compression map has been computed
56 """
57 fields: list[str] = field(default_factory=list)
58 method: str = 'svd'
59 coords: np.ndarray = None # (num_pts, dim)
60 interpolate_method: str = 'rbf'
61 interpolate_opts: dict = field(default_factory=dict)
62 _map_computed: bool = False
64 @property
65 def map_exists(self):
66 """All compression methods should have `coords` when their map has been constructed."""
67 return self.coords is not None and self._map_computed
69 @property
70 def dim(self):
71 """Number of physical grid coordinates for the field quantity, (i.e. x,y,z spatial dims)"""
72 return self.coords.shape[1] if (self.coords is not None and len(self.coords.shape) > 1) else 1
74 @property
75 def num_pts(self):
76 """Number of physical points in the compression grid."""
77 return self.coords.shape[0] if self.coords is not None else None
79 @property
80 def num_qoi(self):
81 """Number of quantities of interest at each grid point, (i.e. `ux, uy, uz` for 3d velocity data)."""
82 return len(self.fields) if self.fields is not None else 1
84 @property
85 def dof(self):
86 """Total degrees of freedom in the compression grid (i.e. `num_pts * num_qoi`)."""
87 return self.num_pts * self.num_qoi if self.num_pts is not None else None
89 def _correct_coords(self, coords):
90 """Correct the coordinates to be in the correct shape for compression."""
91 coords = np.atleast_1d(coords)
92 if np.issubdtype(coords.dtype, np.object_): # must be object array of np.arrays (for unique coords)
93 for i, arr in np.ndenumerate(coords):
94 if arr is None:
95 continue # skip empty values
96 if len(arr.shape) == 1:
97 coords[i] = arr[..., np.newaxis] if self.dim == 1 else arr[np.newaxis, ...]
98 else:
99 if len(coords.shape) == 1:
100 coords = coords[..., np.newaxis] if self.dim == 1 else coords[np.newaxis, ...]
101 return coords
103 def interpolator(self):
104 """The interpolator to use during compression and reconstruction. Interpolator expects to be used as:
106 ```python
107 xg = np.ndarray # (num_pts, dim) grid coordinates
108 yg = np.ndarray # (num_pts, ...) scalar values on grid
109 xp = np.ndarray # (Q, dim) evaluation points
111 interp = interpolate_method(xg, yg, **interpolate_opts)
113 yp = interp(xp) # (Q, ...) interpolated values
114 ```
115 """
116 method = self.interpolate_method or 'rbf'
117 match method.lower():
118 case 'rbf':
119 return RBFInterpolator
120 case other:
121 raise NotImplementedError(f"Interpolation method '{other}' is not implemented.")
123 def interpolate_from_grid(self, states: np.ndarray, new_coords: np.ndarray):
124 """Interpolate the states on the compression grid to new coordinates.
126 :param states: `(*loop_shape, dof)` - the states on the compression grid
127 :param new_coords: `(*coord_shape, dim)` - the new coordinates to interpolate to; if a 1d object array, then
128 each element is assumed to be a unique `(*coord_shape, dim)` array with assumed
129 same length as loop_shape of the states -- will interpolate each state to the
130 corresponding new coordinates
131 :return: `dict` of `(*loop_shape, *coord_shape)` for each qoi - the interpolated states; will return a
132 single 1d object array for each qoi if new_coords is a 1d object array
133 """
134 new_coords = self._correct_coords(new_coords)
135 grid_coords = self._correct_coords(self.coords)
137 coords_obj_array = np.issubdtype(new_coords.dtype, np.object_)
139 # Iterate over one set of coords and states at a time
140 def _iterate_coords_and_states():
141 if coords_obj_array:
142 for index, c in np.ndenumerate(new_coords): # assumes same number of coords and states
143 yield index, c, states[index]
144 else:
145 yield (0,), new_coords, states # assumes same coords for all states
147 all_qois = np.empty(new_coords.shape if coords_obj_array else (1,), dtype=object)
149 # Do interpolation for each set of unique coordinates (if multiple)
150 for j, n_coords, state in _iterate_coords_and_states():
151 if n_coords is None: # Skip empty coords
152 continue
154 skip_interp = (n_coords.shape == grid_coords.shape and np.allclose(n_coords, grid_coords))
156 ret_dict = {}
157 loop_shape = state.shape[:-1]
158 coords_shape = n_coords.shape[:-1]
159 state = state.reshape((*loop_shape, self.num_pts, self.num_qoi))
160 n_coords = n_coords.reshape((-1, self.dim))
161 for i, qoi in enumerate(self.fields):
162 if skip_interp:
163 ret_dict[qoi] = state[..., i]
164 else:
165 reshaped_states = state[..., i].reshape(-1, self.num_pts).T # (num_pts, ...)
166 interp = self.interpolator()(grid_coords, reshaped_states, **self.interpolate_opts)
167 yp = interp(n_coords)
168 ret_dict[qoi] = yp.T.reshape(*loop_shape, *coords_shape)
170 all_qois[j] = ret_dict
172 if coords_obj_array:
173 # Make an object array for each qoi, where each element is a unique `(*loop_shape, *coord_shape)` array
174 for _, _first_dict in np.ndenumerate(all_qois):
175 if _first_dict is not None:
176 break
177 ret = {qoi: np.empty(all_qois.shape, dtype=object) for qoi in _first_dict}
178 for qoi in ret:
179 for index, qoi_dict in np.ndenumerate(all_qois):
180 if qoi_dict is not None:
181 ret[qoi][index] = qoi_dict[qoi]
182 else:
183 # Otherwise, all loop dims used the same coords, so just return the single array for each qoi
184 ret = all_qois[0]
186 return ret
188 def interpolate_to_grid(self, field_coords: np.ndarray, field_values):
189 """Interpolate the field values at given coordinates to the compression grid. An array of nan is returned
190 for any coordinates or field values that are empty or None.
192 :param field_coords: `(*coord_shape, dim)` - the coordinates of the field values; if an object array, then
193 each element is assumed to be a unique `(*coord_shape, dim)` array
194 :param field_values: `dict` of `(*loop_shape, *coord_shape)` for each qoi - the field values at the coordinates;
195 if each array is an object array, then each element is assumed to be a unique
196 `(*loop_shape, *coord_shape)` array corresponding to the `field_coords`
197 :return: `(*loop_shape, dof)` - the interpolated values on the compression grid
198 """
199 field_coords = self._correct_coords(field_coords)
200 grid_coords = self._correct_coords(self.coords)
202 # Loop over each set of coordinates and field values (multiple if they are object arrays)
203 # If only one set of coords, then they are assumed the same for each set of field values
204 coords_obj_array = np.issubdtype(field_coords.dtype, np.object_)
205 fields_obj_array = np.issubdtype(next(iter(field_values.values())).dtype, np.object_)
206 def _iterate_coords_and_fields():
207 if coords_obj_array:
208 for index, c in np.ndenumerate(field_coords): # assumes same number of coords and field values
209 yield index, c, {qoi: field_values[qoi][index] for qoi in field_values}
210 elif fields_obj_array:
211 for index in np.ndindex(next(iter(field_values.values())).shape):
212 yield index, field_coords, {qoi: field_values[qoi][index] for qoi in field_values}
213 else:
214 yield (0,), field_coords, field_values # assumes same coords for all field values
216 if coords_obj_array:
217 shape = field_coords.shape
218 elif fields_obj_array:
219 shape = next(iter(field_values.values())).shape
220 else:
221 shape = (1,)
223 always_skip_interp = not coords_obj_array and np.array_equal(field_coords, grid_coords) # must be exact match
225 all_states = np.empty(shape, dtype=object) # are you in good hands?
227 for j, f_coords, f_values in _iterate_coords_and_fields():
228 if f_coords is None or any([val is None for val in f_values.values()]): # Skip empty samples
229 continue
231 skip_interp = always_skip_interp or np.array_equal(f_coords, grid_coords) # exact even for floats
233 coords_shape = f_coords.shape[:-1]
234 loop_shape = next(iter(f_values.values())).shape[:-len(coords_shape)]
235 states = np.empty((*loop_shape, self.num_pts, self.num_qoi))
236 f_coords = f_coords.reshape(-1, self.dim)
237 for i, qoi in enumerate(self.fields):
238 field_vals = f_values[qoi].reshape((*loop_shape, -1)) # (..., Q)
239 if skip_interp:
240 states[..., i] = field_vals
241 else:
242 field_vals = field_vals.reshape((-1, field_vals.shape[-1])).T # (Q, ...)
243 interp = self.interpolator()(f_coords, field_vals, **self.interpolate_opts)
244 yg = interp(grid_coords)
245 states[..., i] = yg.T.reshape(*loop_shape, self.num_pts)
246 all_states[j] = states.reshape((*loop_shape, self.dof))
248 # All fields now on the same dof grid, so stack them in same array
249 state_shape = ()
250 for index in np.ndindex(all_states.shape):
251 if all_states[index] is not None:
252 state_shape = all_states[index].shape
253 break
254 ret_states = np.empty(shape + state_shape)
256 for index, arr in np.ndenumerate(all_states):
257 ret_states[index] = arr if arr is not None else np.nan
259 if not (coords_obj_array or fields_obj_array):
260 ret_states = np.squeeze(ret_states, axis=0) # artificial leading dim for non-object arrays
262 return ret_states
264 @abstractmethod
265 def compute_map(self, **kwargs):
266 """Compute and store the compression map. Must set the value of `coords` and `_is_computed`. Should
267 use the same normalization as the parent `Variable` object.
269 !!! Note
270 You should pass any required data to `compute_map` with the assumption that the data will be used in the
271 shape `(num_pts, num_qoi)` where `num_qoi` is the length of `fields` and `num_pts` is the length of
272 `coords`. This is the shape that the compression map should be constructed in.
273 """
274 raise NotImplementedError
276 @abstractmethod
277 def compress(self, data: np.ndarray) -> np.ndarray:
278 """Compress the data into a latent space.
280 :param data: `(..., dof)` - the data to compress from full size of `dof`
281 :return: `(..., rank)` - the compressed latent space data with size `rank`
282 """
283 raise NotImplementedError
285 @abstractmethod
286 def reconstruct(self, compressed: np.ndarray) -> np.ndarray:
287 """Reconstruct the compressed data back into the full `dof` space.
289 :param compressed: `(..., rank)` - the compressed data to reconstruct
290 :return: `(..., dof)` - the reconstructed data with full `dof`
291 """
292 raise NotImplementedError
294 @abstractmethod
295 def latent_size(self) -> int:
296 """Return the size of the latent space."""
297 raise NotImplementedError
299 @abstractmethod
300 def estimate_latent_ranges(self) -> list[tuple[float, float]]:
301 """Estimate the range of the latent space coefficients."""
302 raise NotImplementedError
304 @classmethod
305 def from_dict(cls, spec: dict) -> Compression:
306 """Construct a `Compression` object from a spec dictionary."""
307 method = spec.pop('method', 'svd').lower()
308 match method:
309 case 'svd':
310 return SVD(**spec)
311 case other:
312 raise NotImplementedError(f"Compression method '{other}' is not implemented.")
315@dataclass
316class SVD(Compression):
317 """A Singular Value Decomposition (SVD) compression method. The SVD will be computed on initialization if the
318 `data_matrix` is provided.
320 :ivar data_matrix: `(dof, num_samples)` - the data matrix
321 :ivar projection_matrix: `(dof, rank)` - the projection matrix
322 :ivar rank: the rank of the SVD decomposition
323 :ivar energy_tol: the energy tolerance of the SVD decomposition
324 :ivar reconstruction_tol: the reconstruction error tolerance of the SVD decomposition
325 """
326 data_matrix: np.ndarray = None # (dof, num_samples)
327 projection_matrix: np.ndarray = None # (dof, rank)
328 rank: int = None
329 energy_tol: float = None
330 reconstruction_tol: float = None
332 def __post_init__(self):
333 """Compute the SVD if the data matrix is provided."""
334 if (data_matrix := self.data_matrix) is not None:
335 self.compute_map(data_matrix, rank=self.rank, energy_tol=self.energy_tol,
336 reconstruction_tol=self.reconstruction_tol)
338 def compute_map(self, data_matrix: np.ndarray | dict, rank: int = None, energy_tol: float = None,
339 reconstruction_tol: float = None):
340 """Compute the SVD compression map from the data matrix. Recall that `dof` is the total number of degrees of
341 freedom, equal to the number of grid points `num_pts` times the number of quantities of interest `num_qoi`
342 at each grid point.
344 **Rank priority:** if `rank` is provided, then it will be used. Otherwise, if `reconstruction_tol` is provided,
345 then the rank will be chosen to meet this reconstruction error level. Finally, if `energy_tol` is provided,
346 then the rank will be chosen to meet this energy fraction level (sum of squared singular values).
348 :param data_matrix: `(dof, num_samples)` - the data matrix. If passed in as a `dict`, then the data matrix
349 will be formed by concatenating the values of the `dict` along the last axis in the order
350 of the `fields` attribute and flattening the last two axes. This is useful for passing
351 in a dictionary of field values like `{field1: (num_samples, num_pts), field2: ...}`
352 which ensures consistency of shape with the compression `coords`.
353 :param rank: the rank of the SVD decomposition
354 :param energy_tol: the energy tolerance of the SVD decomposition (defaults to 0.95)
355 :param reconstruction_tol: the reconstruction error tolerance of the SVD decomposition
356 """
357 if isinstance(data_matrix, dict):
358 data_matrix = np.concatenate([data_matrix[field][..., np.newaxis] for field in self.fields], axis=-1)
359 data_matrix = data_matrix.reshape(*data_matrix.shape[:-2], -1).T # (dof, num_samples)
361 nan_idx = np.any(np.isnan(data_matrix), axis=0)
362 data_matrix = data_matrix[:, ~nan_idx]
363 u, s, vt = np.linalg.svd(data_matrix)
364 energy_frac = np.cumsum(s ** 2 / np.sum(s ** 2))
365 if rank := (rank or self.rank):
366 energy_tol = energy_frac[rank - 1]
367 reconstruction_tol = relative_error(u[:, :rank] @ u[:, :rank].T @ data_matrix, data_matrix)
368 elif reconstruction_tol := (reconstruction_tol or self.reconstruction_tol):
369 rank = u.shape[1]
370 for r in range(1, u.shape[1] + 1):
371 if relative_error(u[:, :r] @ u[:, :r].T @ data_matrix, data_matrix) <= reconstruction_tol:
372 rank = r
373 break
374 energy_tol = energy_frac[rank - 1]
375 else:
376 energy_tol = energy_tol or self.energy_tol or 0.95
377 idx = int(np.where(energy_frac >= energy_tol)[0][0])
378 rank = idx + 1
379 reconstruction_tol = relative_error(u[:, :rank] @ u[:, :rank].T @ data_matrix, data_matrix)
381 self.data_matrix = data_matrix
382 self.rank = rank
383 self.energy_tol = energy_tol
384 self.reconstruction_tol = reconstruction_tol
385 self.projection_matrix = u[:, :rank] # (dof, rank)
386 self._map_computed = True
388 def compress(self, data):
389 return np.squeeze(self.projection_matrix.T @ data[..., np.newaxis], axis=-1)
391 def reconstruct(self, compressed):
392 return np.squeeze(self.projection_matrix @ compressed[..., np.newaxis], axis=-1)
394 def latent_size(self):
395 return self.rank
397 def estimate_latent_ranges(self):
398 if self.map_exists:
399 latent_data = self.compress(self.data_matrix.T) # (rank, num_samples)
400 latent_min = np.min(latent_data, axis=0)
401 latent_max = np.max(latent_data, axis=0)
402 return [(lmin, lmax) for lmin, lmax in zip(latent_min, latent_max)]
403 else:
404 return None