From 746df51514bf409a708ba4ed2b82d1de8e724cf2 Mon Sep 17 00:00:00 2001 From: Hu Zhao <zhao@mbd.rwth-aachen.de> Date: Thu, 27 Oct 2022 16:50:54 +0200 Subject: [PATCH] feat: add check_bounds function --- src/psimpy/utility/util_funcs.py | 40 +++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/src/psimpy/utility/util_funcs.py b/src/psimpy/utility/util_funcs.py index 7fad221..305eb81 100644 --- a/src/psimpy/utility/util_funcs.py +++ b/src/psimpy/utility/util_funcs.py @@ -1,6 +1,29 @@ import numpy as np from beartype import beartype +@beartype +def check_bounds(ndim: int, bounds: np.ndarray) -> None: + """Check if bounds are valid. + + Parameters + ---------- + ndim : int + Parameter dimension. + bounds: numpy array + Bounds of the `ndim` parameters, where bounds[:, 0] and bounds[:, 1] + correspond to lower and upper bounds, respectively. Shape (ndim, 2). + """ + if bounds.ndim != 2: + raise ValueError("bounds must be a 2d numpy array") + elif bounds.shape[0] != ndim or bounds.shape[1] != 2: + raise ValueError("bounds must be of shape (ndim, 2)") + + lower_bounds = bounds[:,0] + upper_bounds = bounds[:,1] + if np.any(lower_bounds >= upper_bounds): + raise ValueError( + "Lower bounds must be smaller than corresponding upper bounds") + @beartype def scale_samples(samples: np.ndarray, bounds: np.ndarray) -> np.ndarray: """Scale samples from a unit hypercube to arbitrary `bounds`. @@ -19,26 +42,17 @@ def scale_samples(samples: np.ndarray, bounds: np.ndarray) -> np.ndarray: ------- scaled_samples : np.ndarray (nsamples, ndim) - """ + """ if not samples.ndim == 2: raise ValueError("samples must be a 2D array") - - if not bounds.ndim == 2: - raise ValueError("bounds must be a 2D array") - if not samples.shape[1] == bounds.shape[0]: - raise ValueError("The dimension of parameters in samples must match " - "that in bounds") + ndim = samples.shape[1] + check_bounds(ndim, bounds) lower_bounds = bounds[:,0] - upper_bounds = bounds[:,1] - if np.any(lower_bounds >= upper_bounds): - raise ValueError("All lower_bounds (bounds[:,0]) must be smaller " - "than upper_bounds (bounds[:,1])") - + upper_bounds = bounds[:,1] scaled_samples = samples*(upper_bounds-lower_bounds) + lower_bounds return scaled_samples - \ No newline at end of file -- GitLab