Coverage for src/amisc/training.py: 88%

265 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-01-24 04:51 +0000

1"""Classes for storing and managing training data for surrogate models. The `TrainingData` interface also 

2specifies how new training data should be sampled over the input space (i.e. experimental design). 

3 

4Includes: 

5 

6- `TrainingData` — an interface for storing surrogate training data. 

7- `SparseGrid` — a class for storing training data in a sparse grid format. 

8""" 

9from __future__ import annotations 

10 

11import copy 

12import itertools 

13from abc import ABC, abstractmethod 

14from dataclasses import dataclass, field 

15from typing import Any, ClassVar 

16 

17import numpy as np 

18from numpy.typing import ArrayLike 

19from scipy.optimize import direct 

20 

21from amisc.serialize import PickleSerializable, Serializable 

22from amisc.typing import LATENT_STR_ID, Dataset, MultiIndex 

23from amisc.utils import _RidgeRegression 

24 

25__all__ = ['TrainingData', 'SparseGrid'] 

26 

27 

28class TrainingData(Serializable, ABC): 

29 """Interface for storing and collecting surrogate training data. `TrainingData` objects should: 

30 

31 - `get` - retrieve the training data 

32 - `set` - store the training data 

33 - `refine` - generate new design points for the parent `Component` model 

34 - `clear` - clear all training data 

35 - `set_errors` - store error information (if desired) 

36 - `impute_missing_data` - fill in missing values in the training data (if desired) 

37 """ 

38 

39 @abstractmethod 

40 def get(self, alpha: MultiIndex, beta: MultiIndex, y_vars: list[str] = None, 

41 skip_nan: bool = False) -> tuple[Dataset, Dataset]: 

42 """Return the training data for a given multi-index pair. 

43 

44 :param alpha: the model fidelity indices 

45 :param beta: the surrogate fidelity indices 

46 :param y_vars: the keys of the outputs to return (if `None`, return all outputs) 

47 :param skip_nan: skip any data points with remaining `nan` values if `skip_nan=True` 

48 :returns: `dicts` of model inputs `x_train` and outputs `y_train` 

49 """ 

50 raise NotImplementedError 

51 

52 @abstractmethod 

53 def set(self, alpha: MultiIndex, beta: MultiIndex, coords: list[Any], yi_dict: Dataset): 

54 """Store training data for a given multi-index pair. 

55 

56 :param alpha: the model fidelity indices 

57 :param beta: the surrogate fidelity indices 

58 :param coords: locations for storing the `yi` values in the underlying data structure 

59 :param yi_dict: a `dict` of model output `yi` values, each entry should be the same length as `coords` 

60 """ 

61 raise NotImplementedError 

62 

63 @abstractmethod 

64 def set_errors(self, alpha: MultiIndex, beta: MultiIndex, coords: list[Any], errors: list[dict]): 

65 """Store error information for a given multi-index pair (just pass if you don't care). 

66 

67 :param alpha: the model fidelity indices 

68 :param beta: the surrogate fidelity indices 

69 :param coords: locations for storing the error information in the underlying data structure 

70 :param errors: a list of error dictionaries, should be the same length as `coords` 

71 """ 

72 raise NotImplementedError 

73 

74 @abstractmethod 

75 def impute_missing_data(self, alpha: MultiIndex, beta: MultiIndex): 

76 """Impute missing values in the training data for a given multi-index pair (just pass if you don't care). 

77 

78 :param alpha: the model fidelity indices 

79 :param beta: the surrogate fidelity indices 

80 """ 

81 raise NotImplementedError 

82 

83 @abstractmethod 

84 def refine(self, alpha: MultiIndex, beta: MultiIndex, input_domains: dict[str, tuple], 

85 weight_fcns: dict[str, callable] = None) -> tuple[list[Any], Dataset]: 

86 """Return new design/training points for a given multi-index pair and their coordinates/locations in the 

87 `TrainingData` storage structure. 

88 

89 !!! Example 

90 ```python 

91 domains = {'x1': (0, 1), 'x2': (0, 1)} 

92 alpha, beta = (0, 1), (1, 1) 

93 coords, x_train = training_data.refine(alpha, beta, domains) 

94 y_train = my_model(x_train) 

95 training_data.set(alpha, beta, coords, y_train) 

96 ``` 

97 

98 The returned data coordinates `coords` should be any object that can be used to locate the corresponding 

99 `x_train` training points in the `TrainingData` storage structure. These `coords` will be passed back to the 

100 `set` function to store the training data at a later time (i.e. after model evaluation). 

101 

102 :param alpha: the model fidelity indices 

103 :param beta: the surrogate fidelity indices 

104 :param input_domains: a `dict` specifying domain bounds for each input variable 

105 :param weight_fcns: a `dict` of weighting functions for each input variable 

106 :returns: a list of new data coordinates `coords` and the corresponding training points `x_train` 

107 """ 

108 raise NotImplementedError 

109 

110 @abstractmethod 

111 def clear(self): 

112 """Clear all training data.""" 

113 raise NotImplementedError 

114 

115 @classmethod 

116 def from_dict(cls, config: dict) -> TrainingData: 

117 """Create a `TrainingData` object from a `dict` configuration. Currently, only `method='sparse-grid'` is 

118 supported for the `SparseGrid` class. 

119 """ 

120 method = config.pop('method', 'sparse-grid').lower() 

121 match method: 

122 case 'sparse-grid': 

123 return SparseGrid(**config) 

124 case other: 

125 raise NotImplementedError(f"Unknown training data method: {other}") 

126 

127 

128@dataclass 

129class SparseGrid(TrainingData, PickleSerializable): 

130 """A class for storing training data in a sparse grid format. The `SparseGrid` class stores training points 

131 by their coordinate location in a larger tensor-product grid, and obtains new training data by refining 

132 a single 1d grid at a time. 

133 

134 !!! Note "MISC and sparse grids" 

135 MISC itself can be thought of as an extension to the well-known sparse grid technique, so this class 

136 readily integrates with the MISC implementation in `Component`. Sparse grids limit the curse 

137 of dimensionality up to about `dim = 10-15` for the input space (which would otherwise be infeasible with a 

138 normal full tensor-product grid of the same size). 

139 

140 !!! Info "About points in a sparse grid" 

141 A sparse grid approximates a full tensor-product grid $(N_1, N_2, ..., N_d)$, where $N_i$ is the number of grid 

142 points along dimension $i$, for a $d$-dimensional space. Each point is uniquely identified in the sparse grid 

143 by a list of indices $(j_1, j_2, ..., j_d)$, where $j_i = 0 ... N_i$. We refer to this unique identifier as a 

144 "grid coordinate". In the `SparseGrid` data structure, these coordinates are used along with the `alpha` 

145 fidelity index to uniquely locate the training data for a given multi-index pair. 

146 

147 :ivar collocation_rule: the collocation rule to use for generating new grid points (only 'leja' is supported) 

148 :ivar knots_per_level: the number of grid knots/points per level in the `beta` fidelity multi-index 

149 :ivar expand_latent_method: method for expanding latent grids, either 'round-robin' or 'tensor-product' 

150 :ivar opt_args: extra arguments for the global 1d `direct` optimizer 

151 :ivar betas: a set of all `beta` multi-indices that have been seen so far 

152 :ivar x_grids: a `dict` of grid points for each 1d input dimension 

153 :ivar yi_map: a `dict` of model outputs for each grid coordinate 

154 :ivar yi_nan_map: a `dict` of imputed model outputs for each grid coordinate where the model failed (or gave nan) 

155 :ivar error_map: a `dict` of error information for each grid coordinate where the model failed 

156 :ivar latent_size: the number of latent coefficients for each variable (0 if scalar) 

157 """ 

158 MAX_IMPUTE_SIZE: ClassVar[int] = 10 # don't try to impute large arrays 

159 

160 collocation_rule: str = 'leja' 

161 knots_per_level: int = 2 

162 expand_latent_method: str = 'round-robin' # or 'tensor-product', for converting beta to latent grid sizes 

163 opt_args: dict = field(default_factory=lambda: {'locally_biased': False, 'maxfun': 300}) # for leja optimizer 

164 

165 betas: set[MultiIndex] = field(default_factory=set) 

166 x_grids: dict[str, ArrayLike] = field(default_factory=dict) 

167 yi_map: dict[MultiIndex, dict[tuple[int, ...], dict[str, ArrayLike]]] = field(default_factory=dict) 

168 yi_nan_map: dict[MultiIndex, dict[tuple[int, ...], dict[str, ArrayLike]]] = field(default_factory=dict) 

169 error_map: dict[MultiIndex, dict[tuple[int, ...], dict[str, Any]]] = field(default_factory=dict) 

170 latent_size: dict[str, int] = field(default_factory=dict) # keep track of latent grid sizes for each variable 

171 

172 def clear(self): 

173 """Clear all training data.""" 

174 self.betas.clear() 

175 self.x_grids.clear() 

176 self.yi_map.clear() 

177 self.yi_nan_map.clear() 

178 self.error_map.clear() 

179 self.latent_size.clear() 

180 

181 def get_by_coord(self, alpha: MultiIndex, coords: list, y_vars: list = None, skip_nan: bool = False): 

182 """Get training data from the sparse grid for a given `alpha` and list of grid coordinates. Try to replace 

183 `nan` values with imputed values. Skip any data points with remaining `nan` values if `skip_nan=True`. 

184 

185 :param alpha: the model fidelity indices 

186 :param coords: a list of grid coordinates to locate the `yi` values in the sparse grid data structure 

187 :param y_vars: the keys of the outputs to return (if `None`, return all outputs) 

188 :param skip_nan: skip any data points with remaining `nan` values if `skip_nan=True` (only for numeric outputs) 

189 :returns: `dicts` of model inputs `xi_dict` and outputs `yi_dict` 

190 """ 

191 N = len(coords) 

192 is_numeric = {} 

193 is_singleton = {} 

194 xi_dict = self._extract_grid_points(coords) 

195 yi_dict = {} 

196 

197 first_yi = next(iter(self.yi_map[alpha].values())) 

198 if y_vars is None: 

199 y_vars = first_yi.keys() 

200 

201 for var in y_vars: 

202 yi = np.atleast_1d(first_yi[var]) 

203 is_numeric[var] = self._is_numeric(yi) 

204 is_singleton[var] = self._is_singleton(yi) 

205 yi_dict[var] = np.empty(N, dtype=np.float64 if is_numeric[var] and is_singleton[var] else object) 

206 

207 for i, coord in enumerate(coords): 

208 try: 

209 yi_curr = self.yi_map[alpha][coord] 

210 for var in y_vars: 

211 yi = arr if (arr := self.yi_nan_map[alpha].get(coord, {}).get(var)) is not None else yi_curr[var] 

212 yi_dict[var][i] = yi if is_singleton[var] else np.atleast_1d(yi) 

213 

214 except KeyError as e: 

215 raise KeyError(f"Can't access sparse grid data for alpha={alpha}, coord={coord}. " 

216 f"Make sure the data has been set first.") from e 

217 

218 # Delete nans if requested (only for numeric singleton outputs) 

219 if skip_nan: 

220 nan_idx = np.full(N, False) 

221 for var in y_vars: 

222 if is_numeric[var] and is_singleton[var]: 

223 nan_idx |= np.isnan(yi_dict[var]) 

224 

225 xi_dict = {k: v[~nan_idx] for k, v in xi_dict.items()} 

226 yi_dict = {k: v[~nan_idx] for k, v in yi_dict.items()} 

227 

228 return xi_dict, yi_dict # Both with elements of shape (N, ...) for N grid points 

229 

230 def get(self, alpha: MultiIndex, beta: MultiIndex, y_vars: list[str] = None, skip_nan: bool = False): 

231 """Get the training data from the sparse grid for a given `alpha` and `beta` pair.""" 

232 return self.get_by_coord(alpha, list(self._expand_grid_coords(beta)), y_vars=y_vars, skip_nan=skip_nan) 

233 

234 def set_errors(self, alpha: MultiIndex, beta: MultiIndex, coords: list, errors: list[dict]): 

235 """Store error information in the sparse-grid for a given multi-index pair.""" 

236 for coord, error in zip(coords, errors): 

237 self.error_map[alpha][coord] = copy.deepcopy(error) 

238 

239 def set(self, alpha: MultiIndex, beta: MultiIndex, coords: list, yi_dict: dict[str, ArrayLike]): 

240 """Store model output `yi_dict` values. 

241 

242 :param alpha: the model fidelity indices 

243 :param beta: the surrogate fidelity indices 

244 :param coords: a list of grid coordinates to locate the `yi` values in the sparse grid data structure 

245 :param yi_dict: a `dict` of model output `yi` values 

246 """ 

247 for i, coord in enumerate(coords): # First dim of yi is loop dim aligning with coords 

248 new_yi = {} 

249 for var, yi in yi_dict.items(): 

250 yi = np.atleast_1d(yi[i]) 

251 new_yi[var] = (float(yi[0]) if self._is_numeric(yi) else yi[0]) if self._is_singleton(yi) else yi.tolist() # noqa: E501 

252 self.yi_map[alpha][coord] = copy.deepcopy(new_yi) 

253 

254 def impute_missing_data(self, alpha: MultiIndex, beta: MultiIndex): 

255 """Impute missing values in the sparse grid for a given multi-index pair by linear regression imputation.""" 

256 imputer, xi_all, yi_all = None, None, None 

257 

258 # only impute (small-length) numeric quantities 

259 yi_dict = next(iter(self.yi_map[alpha].values())) 

260 output_vars = [var for var in self._numeric_outputs(yi_dict) 

261 if len(np.ravel(yi_dict[var])) <= self.MAX_IMPUTE_SIZE] 

262 

263 for coord, yi_dict in self.yi_map[alpha].items(): 

264 if any([np.any(np.isnan(yi_dict[var])) for var in output_vars]): 

265 if imputer is None: 

266 # Grab all 'good' interpolation points and train a simple linear regression fit 

267 xi_all, yi_all = self.get(alpha, beta, y_vars=output_vars, skip_nan=True) 

268 if len(xi_all) == 0 or len(next(iter(xi_all.values()))) == 0: 

269 continue # possible if no good data has been set yet 

270 

271 N = next(iter(xi_all.values())).shape[0] # number of grid points 

272 xi_mat = np.concatenate([xi_all[var][:, np.newaxis] if len(xi_all[var].shape) == 1 else 

273 xi_all[var] for var in xi_all.keys()], axis=-1) 

274 yi_mat = np.concatenate([yi_all[var][:, np.newaxis] if len(yi_all[var].shape) == 1 else 

275 yi_all[var].reshape((N, -1)) for var in output_vars], axis=-1) 

276 

277 imputer = _RidgeRegression(alpha=1.0) 

278 imputer.fit(xi_mat, yi_mat) 

279 

280 # Run the imputer for this coordinate 

281 x_interp = self._extract_grid_points(coord) 

282 x_interp = np.concatenate([x_interp[var][:, np.newaxis] if len(x_interp[var].shape) == 1 else 

283 x_interp[var] for var in x_interp.keys()], axis=-1) 

284 y_interp = imputer.predict(x_interp) 

285 

286 # Unpack the imputed value 

287 y_impute = {} 

288 start_idx = 0 

289 for var in output_vars: 

290 var_shape = yi_all[var].shape[1:] or (1,) 

291 end_idx = start_idx + int(np.prod(var_shape)) 

292 yi = np.atleast_1d(y_interp[0, start_idx:end_idx]).reshape(var_shape) 

293 nan_idx = np.isnan(np.atleast_1d(yi_dict[var])) 

294 yi[~nan_idx] = np.atleast_1d(yi_dict[var])[~nan_idx] # Only keep imputed values where yi is nan 

295 y_impute[var] = float(yi[0]) if self._is_singleton(yi) else yi.tolist() 

296 start_idx = end_idx 

297 

298 self.yi_nan_map[alpha][coord] = copy.deepcopy(y_impute) 

299 

300 def refine(self, alpha: MultiIndex, beta: MultiIndex, input_domains: dict, weight_fcns: dict = None): 

301 """Refine the sparse grid for a given `alpha` and `beta` pair and given collocation rules. Return any new 

302 grid points that do not have model evaluations saved yet. 

303 

304 !!! Note 

305 The `beta` multi-index is used to determine the number of collocation points in each input dimension. The 

306 length of `beta` should therefore match the number of variables in `x_vars`. 

307 """ 

308 weight_fcns = weight_fcns or {} 

309 

310 # Initialize a sparse grid for beta=(0, 0, ..., 0) 

311 if np.sum(beta) == 0: 

312 if len(self.x_grids) == 0: 

313 num_latent = {} 

314 for var in input_domains: 

315 if LATENT_STR_ID in var: 

316 base_id = var.split(LATENT_STR_ID)[0] 

317 num_latent[base_id] = 1 if base_id not in num_latent else num_latent[base_id] + 1 

318 else: 

319 num_latent[var] = 0 

320 self.latent_size = num_latent 

321 

322 new_pt = {} 

323 domains = iter(input_domains.items()) 

324 for grid_size in self.beta_to_knots(beta): 

325 if isinstance(grid_size, int): # scalars 

326 var, domain = next(domains) 

327 new_pt[var] = self.collocation_1d(grid_size, domain, method=self.collocation_rule, 

328 wt_fcn=weight_fcns.get(var, None), 

329 opt_args=self.opt_args).tolist() 

330 else: # latent coeffs 

331 for s in grid_size: 

332 var, domain = next(domains) 

333 new_pt[var] = self.collocation_1d(s, domain, method=self.collocation_rule, 

334 wt_fcn=weight_fcns.get(var, None), 

335 opt_args=self.opt_args).tolist() 

336 self.x_grids = new_pt 

337 self.betas.add(beta) 

338 self.yi_map.setdefault(alpha, dict()) 

339 self.yi_nan_map.setdefault(alpha, dict()) 

340 self.error_map.setdefault(alpha, dict()) 

341 new_coords = list(self._expand_grid_coords(beta)) 

342 return new_coords, self._extract_grid_points(new_coords) 

343 

344 # Otherwise, refine the sparse grid 

345 for beta_old in self.betas: 

346 # Get the first lower neighbor in the sparse grid and refine the 1d grid if necessary 

347 if self.is_one_level_refinement(beta_old, beta): 

348 new_grid_size = self.beta_to_knots(beta) 

349 inputs = zip(self.x_grids.keys(), self.x_grids.values(), input_domains.values()) 

350 

351 for new_size in new_grid_size: 

352 if isinstance(new_size, int): # scalar grid 

353 var, grid, domain = next(inputs) 

354 if len(grid) < new_size: 

355 num_new_pts = new_size - len(grid) 

356 self.x_grids[var] = self.collocation_1d(num_new_pts, domain, grid, opt_args=self.opt_args, 

357 wt_fcn=weight_fcns.get(var, None), 

358 method=self.collocation_rule).tolist() 

359 else: # latent grid 

360 for s_new in new_size: 

361 var, grid, domain = next(inputs) 

362 if len(grid) < s_new: 

363 num_new_pts = s_new - len(grid) 

364 self.x_grids[var] = self.collocation_1d(num_new_pts, domain, grid, 

365 opt_args=self.opt_args, 

366 wt_fcn=weight_fcns.get(var, None), 

367 method=self.collocation_rule).tolist() 

368 break 

369 

370 new_coords = [] 

371 for coord in self._expand_grid_coords(beta): 

372 if coord not in self.yi_map[alpha]: 

373 # If we have not computed this grid coordinate yet 

374 new_coords.append(coord) 

375 

376 new_pts = self._extract_grid_points(new_coords) 

377 

378 self.betas.add(beta) 

379 return new_coords, new_pts 

380 

381 def _extract_grid_points(self, coords: list[tuple] | tuple): 

382 """Extract the `x` grid points located at `coords` from `x_grids` and return as the `pts` dictionary.""" 

383 if not isinstance(coords, list): 

384 coords = [coords] 

385 pts = {var: np.empty(len(coords)) for var in self.x_grids} 

386 

387 for k, coord in enumerate(coords): 

388 grids = iter(self.x_grids.items()) 

389 for idx in coord: 

390 if isinstance(idx, int): # scalar grid point 

391 var, grid = next(grids) 

392 pts[var][k] = grid[idx] 

393 else: # latent coefficients 

394 for i in idx: 

395 var, grid = next(grids) 

396 pts[var][k] = grid[i] 

397 

398 return pts 

399 

400 def _expand_grid_coords(self, beta: MultiIndex): 

401 """Iterable over all grid coordinates for a given `beta`, accounting for scalars and latent coefficients.""" 

402 grid_size = self.beta_to_knots(beta) 

403 grid_coords = [] 

404 for s in grid_size: 

405 if isinstance(s, int): # scalars 

406 grid_coords.append(range(s)) 

407 else: # latent coefficients 

408 grid_coords.append(itertools.product(*[range(latent_size) for latent_size in s])) 

409 

410 yield from itertools.product(*grid_coords) 

411 

412 @staticmethod 

413 def _is_singleton(arr: np.ndarray): 

414 return len(arr.shape) == 1 and arr.shape[0] == 1 

415 

416 @staticmethod 

417 def _is_numeric(arr: np.ndarray): 

418 return np.issubdtype(arr.dtype, np.number) 

419 

420 @classmethod 

421 def _numeric_outputs(cls, yi_dict: dict[str, ArrayLike]) -> list[str]: 

422 """Return a list of the output variables that have numeric data.""" 

423 output_vars = [] 

424 for var in yi_dict.keys(): 

425 try: 

426 if cls._is_numeric(np.atleast_1d(yi_dict[var])): 

427 output_vars.append(var) 

428 except Exception: 

429 continue 

430 return output_vars 

431 

432 @staticmethod 

433 def is_one_level_refinement(beta_old: tuple, beta_new: tuple) -> bool: 

434 """Check if a new `beta` multi-index is a one-level refinement from a previous `beta`. 

435 

436 !!! Example 

437 Refining from `(0, 1, 2)` to the new multi-index `(1, 1, 2)` is a one-level refinement. But refining to 

438 either `(2, 1, 2)` or `(1, 2, 2)` are not, since more than one refinement occurs at the same time. 

439 

440 :param beta_old: the starting multi-index 

441 :param beta_new: the new refined multi-index 

442 :returns: whether `beta_new` is a one-level refinement from `beta_old` 

443 """ 

444 level_diff = np.array(beta_new, dtype=int) - np.array(beta_old, dtype=int) 

445 ind = np.nonzero(level_diff)[0] 

446 return ind.shape[0] == 1 and level_diff[ind] == 1 

447 

448 def beta_to_knots(self, beta: MultiIndex, knots_per_level: int = None, latent_size: dict = None, 

449 expand_latent_method: str = None) -> tuple: 

450 """Convert a `beta` multi-index to the number of knots per dimension in the sparse grid. 

451 

452 :param beta: refinement level indices 

453 :param knots_per_level: level-to-grid-size multiplier, i.e. number of new points (or knots) for each beta level 

454 :param latent_size: the number of latent coefficients for each variable (0 if scalar); number of variables and 

455 order should match the `beta` multi-index 

456 :param expand_latent_method: method for expanding latent grids, either 'round-robin' or 'tensor-product' 

457 :returns: the number of knots/points per dimension for the sparse grid 

458 """ 

459 knots_per_level = knots_per_level or self.knots_per_level 

460 latent_size = latent_size or self.latent_size 

461 expand_latent_method = expand_latent_method or self.expand_latent_method 

462 

463 grid_size = [] 

464 for i, (var, num_latent) in enumerate(latent_size.items()): 

465 if num_latent > 0: 

466 match expand_latent_method: 

467 case 'round-robin': 

468 if beta[i] == 0: 

469 grid_size.append((1,) * num_latent) # initializes all latent grids to 1 

470 else: 

471 latent_refine_idx = (beta[i] - 1) % num_latent 

472 latent_refine_num = ((beta[i] - 1) // num_latent) + 1 

473 latent_beta = tuple([latent_refine_num] * (latent_refine_idx + 1) + 

474 [latent_refine_num - 1] * (num_latent - latent_refine_idx - 1)) 

475 latent_grid = [knots_per_level * latent_beta[j] + 1 for j in range(num_latent)] 

476 grid_size.append(tuple(latent_grid)) 

477 case 'tensor-product': 

478 grid_size.append((knots_per_level * beta[i] + 1,) * num_latent) 

479 case other: 

480 raise NotImplementedError(f"Unknown method for expanding latent grids: {other}") 

481 else: 

482 grid_size.append(knots_per_level * beta[i] + 1) 

483 

484 return tuple(grid_size) 

485 

486 @staticmethod 

487 def collocation_1d(N: int, z_bds: tuple, z_pts: np.ndarray = None, 

488 wt_fcn: callable = None, method='leja', opt_args=None) -> np.ndarray: 

489 """Find the next `N` points in the 1d sequence of `z_pts` using the provided collocation method. 

490 

491 :param N: number of new points to add to the sequence 

492 :param z_bds: bounds on the 1d domain 

493 :param z_pts: current univariate sequence `(Nz,)`, start at middle of `z_bds` if `None` 

494 :param wt_fcn: weighting function, uses a constant weight if `None`, callable as `wt_fcn(z)` 

495 :param method: collocation method to use, currently only 'leja' is supported 

496 :param opt_args: extra arguments for the global 1d `direct` optimizer 

497 :returns: the univariate sequence `z_pts` augmented by `N` new points 

498 """ 

499 opt_args = opt_args or {} 

500 if wt_fcn is None: 

501 wt_fcn = lambda z: 1 

502 if z_pts is None: 

503 z_pts = (z_bds[1] + z_bds[0]) / 2 

504 N = N - 1 

505 z_pts = np.atleast_1d(z_pts) 

506 

507 match method: 

508 case 'leja': 

509 # Construct Leja sequence by maximizing the Leja objective sequentially 

510 for i in range(N): 

511 obj_fun = lambda z: -wt_fcn(np.array(z)) * np.prod(np.abs(z - z_pts)) 

512 res = direct(obj_fun, [z_bds], **opt_args) # Use global DIRECT optimization over 1d domain 

513 z_star = res.x 

514 z_pts = np.concatenate((z_pts, z_star)) 

515 case other: 

516 raise NotImplementedError(f"Unknown collocation method: {other}") 

517 

518 return z_pts