Skip to content
Snippets Groups Projects
Commit 746df515 authored by Hu Zhao's avatar Hu Zhao
Browse files

feat: add check_bounds function

parent c999fe7d
Branches
Tags
No related merge requests found
import numpy as np import numpy as np
from beartype import beartype from beartype import beartype
@beartype
def check_bounds(ndim: int, bounds: np.ndarray) -> None:
"""Check if bounds are valid.
Parameters
----------
ndim : int
Parameter dimension.
bounds: numpy array
Bounds of the `ndim` parameters, where bounds[:, 0] and bounds[:, 1]
correspond to lower and upper bounds, respectively. Shape (ndim, 2).
"""
if bounds.ndim != 2:
raise ValueError("bounds must be a 2d numpy array")
elif bounds.shape[0] != ndim or bounds.shape[1] != 2:
raise ValueError("bounds must be of shape (ndim, 2)")
lower_bounds = bounds[:,0]
upper_bounds = bounds[:,1]
if np.any(lower_bounds >= upper_bounds):
raise ValueError(
"Lower bounds must be smaller than corresponding upper bounds")
@beartype @beartype
def scale_samples(samples: np.ndarray, bounds: np.ndarray) -> np.ndarray: def scale_samples(samples: np.ndarray, bounds: np.ndarray) -> np.ndarray:
"""Scale samples from a unit hypercube to arbitrary `bounds`. """Scale samples from a unit hypercube to arbitrary `bounds`.
...@@ -23,22 +46,13 @@ def scale_samples(samples: np.ndarray, bounds: np.ndarray) -> np.ndarray: ...@@ -23,22 +46,13 @@ def scale_samples(samples: np.ndarray, bounds: np.ndarray) -> np.ndarray:
if not samples.ndim == 2: if not samples.ndim == 2:
raise ValueError("samples must be a 2D array") raise ValueError("samples must be a 2D array")
if not bounds.ndim == 2: ndim = samples.shape[1]
raise ValueError("bounds must be a 2D array") check_bounds(ndim, bounds)
if not samples.shape[1] == bounds.shape[0]:
raise ValueError("The dimension of parameters in samples must match "
"that in bounds")
lower_bounds = bounds[:,0] lower_bounds = bounds[:,0]
upper_bounds = bounds[:,1] upper_bounds = bounds[:,1]
if np.any(lower_bounds >= upper_bounds):
raise ValueError("All lower_bounds (bounds[:,0]) must be smaller "
"than upper_bounds (bounds[:,1])")
scaled_samples = samples*(upper_bounds-lower_bounds) + lower_bounds scaled_samples = samples*(upper_bounds-lower_bounds) + lower_bounds
return scaled_samples return scaled_samples
\ No newline at end of file
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment