Coverage for src/uqtils/uq_types.py: 91%
11 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 03:45 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 03:45 +0000
1"""Custom types"""
2import numpy as np
4Array = list | float | np.ndarray
6__all__ = ['Array', 'format_input']
9def format_input(x: Array, ndim: int) -> tuple[bool, np.ndarray]:
10 """Helper function to make sure input `x` is an `ndarray` of shape `(..., ndim)`.
12 :param x: if 1d-like as `(n,)`, then converted to 2d as `(1, n) if n==ndim or (n, 1) if ndim==1`
13 :param ndim: the dimension of the inputs
14 :returns: `x` as at least a 2d array `(..., ndim)`, and whether `x` was originally 1d-like
15 """
16 x = np.atleast_1d(x)
17 is_1d = len(x.shape) == 1
18 if is_1d:
19 if x.shape[0] != ndim and ndim > 1:
20 raise ValueError(f'Input x shape {x.shape} is incompatible with ndim of {ndim}')
21 x = np.expand_dims(x, axis=0 if x.shape[0] == ndim else 1)
23 return is_1d, x