Coverage for src/amisc/compression.py: 93%
196 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-24 04:51 +0000
1"""Module for compression methods.
3Especially useful for field quantities with high dimensions.
5Includes:
7- `Compression` — an interface for specifying a compression method for field quantities.
8- `SVD` — a Singular Value Decomposition (SVD) compression method.
9"""
10from __future__ import annotations
12from abc import ABC, abstractmethod
13from dataclasses import dataclass, field
15import numpy as np
16from scipy.interpolate import RBFInterpolator
18from amisc.serialize import PickleSerializable
19from amisc.utils import relative_error
21__all__ = ["Compression", "SVD"]
24@dataclass
25class Compression(PickleSerializable, ABC):
26 """Base class for compression methods. Compression methods should:
28 - `compute_map` - compute the compression map from provided data
29 - `compress` - compress data into a latent space
30 - `reconstruct` - reconstruct the compressed data back into the full space
31 - `latent_size` - return the size of the latent space
32 - `estimate_latent_ranges` - estimate the range of the latent space coefficients
34 !!! Note "Specifying fields"
35 The `fields` attribute is a list of strings that specify the field quantities to compress. For example, for
36 3D velocity data, the fields might be `['ux', 'uy', 'uz']`. The length of the
37 `fields` attribute is used to determine the number of quantities of interest at each grid point in `coords`.
38 Note that interpolation to/from the compression grid will assume a shape of `(num_pts, num_qoi)` for the
39 states on the grid, where `num_qoi` is the length of `fields` and `num_pts` is the length of `coords`. When
40 constructing the compression map, this important fact should be considered when passing data to
41 `compute_map`.
43 In order to use a `Compression` object, you must first call `compute_map` to compute the compression map, which
44 should set the private value `self._map_computed=True`. The `coords` of the compression grid must also be
45 specified. The `coords` should have the shape `(num_pts, dim)` where `num_pts` is the number of points in the
46 compression grid and `dim` is the number of spatial dimensions. If `coords` is a 1d array, then the `dim` is
47 assumed to be 1.
49 :ivar fields: list of field quantities to compress
50 :ivar method: the compression method to use (only svd is supported for now)
51 :ivar coords: the coordinates of the compression grid
52 :ivar interpolate_method: the interpolation method to use to interpolate to/from the compression grid
53 (only `rbf` (i.e. radial basis function) is supported for now)
54 :ivar interpolate_opts: additional options to pass to the interpolation method
55 :ivar _map_computed: whether the compression map has been computed
56 """
57 fields: list[str] = field(default_factory=list)
58 method: str = 'svd'
59 coords: np.ndarray = None # (num_pts, dim)
60 interpolate_method: str = 'rbf'
61 interpolate_opts: dict = field(default_factory=dict)
62 _map_computed: bool = False
64 @property
65 def map_exists(self):
66 """All compression methods should have `coords` when their map has been constructed."""
67 return self.coords is not None and self._map_computed
69 @property
70 def dim(self):
71 """Number of physical grid coordinates for the field quantity, (i.e. x,y,z spatial dims)"""
72 return self.coords.shape[1] if (self.coords is not None and len(self.coords.shape) > 1) else 1
74 @property
75 def num_pts(self):
76 """Number of physical points in the compression grid."""
77 return self.coords.shape[0] if self.coords is not None else None
79 @property
80 def num_qoi(self):
81 """Number of quantities of interest at each grid point, (i.e. `ux, uy, uz` for 3d velocity data)."""
82 return len(self.fields) if self.fields is not None else 1
84 @property
85 def dof(self):
86 """Total degrees of freedom in the compression grid (i.e. `num_pts * num_qoi`)."""
87 return self.num_pts * self.num_qoi if self.num_pts is not None else None
89 def _correct_coords(self, coords):
90 """Correct the coordinates to be in the correct shape for compression."""
91 coords = np.atleast_1d(coords)
92 if np.issubdtype(coords.dtype, np.object_): # must be object array of np.arrays (for unique coords)
93 for i, arr in np.ndenumerate(coords):
94 if len(arr.shape) == 1:
95 coords[i] = arr[..., np.newaxis] if self.dim == 1 else arr[np.newaxis, ...]
96 else:
97 if len(coords.shape) == 1:
98 coords = coords[..., np.newaxis] if self.dim == 1 else coords[np.newaxis, ...]
99 return coords
101 def interpolator(self):
102 """The interpolator to use during compression and reconstruction. Interpolator expects to be used as:
104 ```python
105 xg = np.ndarray # (num_pts, dim) grid coordinates
106 yg = np.ndarray # (num_pts, ...) scalar values on grid
107 xp = np.ndarray # (Q, dim) evaluation points
109 interp = interpolate_method(xg, yg, **interpolate_opts)
111 yp = interp(xp) # (Q, ...) interpolated values
112 ```
113 """
114 method = self.interpolate_method or 'rbf'
115 match method.lower():
116 case 'rbf':
117 return RBFInterpolator
118 case other:
119 raise NotImplementedError(f"Interpolation method '{other}' is not implemented.")
121 def interpolate_from_grid(self, states: np.ndarray, new_coords: np.ndarray):
122 """Interpolate the states on the compression grid to new coordinates.
124 :param states: `(*loop_shape, dof)` - the states on the compression grid
125 :param new_coords: `(*coord_shape, dim)` - the new coordinates to interpolate to; if a 1d object array, then
126 each element is assumed to be a unique `(*coord_shape, dim)` array with assumed
127 same length as loop_shape of the states -- will interpolate each state to the
128 corresponding new coordinates
129 :return: `dict` of `(*loop_shape, *coord_shape)` for each qoi - the interpolated states; will return a
130 single 1d object array for each qoi if new_coords is a 1d object array
131 """
132 new_coords = self._correct_coords(new_coords)
133 grid_coords = self._correct_coords(self.coords)
135 coords_obj_array = np.issubdtype(new_coords.dtype, np.object_)
137 # Iterate over one set of coords and states at a time
138 def _iterate_coords_and_states():
139 if coords_obj_array:
140 for index, c in np.ndenumerate(new_coords): # assumes same number of coords and states
141 yield index, c, states[index]
142 else:
143 yield (0,), new_coords, states # assumes same coords for all states
145 all_qois = np.empty(new_coords.shape if coords_obj_array else (1,), dtype=object)
147 # Do interpolation for each set of unique coordinates (if multiple)
148 for j, n_coords, state in _iterate_coords_and_states():
149 skip_interp = (n_coords.shape == grid_coords.shape and np.allclose(n_coords, grid_coords))
151 ret_dict = {}
152 loop_shape = state.shape[:-1]
153 coords_shape = n_coords.shape[:-1]
154 state = state.reshape((*loop_shape, self.num_pts, self.num_qoi))
155 n_coords = n_coords.reshape((-1, self.dim))
156 for i, qoi in enumerate(self.fields):
157 if skip_interp:
158 ret_dict[qoi] = state[..., i]
159 else:
160 reshaped_states = state[..., i].reshape(-1, self.num_pts).T # (num_pts, ...)
161 interp = self.interpolator()(grid_coords, reshaped_states, **self.interpolate_opts)
162 yp = interp(n_coords)
163 ret_dict[qoi] = yp.T.reshape(*loop_shape, *coords_shape)
165 all_qois[j] = ret_dict
167 if coords_obj_array:
168 # Make an object array for each qoi, where each element is a unique `(*loop_shape, *coord_shape)` array
169 _, _first_dict = next(np.ndenumerate(all_qois))
170 ret = {qoi: np.empty(all_qois.shape, dtype=object) for qoi in _first_dict}
171 for qoi in ret:
172 for index, qoi_dict in np.ndenumerate(all_qois):
173 ret[qoi][index] = qoi_dict[qoi]
174 else:
175 # Otherwise, all loop dims used the same coords, so just return the single array for each qoi
176 ret = all_qois[0]
178 return ret
180 def interpolate_to_grid(self, field_coords: np.ndarray, field_values):
181 """Interpolate the field values at given coordinates to the compression grid.
183 :param field_coords: `(*coord_shape, dim)` - the coordinates of the field values; if a 1d object array, then
184 each element is assumed to be a unique `(*coord_shape, dim)` array
185 :param field_values: `dict` of `(*loop_shape, *coord_shape)` for each qoi - the field values at the coordinates;
186 if each array is a 1d object array, then each element is assumed to be a unique
187 `(*loop_shape, *coord_shape)` array corresponding to the `field_coords`
188 :return: `(*loop_shape, dof)` - the interpolated values on the compression grid
189 """
190 field_coords = self._correct_coords(field_coords)
191 grid_coords = self._correct_coords(self.coords)
193 # Loop over each set of coordinates and field values (multiple if they are object arrays)
194 # If only one set of coords, then they are assumed the same for each set of field values
195 coords_obj_array = np.issubdtype(field_coords.dtype, np.object_)
196 fields_obj_array = np.issubdtype(next(iter(field_values.values())).dtype, np.object_)
197 def _iterate_coords_and_fields():
198 if coords_obj_array:
199 for index, c in np.ndenumerate(field_coords): # assumes same number of coords and field values
200 yield index, c, {qoi: field_values[qoi][index] for qoi in field_values}
201 elif fields_obj_array:
202 for index in np.ndindex(next(iter(field_values.values())).shape):
203 yield index, field_coords, {qoi: field_values[qoi][index] for qoi in field_values}
204 else:
205 yield (0,), field_coords, field_values # assumes same coords for all field values
207 if coords_obj_array:
208 shape = field_coords.shape
209 elif fields_obj_array:
210 shape = next(iter(field_values.values())).shape
211 else:
212 shape = (1,)
214 always_skip_interp = not coords_obj_array and np.array_equal(field_coords, grid_coords) # must be exact match
216 all_states = np.empty(shape, dtype=object) # are you in good hands?
218 for j, f_coords, f_values in _iterate_coords_and_fields():
219 skip_interp = always_skip_interp or np.array_equal(f_coords, grid_coords) # exact even for floats
221 coords_shape = f_coords.shape[:-1]
222 loop_shape = next(iter(f_values.values())).shape[:-len(coords_shape)]
223 states = np.empty((*loop_shape, self.num_pts, self.num_qoi))
224 f_coords = f_coords.reshape(-1, self.dim)
225 for i, qoi in enumerate(self.fields):
226 field_vals = f_values[qoi].reshape((*loop_shape, -1)) # (..., Q)
227 if skip_interp:
228 states[..., i] = field_vals
229 else:
230 field_vals = field_vals.reshape((-1, field_vals.shape[-1])).T # (Q, ...)
231 interp = self.interpolator()(f_coords, field_vals, **self.interpolate_opts)
232 yg = interp(grid_coords)
233 states[..., i] = yg.T.reshape(*loop_shape, self.num_pts)
234 all_states[j] = states.reshape((*loop_shape, self.dof))
236 # All fields now on the same dof grid, so stack them in same array
237 index = next(np.ndindex(all_states.shape))
238 ret_states = np.empty(shape + all_states[index].shape)
240 for index, arr in np.ndenumerate(all_states):
241 ret_states[index] = arr
243 if not (coords_obj_array or fields_obj_array):
244 ret_states = np.squeeze(ret_states, axis=0) # artificial leading dim for non-object arrays
246 return ret_states
248 @abstractmethod
249 def compute_map(self, **kwargs):
250 """Compute and store the compression map. Must set the value of `coords` and `_is_computed`. Should
251 use the same normalization as the parent `Variable` object.
253 !!! Note
254 You should pass any required data to `compute_map` with the assumption that the data will be used in the
255 shape `(num_pts, num_qoi)` where `num_qoi` is the length of `fields` and `num_pts` is the length of
256 `coords`. This is the shape that the compression map should be constructed in.
257 """
258 raise NotImplementedError
260 @abstractmethod
261 def compress(self, data: np.ndarray) -> np.ndarray:
262 """Compress the data into a latent space.
264 :param data: `(..., dof)` - the data to compress from full size of `dof`
265 :return: `(..., rank)` - the compressed latent space data with size `rank`
266 """
267 raise NotImplementedError
269 @abstractmethod
270 def reconstruct(self, compressed: np.ndarray) -> np.ndarray:
271 """Reconstruct the compressed data back into the full `dof` space.
273 :param compressed: `(..., rank)` - the compressed data to reconstruct
274 :return: `(..., dof)` - the reconstructed data with full `dof`
275 """
276 raise NotImplementedError
278 @abstractmethod
279 def latent_size(self) -> int:
280 """Return the size of the latent space."""
281 raise NotImplementedError
283 @abstractmethod
284 def estimate_latent_ranges(self) -> list[tuple[float, float]]:
285 """Estimate the range of the latent space coefficients."""
286 raise NotImplementedError
288 @classmethod
289 def from_dict(cls, spec: dict) -> Compression:
290 """Construct a `Compression` object from a spec dictionary."""
291 method = spec.pop('method', 'svd').lower()
292 match method:
293 case 'svd':
294 return SVD(**spec)
295 case other:
296 raise NotImplementedError(f"Compression method '{other}' is not implemented.")
299@dataclass
300class SVD(Compression):
301 """A Singular Value Decomposition (SVD) compression method. The SVD will be computed on initialization if the
302 `data_matrix` is provided.
304 :ivar data_matrix: `(dof, num_samples)` - the data matrix
305 :ivar projection_matrix: `(dof, rank)` - the projection matrix
306 :ivar rank: the rank of the SVD decomposition
307 :ivar energy_tol: the energy tolerance of the SVD decomposition
308 :ivar reconstruction_tol: the reconstruction error tolerance of the SVD decomposition
309 """
310 data_matrix: np.ndarray = None # (dof, num_samples)
311 projection_matrix: np.ndarray = None # (dof, rank)
312 rank: int = None
313 energy_tol: float = None
314 reconstruction_tol: float = None
316 def __post_init__(self):
317 """Compute the SVD if the data matrix is provided."""
318 if (data_matrix := self.data_matrix) is not None:
319 self.compute_map(data_matrix, rank=self.rank, energy_tol=self.energy_tol,
320 reconstruction_tol=self.reconstruction_tol)
322 def compute_map(self, data_matrix: np.ndarray | dict, rank: int = None, energy_tol: float = None,
323 reconstruction_tol: float = None):
324 """Compute the SVD compression map from the data matrix. Recall that `dof` is the total number of degrees of
325 freedom, equal to the number of grid points `num_pts` times the number of quantities of interest `num_qoi`
326 at each grid point.
328 **Rank priority:** if `rank` is provided, then it will be used. Otherwise, if `reconstruction_tol` is provided,
329 then the rank will be chosen to meet this reconstruction error level. Finally, if `energy_tol` is provided,
330 then the rank will be chosen to meet this energy fraction level (sum of squared singular values).
332 :param data_matrix: `(dof, num_samples)` - the data matrix. If passed in as a `dict`, then the data matrix
333 will be formed by concatenating the values of the `dict` along the last axis in the order
334 of the `fields` attribute and flattening the last two axes. This is useful for passing
335 in a dictionary of field values like `{field1: (num_samples, num_pts), field2: ...}`
336 which ensures consistency of shape with the compression `coords`.
337 :param rank: the rank of the SVD decomposition
338 :param energy_tol: the energy tolerance of the SVD decomposition (defaults to 0.95)
339 :param reconstruction_tol: the reconstruction error tolerance of the SVD decomposition
340 """
341 if isinstance(data_matrix, dict):
342 data_matrix = np.concatenate([data_matrix[field][..., np.newaxis] for field in self.fields], axis=-1)
343 data_matrix = data_matrix.reshape(*data_matrix.shape[:-2], -1).T # (dof, num_samples)
345 nan_idx = np.any(np.isnan(data_matrix), axis=0)
346 data_matrix = data_matrix[:, ~nan_idx]
347 u, s, vt = np.linalg.svd(data_matrix)
348 energy_frac = np.cumsum(s ** 2 / np.sum(s ** 2))
349 if rank := (rank or self.rank):
350 energy_tol = energy_frac[rank - 1]
351 reconstruction_tol = relative_error(u[:, :rank] @ u[:, :rank].T @ data_matrix, data_matrix)
352 elif reconstruction_tol := (reconstruction_tol or self.reconstruction_tol):
353 rank = u.shape[1]
354 for r in range(1, u.shape[1] + 1):
355 if relative_error(u[:, :r] @ u[:, :r].T @ data_matrix, data_matrix) <= reconstruction_tol:
356 rank = r
357 break
358 energy_tol = energy_frac[rank - 1]
359 else:
360 energy_tol = energy_tol or self.energy_tol or 0.95
361 idx = int(np.where(energy_frac >= energy_tol)[0][0])
362 rank = idx + 1
363 reconstruction_tol = relative_error(u[:, :rank] @ u[:, :rank].T @ data_matrix, data_matrix)
365 self.data_matrix = data_matrix
366 self.rank = rank
367 self.energy_tol = energy_tol
368 self.reconstruction_tol = reconstruction_tol
369 self.projection_matrix = u[:, :rank] # (dof, rank)
370 self._map_computed = True
372 def compress(self, data):
373 return np.squeeze(self.projection_matrix.T @ data[..., np.newaxis], axis=-1)
375 def reconstruct(self, compressed):
376 return np.squeeze(self.projection_matrix @ compressed[..., np.newaxis], axis=-1)
378 def latent_size(self):
379 return self.rank
381 def estimate_latent_ranges(self):
382 if self.map_exists:
383 latent_data = self.compress(self.data_matrix.T) # (rank, num_samples)
384 latent_min = np.min(latent_data, axis=0)
385 latent_max = np.max(latent_data, axis=0)
386 return [(lmin, lmax) for lmin, lmax in zip(latent_min, latent_max)]
387 else:
388 return None