Skip to content
Snippets Groups Projects
Commit 8f3e9eaa authored by Hu Zhao's avatar Hu Zhao
Browse files

feat: move check of target to sample method

parent 9b02dc5f
No related branches found
No related tags found
No related merge requests found
......@@ -98,9 +98,6 @@ class MetropolisHastings:
if (not symmetric) and (f_density is None):
raise ValueError("f_density must be provided if asymmetric")
if (target is None) and (ln_target is None):
raise ValueError("Either target or ln_target must be provided")
self.ndim = ndim
self.init_state = init_state
self.f_sample = f_sample
......@@ -142,6 +139,11 @@ class MetropolisHastings:
state (value 0). `np.sum(mh_accept)/len(mh_accept)` thus gives the
overall acceptance rate.
"""
if (self.target is None) and (self.ln_target is None):
raise ValueError(
"Either target or ln_target must be provided before call the"
" sample method")
if self.ln_target is None:
init_t = self.target(
self.init_state, *self.args_target, **self.kwgs_target)
......
......@@ -17,9 +17,7 @@ import shutil
(1, np.array([1.5]), norm.rvs, uniform.pdf, None, np.array([[0,1]]),
None, True),
(2, np.array([-1,1]), multivariate_normal.rvs, multivariate_normal.pdf,
None, None, None, False),
(2, np.array([-1,1]), multivariate_normal.rvs, None, None, None,
None, True)
None, None, None, False)
]
)
def test_init_ValueError(ndim, init_state, f_sample, target, ln_target, bounds,
......@@ -30,6 +28,13 @@ def test_init_ValueError(ndim, init_state, f_sample, target, ln_target, bounds,
bounds=bounds, f_density=f_density, symmetric=symmetric)
def test_sample_ValueError():
mh_sampler = MetropolisHastings(ndim=2, init_state=np.array([-1,1]),
f_sample=multivariate_normal.rvs, target=None, ln_target=None)
with pytest.raises(ValueError):
mh_sampler.sample(nsamples=10000)
def test_sample_uniform_target():
ndim = 1
init_state = np.array([0.5])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment