diff --git a/src/psimpy/inference/bayes_inference.py b/src/psimpy/inference/bayes_inference.py
index e56b25a94fc3f9065314e49659c666f2eacf75d4..31a3fc549b8b39bbd525d5bcf4c7bef0d438133f 100644
--- a/src/psimpy/inference/bayes_inference.py
+++ b/src/psimpy/inference/bayes_inference.py
@@ -1,16 +1,33 @@
 import numpy as np
 import sys
-from psimpy.sampler.metropolis_hastings import MetropolisHastings
 from abc import ABC
 from abc import abstractmethod
+from psimpy.sampler.metropolis_hastings import MetropolisHastings
+from typing import Union
+from beartype.typing import Callable
+from beartype import beartype
+
+_min_float = 10**(sys.float_info.min_10_exp)
+
 
 class BayesInferenceBase(ABC):
 
+    @beartype
     def __init__(
-        self, ndim, prior=None, likelihood=None, ln_prior=None,
-        ln_likelihood=None, ln_pxl=None, bounds=None, args_prior=None,
-        kwgs_prior=None, args_likelihood=None, kwgs_likelihood=None,
-        args_ln_pxl=None, kwgs_ln_pxl=None):
+        self,
+        ndim: int,
+        bounds: Union[np.ndarray, None] = None,
+        prior: Union[Callable, None] = None,
+        likelihood: Union[Callable, None] = None,
+        ln_prior: Union[Callable, None] = None,
+        ln_likelihood: Union[Callable, None] = None,
+        ln_pxl: Union[Callable, None] = None,
+        args_prior: Union[list, None] = None,
+        kwgs_prior: Union[dict, None] = None,
+        args_likelihood: Union[list, None] = None,
+        kwgs_likelihood: Union[dict, None]=None,
+        args_ln_pxl: Union[list, None] = None,
+        kwgs_ln_pxl: Union[dict, None]=None) -> None:
         """
         A base class to set up basic input for Bayesian inference.
     
@@ -18,30 +35,33 @@ class BayesInferenceBase(ABC):
         ----------
         ndim : int
             Parameter dimension.
-        prior : callable
+        bounds : numpy array
+            Upper and lower boundaries of each parameter. Shape (ndim, 2).
+            bounds[:, 0] - lower boundaries of each parameter.
+            bounds[:, 1] - upper boundaries of each parameter.
+        prior : Callable
             Prior probability density function.
             Call with `prior(x, *args_prior, **kwgs_prior)`.
-            Return the prior probability density value at `x`.
-        likelihood : callable
+            Return the prior probability density value at x, where x is a one
+            dimension numpy array of shape (ndim,).
+        likelihood : Callable
             Likelihood function.
             Call with `likelihood(x, *args_likelihood, **kwgs_likelihood)`.
-            Return the likelihood value at `x`.
-        ln_prior : callable
+            Return the likelihood value at x.
+        ln_prior : Callable
             Natural logarithm of prior probability density function.
             Call with `ln_prior(x, *args_prior, **kwgs_prior)`.
             Return the natural logarithm value of prior probability density
-            value at `x`.
-        ln_likelihood : callable
+            value at x.
+        ln_likelihood : Callable
             Natural logarithm of likelihood function.
             Call with `ln_likelihood(x, *args_likelihood, **kwgs_likelihood)`.
-            Return the natural logarithm value of likelihood value at `x`.
-        ln_pxl : callable
+            Return the natural logarithm value of likelihood value at x.
+        ln_pxl : Callable
             Natural logarithm of the product of prior times likelihood.
-        bounds : numpy array
-            Upper and lower boundaries of each parameter.
-            Shape `(ndim, 2)`.
-            `bounds[:, 0]` - lower boundaries of each parameter.
-            `bounds[:, 1]` - upper boundaries of each parameter.   
+            Call with `ln_pxl(x, *args_ln_pxl, **kwgs_ln_pxl)`.
+            Return the natural logarithm value of prior times the natural
+            logarithm value of likelihood at x.    
         args_prior : list, optional
             Positional arguments for `prior` or `ln_prior`.
         kwgs_prior: dict, optional
@@ -54,34 +74,37 @@ class BayesInferenceBase(ABC):
             Positional arguments for `ln_pxl`.
         kwgs_ln_pxl : dict, optional
             Keyword arguments for `ln_pxl`.
-        
         """
-        
         self.ndim = ndim
+
+        if bounds is not None:
+            if bounds.ndim != 2:
+                raise ValueError("bounds must be a 2d numpy array")
+            elif bounds.shape[0] != ndim or bounds.shape[1] != 2:
+                raise ValueError("bounds must be of shape (ndim, 2)")
         self.bounds = bounds
 
         self.args_prior = () if args_prior is None else args_prior
-        self.args_likelihood = () if args_likelihood is None else args_likelihood
-        self.args_ln_pxl = () if args_ln_pxl is None else args_ln_pxl
-
         self.kwgs_prior = {} if kwgs_prior is None else kwgs_prior
+
+        self.args_likelihood = () if args_likelihood is None else args_likelihood
         self.kwgs_likelihood = {} if kwgs_likelihood is None else kwgs_likelihood
+
+        self.args_ln_pxl = () if args_ln_pxl is None else args_ln_pxl
         self.kwgs_ln_pxl = {} if kwgs_ln_pxl is None else kwgs_ln_pxl
 
         if ln_pxl is not None:
             self.ln_pxl = ln_pxl
-            self.args_ln_pxl = args_ln_pxl
-            self.kwgs_ln_pxl = kwgs_ln_pxl
         elif (ln_prior is not None) and (ln_likelihood is not None):
             self.ln_prior = ln_prior
             self.ln_likelihood = ln_likelihood
-            self.ln_pxl = self._ln_pxl1
+            self.ln_pxl = self._ln_pxl_1
             self.args_ln_pxl = ()
             self.kwgs_ln_pxl = {}
         elif (prior is not None) and (likelihood is not None):
             self.prior = prior
             self.likelihood = likelihood
-            self.ln_pxl = self._ln_pxl2
+            self.ln_pxl = self._ln_pxl_2
             self.args_ln_pxl = ()
             self.kwgs_ln_pxl = {}
         else:
@@ -91,28 +114,27 @@ class BayesInferenceBase(ABC):
                 " (2) `ln_prior` and `ln_likelihood`"
                 " (3) `prior` and `likelihood`"))
     
-    def _ln_pxl1(self, x):
+    def _ln_pxl_1(self, x: np.ndarray) -> float:
         """
         Construct natural logarithm of the product of prior and likelihood,
         given natural logarithm of prior function and natural logarithm of
         likelihood function.
         """
-        ln_pxl = \
-            self.ln_prior(x, *self.args_ln_prior, **self.kwgs_ln_prior) + \
-            self.ln_likelihood(x, *self.args_ln_likelihood, 
-                               **self.kwgs_ln_likelihood)
-        return ln_pxl
+        ln_pxl = self.ln_prior(x, *self.args_prior, **self.kwgs_prior) + \
+            self.ln_likelihood(x, *self.args_likelihood, **self.kwgs_likelihood)
+
+        return float(ln_pxl)
     
-    def _ln_pxl2(self, x):
+    def _ln_pxl_2(self, x: np.ndarray) -> float:
         """
         Construct natural logarithm of the product of prior and likelihood,
         given prior function and likelihood function.
         """
         pxl = self.prior(x, *self.args_prior, **self.kwgs_prior) * \
             self.likelihood(x, *self.args_likelihood, **self.kwgs_likelihood)
-        min_10_exp = sys.float_info.min_10_exp
+        ln_pxl = np.log(np.maximum(pxl, _min_float))
         
-        return np.log(np.maximum(pxl, 10**(min_10_exp)))
+        return float(ln_pxl)
 
     @abstractmethod
     def run(self, *args, **kwgs):
@@ -121,48 +143,56 @@ class BayesInferenceBase(ABC):
 
 class GridEstimation(BayesInferenceBase):
     
-    def run(self, nbins):
+    @beartype
+    def run(self, nbins: Union[int, list[int]]
+        ) -> tuple[np.ndarray, list[np.ndarray]]:
         """
         Use Grid approximation to estimate the posterior.
 
         Parameters
         ----------
-        nbins : list of ints
+        nbins : int or list of ints
             Number of bins for each parameter.
-            Contain `ndim` elements, one for each parameter.
+            If int, the same value is used for each parameter.
+            If list of int, it should be of length `ndim`.
         
         Returns
         -------
         posterior : numpy array
             Estimated posterior probability density values at grid points.
-            Shape `(nbins[0], nbins[1], ..., nbins[ndim])`.
+            Shape of (nbins[0], nbins[1], ..., nbins[ndim]).
         x_ndim : list of numpy array
-            Contain `ndim` 1d numpy arrays `x1`, `x2`,... Each `xi`
+            Contain `ndim` 1d numpy arrays x1, x2, ... Each xi
             is a 1d array of length `nbins[i]`, representing the 1d coordinate
             array along the i-th axis.
 
         """
+        if isinstance(nbins, int):
+            nbins = [nbins] * self.ndim
+        elif len(nbins) != self.ndim:
+            raise ValueError(
+                "nbins must be an integer or a list of ndim integers")
+
         if self.bounds is None:
-            raise ValueError("`bounds` must be provided for grid estimation")
-        
-        if len(nbins) != self.ndim:
-            raise ValueError("`nbins` must contain `ndim` integers")
+            raise ValueError("bounds must be provided for grid estimation")
         
         steps = (self.bounds[:,1] - self.bounds[:,0]) / np.array(nbins) 
         starts = steps/2 + self.bounds[:,0]
         stops = self.bounds[:,1]  
         
         x_ndim = [np.arange(starts[i], stops[i], steps[i])
-                  for i in range(self.ndim)]
+            for i in range(self.ndim)]
         xx_matrices = np.meshgrid(*x_ndim, indexing='ij')
-        xx_coords = np.vstack((np.ravel(matrix) for matrix in xx_matrices))
-        xx_coords = np.transpose(xx_coords) 
-
-        ln_unnorm_posterior = np.zeros(len(xx_coords))
-        for i in range(len(xx_coords)):
-            coord_i = xx_coords[i,:]
+        grid_point_coords = np.hstack(
+            tuple(matrix.reshape((-1,1)) for matrix in xx_matrices)
+            )
+        
+        n = len(grid_point_coords)
+        ln_unnorm_posterior = np.zeros(n)
+        for i in range(n):
+            point_i = grid_point_coords[i,:]
             ln_unnorm_posterior[i] = \
-                self.ln_pxl(coord_i, *self.args_ln_pxl, **self.kwgs_ln_pxl)
+                self.ln_pxl(point_i, *self.args_ln_pxl, **self.kwgs_ln_pxl)
         
         ln_unnorm_posterior = ln_unnorm_posterior.reshape(xx_matrices[0].shape)
         unnorm_posterior = np.exp(ln_unnorm_posterior)
@@ -173,7 +203,8 @@ class GridEstimation(BayesInferenceBase):
 
 class MetropolisHastingsEstimation(BayesInferenceBase):
 
-    def run(self, nsamples, mh_sampler):
+    def run(self, nsamples: int, mh_sampler: MetropolisHastings) -> tuple[
+        np.ndarray, np.ndarray]:
         """
         Use metropolis hastings sampling to draw samples from the posterior.
 
@@ -186,18 +217,27 @@ class MetropolisHastingsEstimation(BayesInferenceBase):
         Returns
         -------
         mh_samples : numpy array
-            Samples drawn from the posterior. Shape `(nsamples, ndim)`.
+            Samples drawn from the posterior. Shape of (nsamples, ndim).
         mh_accept : numpy array
-            An array of shape `nsamples`. Each element indicates whether the
-            corresponding sample is a proposed new state (1) or the old state
-            (0). `np.sum(mh_accept)/len(mh_accept)` thus gives the overall
+            An array of shape (nsamples,). Each element indicates whether the
+            corresponding sample is the proposed new state (value 1) or the old
+            state (value 0). `np.mean(mh_accept)` thus gives the overall
             acceptance rate.
-
         """
-        if mh_sampler.ndim != self.ndim or mh_sampler.bounds != self.bounds:
-            raise ValueError(
-                "The Metropolis Hastings sampler is not correctly set to"
-                " the Bayes inference problem")
+        if mh_sampler.ndim != self.ndim:
+            raise RuntimeError(
+                "ndim of mh_sampler and ndim defined in this class must be"
+                " consistent")
+        
+        if type(self.bounds) != type(mh_sampler.bounds):
+            raise RuntimeError(
+                "bounds of mh_sampler and bounds defined in this class must be"
+                " consistent")
+        elif isinstance(self.bounds, np.ndarray) and \
+            not np.all(np.equal(self.bounds, mh_sampler.bounds)):
+            raise RuntimeError(
+                "bounds of mh_sampler and bounds defined in this class must be"
+                " consistent")
 
         mh_sampler.ln_target = self.ln_pxl
         mh_sampler.args_target = self.args_ln_pxl