Skip to content
Snippets Groups Projects
Commit 1a7fcaff authored by Hu Zhao's avatar Hu Zhao
Browse files

refactor: use check_bounds

parent 746df515
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ import sys
from abc import ABC
from abc import abstractmethod
from psimpy.sampler.metropolis_hastings import MetropolisHastings
from psimpy.utility.util_funcs import check_bounds
from typing import Union
from beartype.typing import Callable
from beartype import beartype
......@@ -76,12 +77,6 @@ class BayesInferenceBase(ABC):
Keyword arguments for `ln_pxl`.
"""
self.ndim = ndim
if bounds is not None:
if bounds.ndim != 2:
raise ValueError("bounds must be a 2d numpy array")
elif bounds.shape[0] != ndim or bounds.shape[1] != 2:
raise ValueError("bounds must be of shape (ndim, 2)")
self.bounds = bounds
self.args_prior = () if args_prior is None else args_prior
......@@ -175,6 +170,8 @@ class GridEstimation(BayesInferenceBase):
if self.bounds is None:
raise ValueError("bounds must be provided for grid estimation")
else:
check_bounds(self.ndim, self.bounds)
steps = (self.bounds[:,1] - self.bounds[:,0]) / np.array(nbins)
starts = steps/2 + self.bounds[:,0]
......
import numpy as np
from scipy.spatial.distance import pdist
from psimpy.utility.util_funcs import scale_samples
from psimpy.utility.util_funcs import check_bounds, scale_samples
from typing import Union
from beartype import beartype
......@@ -31,7 +31,11 @@ class LHS:
Number of iterations if `criterion='maxmin'`.
"""
self.ndim = ndim
if bounds is not None:
check_bounds(ndim, bounds)
self.bounds = bounds
self.rng = np.random.default_rng(seed)
self.seed = seed
......
import numpy as np
import sys
from psimpy.utility.util_funcs import check_bounds
from typing import Union
from beartype.typing import Callable
from beartype import beartype
......@@ -10,12 +11,17 @@ class MetropolisHastings:
@beartype
def __init__(
self, ndim: int, init_state: np.ndarray, f_sample: Callable,
self,
ndim: int,
init_state: np.ndarray,
f_sample: Callable,
target: Union[Callable, None] = None,
ln_target: Union[Callable, None] = None,
bounds: Union[np.ndarray, None] = None,
f_density: Union[Callable, None] = None,
symmetric: bool = True, nburn: int = 0, nthin: int = 1,
symmetric: bool = True,
nburn: int = 0,
nthin: int = 1,
seed: Union[int, None] = None,
args_target: Union[list, None] = None,
kwgs_target: Union[dict, None] = None,
......@@ -83,10 +89,7 @@ class MetropolisHastings:
raise ValueError(f"init_state must contain ndim={ndim} elements")
if bounds is not None:
if bounds.ndim != 2:
raise ValueError("bounds must be a 2d numpy array")
elif bounds.shape[0] != ndim or bounds.shape[1] != 2:
raise ValueError("bounds must be of shape (ndim, 2)")
check_bounds(ndim, bounds)
if not all([
init_state[i] >= bounds[i,0] and
......
from SALib.sample import saltelli
import numpy as np
from psimpy.utility.util_funcs import check_bounds
from typing import Union
from beartype import beartype
......@@ -31,6 +32,8 @@ class Saltelli:
"""
if bounds is None:
bounds = np.array([[0, 1] for i in range(ndim)])
else:
check_bounds(ndim, bounds)
self.problem = {
'num_vars': ndim,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment