Coverage for src/uqtils/uq_types.py: 91%

11 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 03:45 +0000

1"""Custom types""" 

2import numpy as np 

3 

4Array = list | float | np.ndarray 

5 

6__all__ = ['Array', 'format_input'] 

7 

8 

9def format_input(x: Array, ndim: int) -> tuple[bool, np.ndarray]: 

10 """Helper function to make sure input `x` is an `ndarray` of shape `(..., ndim)`. 

11 

12 :param x: if 1d-like as `(n,)`, then converted to 2d as `(1, n) if n==ndim or (n, 1) if ndim==1` 

13 :param ndim: the dimension of the inputs 

14 :returns: `x` as at least a 2d array `(..., ndim)`, and whether `x` was originally 1d-like 

15 """ 

16 x = np.atleast_1d(x) 

17 is_1d = len(x.shape) == 1 

18 if is_1d: 

19 if x.shape[0] != ndim and ndim > 1: 

20 raise ValueError(f'Input x shape {x.shape} is incompatible with ndim of {ndim}') 

21 x = np.expand_dims(x, axis=0 if x.shape[0] == ndim else 1) 

22 

23 return is_1d, x