From 7bcd40ef56770d509a06de824e15ced4351a5b61 Mon Sep 17 00:00:00 2001
From: Hu Zhao <zhao@mbd.rwth-aachen.de>
Date: Wed, 26 Oct 2022 13:35:54 +0200
Subject: [PATCH] test: implement test_clear_temp_files and move all temp file
 clearning tasks there

---
 tests/test_clear_temp_files.py     | 24 ++++++++++++++++++++++++
 tests/test_metropolis_hastings.py  |  4 +---
 tests/test_ravaflow24.py           |  7 +------
 tests/test_run_mass_point_model.py |  3 ---
 tests/test_run_ravaflow24.py       |  6 +-----
 tests/test_run_simulator.py        | 12 ------------
 tests/test_saltelli.py             |  2 +-
 7 files changed, 28 insertions(+), 30 deletions(-)
 create mode 100644 tests/test_clear_temp_files.py

diff --git a/tests/test_clear_temp_files.py b/tests/test_clear_temp_files.py
new file mode 100644
index 0000000..7f21635
--- /dev/null
+++ b/tests/test_clear_temp_files.py
@@ -0,0 +1,24 @@
+import os
+import shutil
+import pytest
+
+@pytest.mark.order(-1)
+def test_clear_temp_files():
+
+    dir_test = os.path.abspath(os.path.join(__file__, '../'))
+
+    temp_dirs = [
+        os.path.join(dir_test, 'temp_add'), # test_run_simulator
+        os.path.join(dir_test, 'temp_run_add'), # test_run_simulator
+        os.path.join(dir_test, 'temp_run_mass_point_model'), # test_run_mass_point_model
+        os.path.join(dir_test, 'temp_ravaflow_1'), # test_ravaflow24
+        os.path.join(dir_test, 'temp_ravaflow_2'), # test_ravaflow24
+        os.path.join(dir_test, 'temp_run_ravaflow'), # test_run_ravaflow24
+        os.path.join(dir_test, 'temp_ravaflow_results'), # test_run_ravaflow24
+        os.path.join(dir_test, 'temp_metropolis_hastings'), # test_metropolis_hasting
+        os.path.join(dir_test, 'temp_bayes_inference_grid_approx_1d') # test_bayes_inference
+    ]
+
+    for temp_dir in temp_dirs:
+        if os.path.exists(temp_dir):
+            shutil.rmtree(temp_dir)
\ No newline at end of file
diff --git a/tests/test_metropolis_hastings.py b/tests/test_metropolis_hastings.py
index 1e400e2..f8bc3a4 100644
--- a/tests/test_metropolis_hastings.py
+++ b/tests/test_metropolis_hastings.py
@@ -4,7 +4,6 @@ from scipy.stats import norm, multivariate_normal, uniform
 from psimpy.sampler.metropolis_hastings import MetropolisHastings
 import matplotlib.pyplot as plt
 import os
-import shutil
 
 @pytest.mark.parametrize(
     "ndim, init_state, f_sample, target, ln_target, bounds, f_density, symmetric",
@@ -23,7 +22,7 @@ import shutil
 def test_init_ValueError(ndim, init_state, f_sample, target, ln_target, bounds,
     f_density, symmetric):
     with pytest.raises(ValueError):
-        mh_sampler = MetropolisHastings(ndim=ndim, init_state=init_state,
+        _ = MetropolisHastings(ndim=ndim, init_state=init_state,
             f_sample=f_sample, target=target, ln_target=ln_target,
             bounds=bounds, f_density=f_density, symmetric=symmetric)
 
@@ -109,6 +108,5 @@ def test_sample_multivariate_norm_target():
         os.mkdir('temp_metropolis_hastings')
     dir_out = os.path.join(dir_test, 'temp_metropolis_hastings')
     plt.savefig(os.path.join(dir_out,'2d_norm_target.png'), bbox_inches='tight')
-    shutil.rmtree(dir_out)
 
 
diff --git a/tests/test_ravaflow24.py b/tests/test_ravaflow24.py
index 6d1e798..fa7c596 100644
--- a/tests/test_ravaflow24.py
+++ b/tests/test_ravaflow24.py
@@ -2,7 +2,6 @@ from psimpy.simulator.ravaflow24 import Ravaflow24Mixture
 import numpy as np
 import os
 import pytest
-import shutil
 
 @pytest.mark.parametrize(
     "dir_sim, conversion_control, curvature_control, surface_control, \
@@ -27,7 +26,7 @@ import shutil
 def test_ravaflow24_mixture_init_ValueError(dir_sim, conversion_control,
     curvature_control, surface_control, entrainment_control, stopping_control):
     with pytest.raises(ValueError):
-        rflow24_mixture = Ravaflow24Mixture(
+        _ = Ravaflow24Mixture(
             dir_sim=dir_sim,
             conversion_control=conversion_control,
             curvature_control=curvature_control,
@@ -67,8 +66,6 @@ def test_ravaflow24_mixture_preprocess_ValueError(prefix, elevation, hrelease):
     with pytest.raises(ValueError):
         rflow24_mixture.preprocess(
             prefix=prefix, elevation=elevation, hrelease=hrelease)
-    
-    shutil.rmtree(dir_sim)
 
 
 def test_ravaflow24_mixture_run_and_extract_output():
@@ -119,5 +116,3 @@ def test_ravaflow24_mixture_run_and_extract_output():
     assert isinstance(loc_max_energy, np.ndarray)
     assert loc_max_energy.ndim == 1
     assert len(loc_max_energy) == len(loc)
-
-    shutil.rmtree(dir_sim)
diff --git a/tests/test_run_mass_point_model.py b/tests/test_run_mass_point_model.py
index c9ec5f5..310fcce 100644
--- a/tests/test_run_mass_point_model.py
+++ b/tests/test_run_mass_point_model.py
@@ -4,7 +4,6 @@ import os
 import numpy as np
 import itertools
 import time
-import shutil
 
 def test_run_mass_point_model():
     mpm = MassPointModel()
@@ -51,5 +50,3 @@ def test_run_mass_point_model():
         assert np.array_equal(serial_output[i], parallel_output[i])
     
     assert serial_time > parallel_time
-        
-    shutil.rmtree(dir_out)
diff --git a/tests/test_run_ravaflow24.py b/tests/test_run_ravaflow24.py
index f192694..d2b3466 100644
--- a/tests/test_run_ravaflow24.py
+++ b/tests/test_run_ravaflow24.py
@@ -4,7 +4,6 @@ import numpy as np
 import itertools
 import time
 import os
-import shutil
 
 dir_test = os.path.abspath(os.path.join(__file__, '../'))
 
@@ -91,7 +90,4 @@ def test_run_ravaflow24():
     assert serial_time > parallel_time
 
     for i in range(len(var_samples)):
-        assert np.array_equal(serial_output[i], parallel_output[i])
-    
-    shutil.rmtree(dir_out)
-    shutil.rmtree(dir_sim)
\ No newline at end of file
+        assert np.array_equal(serial_output[i], parallel_output[i])
\ No newline at end of file
diff --git a/tests/test_run_simulator.py b/tests/test_run_simulator.py
index c69d0da..79ab5ba 100644
--- a/tests/test_run_simulator.py
+++ b/tests/test_run_simulator.py
@@ -2,7 +2,6 @@ from psimpy.simulator.run_simulator import RunSimulator
 import pytest
 import numpy as np
 import os
-import shutil
 from beartype.roar import BeartypeCallHintParamViolation
 
 def add(a, b, c , d=100, save=False, filename=None):
@@ -108,9 +107,6 @@ def test_RunSimulator_serial_parallel_run_with_o_parameter():
         'temp_add', 'parallel_run1.txt')
     )
 
-    shutil.rmtree(
-        os.path.join(os.path.abspath(os.path.join(__file__,'../temp_add')))
-    )
 
 def test_RunSimulator_serial_parallel_run_with_save_out():
     dir_out = os.path.join(os.path.abspath(os.path.join(__file__,'../')),
@@ -160,13 +156,5 @@ def test_RunSimulator_serial_parallel_run_with_save_out():
         'parallel_run1_output.npy')
     )
 
-    shutil.rmtree(
-        os.path.join(os.path.abspath(os.path.join(__file__,'../temp_add')))
-    )
-    
-    shutil.rmtree(
-        os.path.join(os.path.abspath(os.path.join(__file__,'../temp_run_add')))
-    )
-
 
 
diff --git a/tests/test_saltelli.py b/tests/test_saltelli.py
index a3c74ad..d9331c5 100644
--- a/tests/test_saltelli.py
+++ b/tests/test_saltelli.py
@@ -16,7 +16,7 @@ from beartype.roar import BeartypeCallHintParamViolation
 )
 def test_init_TypeError(ndim, bounds, calc_second_order, skip_values):
     with pytest.raises(BeartypeCallHintParamViolation):
-        saltelli_sampler = Saltelli(ndim, bounds, calc_second_order, skip_values)
+        _ = Saltelli(ndim, bounds, calc_second_order, skip_values)
 
 
 @pytest.mark.parametrize(
-- 
GitLab