From 1a7fcaffafbc80e64c2b266303e6fe1152946d6a Mon Sep 17 00:00:00 2001
From: Hu Zhao <zhao@mbd.rwth-aachen.de>
Date: Thu, 27 Oct 2022 16:52:15 +0200
Subject: [PATCH] refactor: use check_bounds

---
 src/psimpy/inference/bayes_inference.py   | 11 ++++-------
 src/psimpy/sampler/latin.py               |  6 +++++-
 src/psimpy/sampler/metropolis_hastings.py | 15 +++++++++------
 src/psimpy/sampler/saltelli.py            |  3 +++
 4 files changed, 21 insertions(+), 14 deletions(-)

diff --git a/src/psimpy/inference/bayes_inference.py b/src/psimpy/inference/bayes_inference.py
index 31a3fc5..9980154 100644
--- a/src/psimpy/inference/bayes_inference.py
+++ b/src/psimpy/inference/bayes_inference.py
@@ -3,6 +3,7 @@ import sys
 from abc import ABC
 from abc import abstractmethod
 from psimpy.sampler.metropolis_hastings import MetropolisHastings
+from psimpy.utility.util_funcs import check_bounds
 from typing import Union
 from beartype.typing import Callable
 from beartype import beartype
@@ -75,13 +76,7 @@ class BayesInferenceBase(ABC):
         kwgs_ln_pxl : dict, optional
             Keyword arguments for `ln_pxl`.
         """
-        self.ndim = ndim
-
-        if bounds is not None:
-            if bounds.ndim != 2:
-                raise ValueError("bounds must be a 2d numpy array")
-            elif bounds.shape[0] != ndim or bounds.shape[1] != 2:
-                raise ValueError("bounds must be of shape (ndim, 2)")
+        self.ndim = ndim            
         self.bounds = bounds
 
         self.args_prior = () if args_prior is None else args_prior
@@ -175,6 +170,8 @@ class GridEstimation(BayesInferenceBase):
 
         if self.bounds is None:
             raise ValueError("bounds must be provided for grid estimation")
+        else:
+            check_bounds(self.ndim, self.bounds)
         
         steps = (self.bounds[:,1] - self.bounds[:,0]) / np.array(nbins) 
         starts = steps/2 + self.bounds[:,0]
diff --git a/src/psimpy/sampler/latin.py b/src/psimpy/sampler/latin.py
index 821f1c3..3804d80 100644
--- a/src/psimpy/sampler/latin.py
+++ b/src/psimpy/sampler/latin.py
@@ -1,6 +1,6 @@
 import numpy as np
 from scipy.spatial.distance import pdist
-from psimpy.utility.util_funcs import scale_samples
+from psimpy.utility.util_funcs import check_bounds, scale_samples
 from typing import Union
 from beartype import beartype
 
@@ -31,7 +31,11 @@ class LHS:
             Number of iterations if `criterion='maxmin'`.
         """
         self.ndim = ndim
+
+        if bounds is not None:
+            check_bounds(ndim, bounds)
         self.bounds = bounds
+        
         self.rng = np.random.default_rng(seed)
         self.seed = seed
         
diff --git a/src/psimpy/sampler/metropolis_hastings.py b/src/psimpy/sampler/metropolis_hastings.py
index 5298c4e..dfb2d23 100644
--- a/src/psimpy/sampler/metropolis_hastings.py
+++ b/src/psimpy/sampler/metropolis_hastings.py
@@ -1,5 +1,6 @@
 import numpy as np
 import sys
+from psimpy.utility.util_funcs import check_bounds
 from typing import Union
 from beartype.typing import Callable
 from beartype import beartype
@@ -10,12 +11,17 @@ class MetropolisHastings:
     
     @beartype
     def __init__(
-        self, ndim: int, init_state: np.ndarray, f_sample: Callable,
+        self,
+        ndim: int,
+        init_state: np.ndarray,
+        f_sample: Callable,
         target: Union[Callable, None] = None,
         ln_target: Union[Callable, None] = None,
         bounds: Union[np.ndarray, None] = None,
         f_density: Union[Callable, None] = None,
-        symmetric: bool = True, nburn: int = 0, nthin: int = 1,
+        symmetric: bool = True,
+        nburn: int = 0,
+        nthin: int = 1,
         seed: Union[int, None] = None,
         args_target: Union[list, None] = None,
         kwgs_target: Union[dict, None] = None,
@@ -83,10 +89,7 @@ class MetropolisHastings:
             raise ValueError(f"init_state must contain ndim={ndim} elements")
         
         if bounds is not None:
-            if bounds.ndim != 2:
-                raise ValueError("bounds must be a 2d numpy array")
-            elif bounds.shape[0] != ndim or bounds.shape[1] != 2:
-                raise ValueError("bounds must be of shape (ndim, 2)")
+            check_bounds(ndim, bounds)
 
             if not all([
                 init_state[i] >= bounds[i,0] and
diff --git a/src/psimpy/sampler/saltelli.py b/src/psimpy/sampler/saltelli.py
index afb9f08..f7e1de7 100644
--- a/src/psimpy/sampler/saltelli.py
+++ b/src/psimpy/sampler/saltelli.py
@@ -1,5 +1,6 @@
 from SALib.sample import saltelli
 import numpy as np
+from psimpy.utility.util_funcs import check_bounds
 from typing import Union
 from beartype import beartype
 
@@ -31,6 +32,8 @@ class Saltelli:
         """
         if bounds is None:
             bounds = np.array([[0, 1] for i in range(ndim)])
+        else:
+            check_bounds(ndim, bounds)
 
         self.problem = {
             'num_vars': ndim,
-- 
GitLab