From 9bd057a46de1f62e6d2a6f22b03659d2c402b145 Mon Sep 17 00:00:00 2001
From: kohlhaasrebecca <rebecca.kohlhaas@outlook.com>
Date: Thu, 14 Nov 2024 14:17:43 +0100
Subject: [PATCH] Fix tests

---
 .../post_processing/post_processing.py        | 23 ++++-----
 tests/test_PostProcessing.py                  | 50 +++----------------
 2 files changed, 18 insertions(+), 55 deletions(-)

diff --git a/src/bayesvalidrox/post_processing/post_processing.py b/src/bayesvalidrox/post_processing/post_processing.py
index c3b833b42..a4a5f55b5 100644
--- a/src/bayesvalidrox/post_processing/post_processing.py
+++ b/src/bayesvalidrox/post_processing/post_processing.py
@@ -38,6 +38,9 @@ class PostProcessing:
     """
 
     def __init__(self, engine, name="calib", out_dir=""):
+        # PostProcessing only available for trained engines
+        if not engine.trained:
+            raise AttributeError('PostProcessing can only be performed on trained engines.')
         self.engine = engine
         self.name = name
         self.par_names = self.engine.ExpDesign.par_names
@@ -663,16 +666,6 @@ class PostProcessing:
                         comments="",
                     )
 
-        # Check if the x_values match the number of metamodel outputs
-        # TODO: How relevant is this check?
-        if np.array(x_values_orig).shape[0] != total_sobol_all[outputs[0]].shape[1]:
-            print(
-                "The number of MetaModel outputs does not match the x_values"
-                " specified in ExpDesign. Images are created with "
-                "equidistant numbers on the x-axis"
-            )
-            x_values_orig = np.arange(0, 1, total_sobol_all[outputs[0]].shape[0])
-
         # Plot Sobol' indices
         self.plot_type = plot_type
         for i_order in range(1, max_order + 1):
@@ -686,7 +679,7 @@ class PostProcessing:
 
         return sobol_all, total_sobol_all
 
-    def plot_sobol(self, outputs, par_names, sobol_type="sobol", i_order=0):
+    def plot_sobol(self, par_names, outputs, sobol_type="sobol", i_order=0):
         """
         Generate plots for each output in the given set of Sobol' indices.
 
@@ -697,20 +690,22 @@ class PostProcessing:
         if sobol_type == "totalsobol":
             sobol = self.totalsobol
 
-        fig = plt.figure()
         for _, output in enumerate(outputs):
             x = (
                 self.x_values[output]
                 if isinstance(self.x_values, dict)
                 else self.x_values
             )
-            sobol_ = sobol[output][0]
+            sobol_ = sobol[output]
+            if sobol_type == 'sobol':
+                sobol_ = sobol_[0]
 
             # Compute quantiles
             q_5 = np.quantile(sobol[output], q=0.05, axis=0)
             q_97_5 = np.quantile(sobol[output], q=0.975, axis=0)
 
             if self.plot_type == "bar":
+                fig = plt.figure()
                 ax = fig.add_axes([0, 0, 1, 1])
                 dict1 = {self.xlabel: x}
                 dict2 = dict(zip(par_names, sobol_))
@@ -731,6 +726,8 @@ class PostProcessing:
                     ax.set_ylabel("Total Sobol indices, $S^T$")
 
             else:
+                fig = plt.figure()
+                ax = fig.add_axes([0, 0, 1, 1])
                 for i, sobol_indices in enumerate(sobol_):
                     plt.plot(
                         x,
diff --git a/tests/test_PostProcessing.py b/tests/test_PostProcessing.py
index eb3d19466..3edddf686 100644
--- a/tests/test_PostProcessing.py
+++ b/tests/test_PostProcessing.py
@@ -49,6 +49,7 @@ def basic_engine():
 def pce_engine():
     inp = Input()
     inp.add_marginals()
+    inp.Marginals[0].name = 'x'
     inp.Marginals[0].dist_type = 'normal'
     inp.Marginals[0].parameters = [0, 1]
     expdes = ExpDesigns(inp)
@@ -63,6 +64,7 @@ def pce_engine():
     engine = Engine(mm, mod, expdes)
     engine.out_names = ['Z']
     engine.emulator = True
+    engine.trained = True
     return engine
 
 #%% Test PostProcessing init
@@ -72,6 +74,12 @@ def test_postprocessing_noengine():
 
 def test_postprocessing(basic_engine) -> None:
     engine = basic_engine
+    with pytest.raises(AttributeError) as excinfo:
+        PostProcessing(engine)
+    assert str(excinfo.value) == 'PostProcessing can only be performed on trained engines.'
+
+def test_postprocessing_pce(pce_engine) -> None:
+    engine = pce_engine
     PostProcessing(engine)
     
     
@@ -127,16 +135,6 @@ def test_check_accuracy_pce(pce_engine) -> None:
 
 #%% sobol_indices
 
-def test_sobol_indices_nopce(basic_engine) -> None:
-    """
-    Calculate sobol indices for non-PCE metamodel
-    """
-    engine = basic_engine
-    post = PostProcessing(engine)
-    with pytest.raises(AttributeError) as excinfo:
-        post.sobol_indices()
-    assert str(excinfo.value) == 'Sobol indices only support PCE-type models!'
-    
 def test_sobol_indices_pce(pce_engine) -> None:
     """
     Calculate sobol indices for PCE metamodel
@@ -167,41 +165,9 @@ def test_check_reg_quality_pce(pce_engine) -> None:
 
 #%% eplot_metamodel_3d
 
-def test_plot_metamodel_3d_nopce(basic_engine) -> None:
-    """
-    3d eval of non-PCE metamodel
-    """
-    engine = basic_engine
-    post = PostProcessing(engine)
-    with pytest.raises(AttributeError) as excinfo:
-        post.plot_metamodel_3d()
-    assert str(excinfo.value) == 'This function is only applicable if the MetaModel input dimension is 2.'
-    
-
 #%% _get_sample
 #%% _eval_model
 #%% _plot_validation
 
-def test_plot_validation_nopce(basic_engine) -> None:
-    """
-    Plot validation of non-PCE metamodel
-    """
-    engine = basic_engine
-    samples = engine.ExpDesign.generate_samples(10,'random')
-    post = PostProcessing(engine)
-    with pytest.raises(AttributeError) as excinfo:
-        post._plot_validation(samples)
-    assert str(excinfo.value) == 'This evaluation only support PCE-type models!'
-    
 #%% _plot_validation_multi
-    
-def test_plot_validation_multi_nopce(basic_engine) -> None:
-    """
-    Plot multi-validation of non-PCE metamodel
-    """
-    engine = basic_engine
-    post = PostProcessing(engine)
-    with pytest.raises(AttributeError) as excinfo:
-        post._plot_validation_multi()
-    assert str(excinfo.value) == 'This evaluation only support PCE-type models!'
     
\ No newline at end of file
-- 
GitLab