Skip to content
Snippets Groups Projects
Commit 746df515 authored by Hu Zhao's avatar Hu Zhao
Browse files

feat: add check_bounds function

parent c999fe7d
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
from beartype import beartype from beartype import beartype
@beartype
def check_bounds(ndim: int, bounds: np.ndarray) -> None:
"""Check if bounds are valid.
Parameters
----------
ndim : int
Parameter dimension.
bounds: numpy array
Bounds of the `ndim` parameters, where bounds[:, 0] and bounds[:, 1]
correspond to lower and upper bounds, respectively. Shape (ndim, 2).
"""
if bounds.ndim != 2:
raise ValueError("bounds must be a 2d numpy array")
elif bounds.shape[0] != ndim or bounds.shape[1] != 2:
raise ValueError("bounds must be of shape (ndim, 2)")
lower_bounds = bounds[:,0]
upper_bounds = bounds[:,1]
if np.any(lower_bounds >= upper_bounds):
raise ValueError(
"Lower bounds must be smaller than corresponding upper bounds")
@beartype @beartype
def scale_samples(samples: np.ndarray, bounds: np.ndarray) -> np.ndarray: def scale_samples(samples: np.ndarray, bounds: np.ndarray) -> np.ndarray:
"""Scale samples from a unit hypercube to arbitrary `bounds`. """Scale samples from a unit hypercube to arbitrary `bounds`.
...@@ -23,22 +46,13 @@ def scale_samples(samples: np.ndarray, bounds: np.ndarray) -> np.ndarray: ...@@ -23,22 +46,13 @@ def scale_samples(samples: np.ndarray, bounds: np.ndarray) -> np.ndarray:
if not samples.ndim == 2: if not samples.ndim == 2:
raise ValueError("samples must be a 2D array") raise ValueError("samples must be a 2D array")
if not bounds.ndim == 2: ndim = samples.shape[1]
raise ValueError("bounds must be a 2D array") check_bounds(ndim, bounds)
if not samples.shape[1] == bounds.shape[0]:
raise ValueError("The dimension of parameters in samples must match "
"that in bounds")
lower_bounds = bounds[:,0] lower_bounds = bounds[:,0]
upper_bounds = bounds[:,1] upper_bounds = bounds[:,1]
if np.any(lower_bounds >= upper_bounds):
raise ValueError("All lower_bounds (bounds[:,0]) must be smaller "
"than upper_bounds (bounds[:,1])")
scaled_samples = samples*(upper_bounds-lower_bounds) + lower_bounds scaled_samples = samples*(upper_bounds-lower_bounds) + lower_bounds
return scaled_samples return scaled_samples
\ No newline at end of file
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment