diff --git a/src/psimpy/inference/bayes_inference.py b/src/psimpy/inference/bayes_inference.py index 31a3fc549b8b39bbd525d5bcf4c7bef0d438133f..99801540e09a458b764cb9f662b89bc537697975 100644 --- a/src/psimpy/inference/bayes_inference.py +++ b/src/psimpy/inference/bayes_inference.py @@ -3,6 +3,7 @@ import sys from abc import ABC from abc import abstractmethod from psimpy.sampler.metropolis_hastings import MetropolisHastings +from psimpy.utility.util_funcs import check_bounds from typing import Union from beartype.typing import Callable from beartype import beartype @@ -75,13 +76,7 @@ class BayesInferenceBase(ABC): kwgs_ln_pxl : dict, optional Keyword arguments for `ln_pxl`. """ - self.ndim = ndim - - if bounds is not None: - if bounds.ndim != 2: - raise ValueError("bounds must be a 2d numpy array") - elif bounds.shape[0] != ndim or bounds.shape[1] != 2: - raise ValueError("bounds must be of shape (ndim, 2)") + self.ndim = ndim self.bounds = bounds self.args_prior = () if args_prior is None else args_prior @@ -175,6 +170,8 @@ class GridEstimation(BayesInferenceBase): if self.bounds is None: raise ValueError("bounds must be provided for grid estimation") + else: + check_bounds(self.ndim, self.bounds) steps = (self.bounds[:,1] - self.bounds[:,0]) / np.array(nbins) starts = steps/2 + self.bounds[:,0] diff --git a/src/psimpy/sampler/latin.py b/src/psimpy/sampler/latin.py index 821f1c39bece60b5a21ce2a5871386923470d0b3..3804d80168e1eb275de39c8590913cf383b4a3fc 100644 --- a/src/psimpy/sampler/latin.py +++ b/src/psimpy/sampler/latin.py @@ -1,6 +1,6 @@ import numpy as np from scipy.spatial.distance import pdist -from psimpy.utility.util_funcs import scale_samples +from psimpy.utility.util_funcs import check_bounds, scale_samples from typing import Union from beartype import beartype @@ -31,7 +31,11 @@ class LHS: Number of iterations if `criterion='maxmin'`. """ self.ndim = ndim + + if bounds is not None: + check_bounds(ndim, bounds) self.bounds = bounds + self.rng = np.random.default_rng(seed) self.seed = seed diff --git a/src/psimpy/sampler/metropolis_hastings.py b/src/psimpy/sampler/metropolis_hastings.py index 5298c4e268037c4e5ba604affa39ef2b7939e666..dfb2d23fc406b867035c43c8c8f3b78d69377b6c 100644 --- a/src/psimpy/sampler/metropolis_hastings.py +++ b/src/psimpy/sampler/metropolis_hastings.py @@ -1,5 +1,6 @@ import numpy as np import sys +from psimpy.utility.util_funcs import check_bounds from typing import Union from beartype.typing import Callable from beartype import beartype @@ -10,12 +11,17 @@ class MetropolisHastings: @beartype def __init__( - self, ndim: int, init_state: np.ndarray, f_sample: Callable, + self, + ndim: int, + init_state: np.ndarray, + f_sample: Callable, target: Union[Callable, None] = None, ln_target: Union[Callable, None] = None, bounds: Union[np.ndarray, None] = None, f_density: Union[Callable, None] = None, - symmetric: bool = True, nburn: int = 0, nthin: int = 1, + symmetric: bool = True, + nburn: int = 0, + nthin: int = 1, seed: Union[int, None] = None, args_target: Union[list, None] = None, kwgs_target: Union[dict, None] = None, @@ -83,10 +89,7 @@ class MetropolisHastings: raise ValueError(f"init_state must contain ndim={ndim} elements") if bounds is not None: - if bounds.ndim != 2: - raise ValueError("bounds must be a 2d numpy array") - elif bounds.shape[0] != ndim or bounds.shape[1] != 2: - raise ValueError("bounds must be of shape (ndim, 2)") + check_bounds(ndim, bounds) if not all([ init_state[i] >= bounds[i,0] and diff --git a/src/psimpy/sampler/saltelli.py b/src/psimpy/sampler/saltelli.py index afb9f08a6b52f7e9d0d3a0b1d2fb1cfb97d0b858..f7e1de7cd1db33209b3e4b51d7ce5b9451b0f30b 100644 --- a/src/psimpy/sampler/saltelli.py +++ b/src/psimpy/sampler/saltelli.py @@ -1,5 +1,6 @@ from SALib.sample import saltelli import numpy as np +from psimpy.utility.util_funcs import check_bounds from typing import Union from beartype import beartype @@ -31,6 +32,8 @@ class Saltelli: """ if bounds is None: bounds = np.array([[0, 1] for i in range(ndim)]) + else: + check_bounds(ndim, bounds) self.problem = { 'num_vars': ndim,