From 8f3e9eaa69c5aaf3d4d068072fc5afb8b8f7b3af Mon Sep 17 00:00:00 2001
From: Hu Zhao <zhao@mbd.rwth-aachen.de>
Date: Tue, 25 Oct 2022 23:29:44 +0200
Subject: [PATCH] feat: move check of target to sample method

---
 src/psimpy/sampler/metropolis_hastings.py | 10 ++++++----
 tests/test_metropolis_hastings.py         | 11 ++++++++---
 2 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/src/psimpy/sampler/metropolis_hastings.py b/src/psimpy/sampler/metropolis_hastings.py
index 8b88c28..4c513b7 100644
--- a/src/psimpy/sampler/metropolis_hastings.py
+++ b/src/psimpy/sampler/metropolis_hastings.py
@@ -98,9 +98,6 @@ class MetropolisHastings:
         if (not symmetric) and (f_density is None):
             raise ValueError("f_density must be provided if asymmetric")
 
-        if (target is None) and (ln_target is None):
-            raise ValueError("Either target or ln_target must be provided") 
-
         self.ndim = ndim
         self.init_state = init_state
         self.f_sample = f_sample
@@ -141,7 +138,12 @@ class MetropolisHastings:
             corresponding sample is the proposed new state (value 1) or the old
             state (value 0). `np.sum(mh_accept)/len(mh_accept)` thus gives the
             overall acceptance rate.
-        """  
+        """ 
+        if (self.target is None) and (self.ln_target is None):
+            raise ValueError(
+                "Either target or ln_target must be provided before call the"
+                " sample method")
+                 
         if self.ln_target is None:
             init_t = self.target(
                 self.init_state, *self.args_target, **self.kwgs_target)
diff --git a/tests/test_metropolis_hastings.py b/tests/test_metropolis_hastings.py
index 8543c42..1e400e2 100644
--- a/tests/test_metropolis_hastings.py
+++ b/tests/test_metropolis_hastings.py
@@ -17,9 +17,7 @@ import shutil
         (1, np.array([1.5]), norm.rvs, uniform.pdf, None, np.array([[0,1]]),
         None, True),
         (2, np.array([-1,1]), multivariate_normal.rvs, multivariate_normal.pdf,
-        None, None, None, False),
-        (2, np.array([-1,1]), multivariate_normal.rvs, None, None, None,
-        None, True)
+        None, None, None, False)
     ]
 )
 def test_init_ValueError(ndim, init_state, f_sample, target, ln_target, bounds,
@@ -30,6 +28,13 @@ def test_init_ValueError(ndim, init_state, f_sample, target, ln_target, bounds,
             bounds=bounds, f_density=f_density, symmetric=symmetric)
 
 
+def test_sample_ValueError():
+    mh_sampler = MetropolisHastings(ndim=2, init_state=np.array([-1,1]),
+        f_sample=multivariate_normal.rvs, target=None, ln_target=None)
+    with pytest.raises(ValueError):
+        mh_sampler.sample(nsamples=10000)
+
+
 def test_sample_uniform_target():
     ndim = 1
     init_state = np.array([0.5])
-- 
GitLab