diff --git a/src/psimpy/sampler/metropolis_hastings.py b/src/psimpy/sampler/metropolis_hastings.py index 8b88c28ecb919f890e3fe61a5cbde263c4d5bc14..4c513b705d477e48d63612d37fd527c570f17137 100644 --- a/src/psimpy/sampler/metropolis_hastings.py +++ b/src/psimpy/sampler/metropolis_hastings.py @@ -98,9 +98,6 @@ class MetropolisHastings: if (not symmetric) and (f_density is None): raise ValueError("f_density must be provided if asymmetric") - if (target is None) and (ln_target is None): - raise ValueError("Either target or ln_target must be provided") - self.ndim = ndim self.init_state = init_state self.f_sample = f_sample @@ -141,7 +138,12 @@ class MetropolisHastings: corresponding sample is the proposed new state (value 1) or the old state (value 0). `np.sum(mh_accept)/len(mh_accept)` thus gives the overall acceptance rate. - """ + """ + if (self.target is None) and (self.ln_target is None): + raise ValueError( + "Either target or ln_target must be provided before call the" + " sample method") + if self.ln_target is None: init_t = self.target( self.init_state, *self.args_target, **self.kwgs_target) diff --git a/tests/test_metropolis_hastings.py b/tests/test_metropolis_hastings.py index 8543c423c6813535e9facc86e9253ff38890ceba..1e400e28f43201e79a1d09f8150e6098d1571b6e 100644 --- a/tests/test_metropolis_hastings.py +++ b/tests/test_metropolis_hastings.py @@ -17,9 +17,7 @@ import shutil (1, np.array([1.5]), norm.rvs, uniform.pdf, None, np.array([[0,1]]), None, True), (2, np.array([-1,1]), multivariate_normal.rvs, multivariate_normal.pdf, - None, None, None, False), - (2, np.array([-1,1]), multivariate_normal.rvs, None, None, None, - None, True) + None, None, None, False) ] ) def test_init_ValueError(ndim, init_state, f_sample, target, ln_target, bounds, @@ -30,6 +28,13 @@ def test_init_ValueError(ndim, init_state, f_sample, target, ln_target, bounds, bounds=bounds, f_density=f_density, symmetric=symmetric) +def test_sample_ValueError(): + mh_sampler = MetropolisHastings(ndim=2, init_state=np.array([-1,1]), + f_sample=multivariate_normal.rvs, target=None, ln_target=None) + with pytest.raises(ValueError): + mh_sampler.sample(nsamples=10000) + + def test_sample_uniform_target(): ndim = 1 init_state = np.array([0.5])