diff --git a/src/psimpy/inference/bayes_inference.py b/src/psimpy/inference/bayes_inference.py index e56b25a94fc3f9065314e49659c666f2eacf75d4..31a3fc549b8b39bbd525d5bcf4c7bef0d438133f 100644 --- a/src/psimpy/inference/bayes_inference.py +++ b/src/psimpy/inference/bayes_inference.py @@ -1,16 +1,33 @@ import numpy as np import sys -from psimpy.sampler.metropolis_hastings import MetropolisHastings from abc import ABC from abc import abstractmethod +from psimpy.sampler.metropolis_hastings import MetropolisHastings +from typing import Union +from beartype.typing import Callable +from beartype import beartype + +_min_float = 10**(sys.float_info.min_10_exp) + class BayesInferenceBase(ABC): + @beartype def __init__( - self, ndim, prior=None, likelihood=None, ln_prior=None, - ln_likelihood=None, ln_pxl=None, bounds=None, args_prior=None, - kwgs_prior=None, args_likelihood=None, kwgs_likelihood=None, - args_ln_pxl=None, kwgs_ln_pxl=None): + self, + ndim: int, + bounds: Union[np.ndarray, None] = None, + prior: Union[Callable, None] = None, + likelihood: Union[Callable, None] = None, + ln_prior: Union[Callable, None] = None, + ln_likelihood: Union[Callable, None] = None, + ln_pxl: Union[Callable, None] = None, + args_prior: Union[list, None] = None, + kwgs_prior: Union[dict, None] = None, + args_likelihood: Union[list, None] = None, + kwgs_likelihood: Union[dict, None]=None, + args_ln_pxl: Union[list, None] = None, + kwgs_ln_pxl: Union[dict, None]=None) -> None: """ A base class to set up basic input for Bayesian inference. @@ -18,30 +35,33 @@ class BayesInferenceBase(ABC): ---------- ndim : int Parameter dimension. - prior : callable + bounds : numpy array + Upper and lower boundaries of each parameter. Shape (ndim, 2). + bounds[:, 0] - lower boundaries of each parameter. + bounds[:, 1] - upper boundaries of each parameter. + prior : Callable Prior probability density function. Call with `prior(x, *args_prior, **kwgs_prior)`. - Return the prior probability density value at `x`. - likelihood : callable + Return the prior probability density value at x, where x is a one + dimension numpy array of shape (ndim,). + likelihood : Callable Likelihood function. Call with `likelihood(x, *args_likelihood, **kwgs_likelihood)`. - Return the likelihood value at `x`. - ln_prior : callable + Return the likelihood value at x. + ln_prior : Callable Natural logarithm of prior probability density function. Call with `ln_prior(x, *args_prior, **kwgs_prior)`. Return the natural logarithm value of prior probability density - value at `x`. - ln_likelihood : callable + value at x. + ln_likelihood : Callable Natural logarithm of likelihood function. Call with `ln_likelihood(x, *args_likelihood, **kwgs_likelihood)`. - Return the natural logarithm value of likelihood value at `x`. - ln_pxl : callable + Return the natural logarithm value of likelihood value at x. + ln_pxl : Callable Natural logarithm of the product of prior times likelihood. - bounds : numpy array - Upper and lower boundaries of each parameter. - Shape `(ndim, 2)`. - `bounds[:, 0]` - lower boundaries of each parameter. - `bounds[:, 1]` - upper boundaries of each parameter. + Call with `ln_pxl(x, *args_ln_pxl, **kwgs_ln_pxl)`. + Return the natural logarithm value of prior times the natural + logarithm value of likelihood at x. args_prior : list, optional Positional arguments for `prior` or `ln_prior`. kwgs_prior: dict, optional @@ -54,34 +74,37 @@ class BayesInferenceBase(ABC): Positional arguments for `ln_pxl`. kwgs_ln_pxl : dict, optional Keyword arguments for `ln_pxl`. - """ - self.ndim = ndim + + if bounds is not None: + if bounds.ndim != 2: + raise ValueError("bounds must be a 2d numpy array") + elif bounds.shape[0] != ndim or bounds.shape[1] != 2: + raise ValueError("bounds must be of shape (ndim, 2)") self.bounds = bounds self.args_prior = () if args_prior is None else args_prior - self.args_likelihood = () if args_likelihood is None else args_likelihood - self.args_ln_pxl = () if args_ln_pxl is None else args_ln_pxl - self.kwgs_prior = {} if kwgs_prior is None else kwgs_prior + + self.args_likelihood = () if args_likelihood is None else args_likelihood self.kwgs_likelihood = {} if kwgs_likelihood is None else kwgs_likelihood + + self.args_ln_pxl = () if args_ln_pxl is None else args_ln_pxl self.kwgs_ln_pxl = {} if kwgs_ln_pxl is None else kwgs_ln_pxl if ln_pxl is not None: self.ln_pxl = ln_pxl - self.args_ln_pxl = args_ln_pxl - self.kwgs_ln_pxl = kwgs_ln_pxl elif (ln_prior is not None) and (ln_likelihood is not None): self.ln_prior = ln_prior self.ln_likelihood = ln_likelihood - self.ln_pxl = self._ln_pxl1 + self.ln_pxl = self._ln_pxl_1 self.args_ln_pxl = () self.kwgs_ln_pxl = {} elif (prior is not None) and (likelihood is not None): self.prior = prior self.likelihood = likelihood - self.ln_pxl = self._ln_pxl2 + self.ln_pxl = self._ln_pxl_2 self.args_ln_pxl = () self.kwgs_ln_pxl = {} else: @@ -91,28 +114,27 @@ class BayesInferenceBase(ABC): " (2) `ln_prior` and `ln_likelihood`" " (3) `prior` and `likelihood`")) - def _ln_pxl1(self, x): + def _ln_pxl_1(self, x: np.ndarray) -> float: """ Construct natural logarithm of the product of prior and likelihood, given natural logarithm of prior function and natural logarithm of likelihood function. """ - ln_pxl = \ - self.ln_prior(x, *self.args_ln_prior, **self.kwgs_ln_prior) + \ - self.ln_likelihood(x, *self.args_ln_likelihood, - **self.kwgs_ln_likelihood) - return ln_pxl + ln_pxl = self.ln_prior(x, *self.args_prior, **self.kwgs_prior) + \ + self.ln_likelihood(x, *self.args_likelihood, **self.kwgs_likelihood) + + return float(ln_pxl) - def _ln_pxl2(self, x): + def _ln_pxl_2(self, x: np.ndarray) -> float: """ Construct natural logarithm of the product of prior and likelihood, given prior function and likelihood function. """ pxl = self.prior(x, *self.args_prior, **self.kwgs_prior) * \ self.likelihood(x, *self.args_likelihood, **self.kwgs_likelihood) - min_10_exp = sys.float_info.min_10_exp + ln_pxl = np.log(np.maximum(pxl, _min_float)) - return np.log(np.maximum(pxl, 10**(min_10_exp))) + return float(ln_pxl) @abstractmethod def run(self, *args, **kwgs): @@ -121,48 +143,56 @@ class BayesInferenceBase(ABC): class GridEstimation(BayesInferenceBase): - def run(self, nbins): + @beartype + def run(self, nbins: Union[int, list[int]] + ) -> tuple[np.ndarray, list[np.ndarray]]: """ Use Grid approximation to estimate the posterior. Parameters ---------- - nbins : list of ints + nbins : int or list of ints Number of bins for each parameter. - Contain `ndim` elements, one for each parameter. + If int, the same value is used for each parameter. + If list of int, it should be of length `ndim`. Returns ------- posterior : numpy array Estimated posterior probability density values at grid points. - Shape `(nbins[0], nbins[1], ..., nbins[ndim])`. + Shape of (nbins[0], nbins[1], ..., nbins[ndim]). x_ndim : list of numpy array - Contain `ndim` 1d numpy arrays `x1`, `x2`,... Each `xi` + Contain `ndim` 1d numpy arrays x1, x2, ... Each xi is a 1d array of length `nbins[i]`, representing the 1d coordinate array along the i-th axis. """ + if isinstance(nbins, int): + nbins = [nbins] * self.ndim + elif len(nbins) != self.ndim: + raise ValueError( + "nbins must be an integer or a list of ndim integers") + if self.bounds is None: - raise ValueError("`bounds` must be provided for grid estimation") - - if len(nbins) != self.ndim: - raise ValueError("`nbins` must contain `ndim` integers") + raise ValueError("bounds must be provided for grid estimation") steps = (self.bounds[:,1] - self.bounds[:,0]) / np.array(nbins) starts = steps/2 + self.bounds[:,0] stops = self.bounds[:,1] x_ndim = [np.arange(starts[i], stops[i], steps[i]) - for i in range(self.ndim)] + for i in range(self.ndim)] xx_matrices = np.meshgrid(*x_ndim, indexing='ij') - xx_coords = np.vstack((np.ravel(matrix) for matrix in xx_matrices)) - xx_coords = np.transpose(xx_coords) - - ln_unnorm_posterior = np.zeros(len(xx_coords)) - for i in range(len(xx_coords)): - coord_i = xx_coords[i,:] + grid_point_coords = np.hstack( + tuple(matrix.reshape((-1,1)) for matrix in xx_matrices) + ) + + n = len(grid_point_coords) + ln_unnorm_posterior = np.zeros(n) + for i in range(n): + point_i = grid_point_coords[i,:] ln_unnorm_posterior[i] = \ - self.ln_pxl(coord_i, *self.args_ln_pxl, **self.kwgs_ln_pxl) + self.ln_pxl(point_i, *self.args_ln_pxl, **self.kwgs_ln_pxl) ln_unnorm_posterior = ln_unnorm_posterior.reshape(xx_matrices[0].shape) unnorm_posterior = np.exp(ln_unnorm_posterior) @@ -173,7 +203,8 @@ class GridEstimation(BayesInferenceBase): class MetropolisHastingsEstimation(BayesInferenceBase): - def run(self, nsamples, mh_sampler): + def run(self, nsamples: int, mh_sampler: MetropolisHastings) -> tuple[ + np.ndarray, np.ndarray]: """ Use metropolis hastings sampling to draw samples from the posterior. @@ -186,18 +217,27 @@ class MetropolisHastingsEstimation(BayesInferenceBase): Returns ------- mh_samples : numpy array - Samples drawn from the posterior. Shape `(nsamples, ndim)`. + Samples drawn from the posterior. Shape of (nsamples, ndim). mh_accept : numpy array - An array of shape `nsamples`. Each element indicates whether the - corresponding sample is a proposed new state (1) or the old state - (0). `np.sum(mh_accept)/len(mh_accept)` thus gives the overall + An array of shape (nsamples,). Each element indicates whether the + corresponding sample is the proposed new state (value 1) or the old + state (value 0). `np.mean(mh_accept)` thus gives the overall acceptance rate. - """ - if mh_sampler.ndim != self.ndim or mh_sampler.bounds != self.bounds: - raise ValueError( - "The Metropolis Hastings sampler is not correctly set to" - " the Bayes inference problem") + if mh_sampler.ndim != self.ndim: + raise RuntimeError( + "ndim of mh_sampler and ndim defined in this class must be" + " consistent") + + if type(self.bounds) != type(mh_sampler.bounds): + raise RuntimeError( + "bounds of mh_sampler and bounds defined in this class must be" + " consistent") + elif isinstance(self.bounds, np.ndarray) and \ + not np.all(np.equal(self.bounds, mh_sampler.bounds)): + raise RuntimeError( + "bounds of mh_sampler and bounds defined in this class must be" + " consistent") mh_sampler.ln_target = self.ln_pxl mh_sampler.args_target = self.args_ln_pxl