From 17427e4b0f3027aee7a21adb317fbe5c31966569 Mon Sep 17 00:00:00 2001
From: Alina Lacheim <a.lacheim@outlook.de>
Date: Thu, 5 Dec 2024 11:28:42 +0100
Subject: [PATCH] test for PostProcessing

---
 tests/test_PostProcessing.py | 269 +++++++++++++++++++++++++++++++++--
 1 file changed, 258 insertions(+), 11 deletions(-)

diff --git a/tests/test_PostProcessing.py b/tests/test_PostProcessing.py
index 95afa3e9e..efcffc481 100644
--- a/tests/test_PostProcessing.py
+++ b/tests/test_PostProcessing.py
@@ -7,19 +7,19 @@ Class PostProcessing:
     plot_moments
     valid_metamodel
     check_accuracy
-    plot_seq_design
+    plot_seq_design_diagnostics
     sobol_indices
+    plot_sobol 
     check_req_quality
-    eval_pce_model_3d
-    _get_sample
-    _eval_model
-    _plot_validation
+    plot_metamodel_3d
     _plot_validation_multi
 """
+
 import sys
 sys.path.append("../src/")
 import numpy as np
 import pytest
+import os 
 
 from bayesvalidrox.post_processing.post_processing import PostProcessing
 from bayesvalidrox.surrogate_models.inputs import Input
@@ -46,6 +46,14 @@ def basic_engine():
     engine.emulator = True
     return engine
 
+@pytest.fixture
+def basic_engine_trained():
+    # Setup a basic engine fixture
+    engine = type('Engine', (object,), {})()
+    engine.trained = True
+    engine.ExpDesign = type('ExpDesign', (object,), {'X': [[0.1], [0.2], [0.3]], 'Y': [[1], [2], [3]]})
+    return engine
+
 @pytest.fixture
 def pce_engine():
     inp = Input()
@@ -59,6 +67,36 @@ def pce_engine():
     expdes.Y = {'Z': [[0.4], [0.5], [0.45]]}
     expdes.x_values = [0]
 
+    mm = PCE(inp)
+    mm.fit(expdes.X, expdes.Y)
+    mod = PL()
+    engine = Engine(mm, mod, expdes)
+    engine.out_names = ['Z']
+    engine.emulator = True
+    engine.trained = True
+    return engine
+
+@pytest.fixture
+def pce_engine_3d_plot():
+    inp = Input()
+    inp.add_marginals()
+    
+    inp.Marginals[0].name = 'x1'
+    inp.Marginals[0].dist_type = 'normal'
+    inp.Marginals[0].parameters = [0, 1]
+    
+    inp.add_marginals()
+    inp.Marginals[1].name = 'x2'
+    inp.Marginals[1].dist_type = 'normal'
+    inp.Marginals[1].parameters = [0, 1]
+    
+    expdes = ExpDesigns(inp)
+    expdes.init_param_space(max_deg=1)
+    expdes.X = np.array([[0, 0], [1, 1]])  # Zwei Eingabedimensionen
+    expdes.Y = {'Z': [[0.4], [0.5]]}
+    expdes.x_values = [0, 1]
+
+
     mm = PCE(inp)
     mm.fit(expdes.X, expdes.Y)
     mod = PL()
@@ -71,10 +109,12 @@ def pce_engine():
 @pytest.fixture
 def gpe_engine(): 
     inp = Input()
+    
     inp.add_marginals()
     inp.Marginals[0].name = 'x'
     inp.Marginals[0].dist_type = 'normal'
     inp.Marginals[0].parameters = [0, 1]
+    
     expdes = ExpDesigns(inp)
     expdes.init_param_space(max_deg=1)
     expdes.X = np.array([[0], [1], [0.5]])
@@ -90,12 +130,42 @@ def gpe_engine():
     engine.trained = True
     return engine
 
+@pytest.fixture
+def gpe_engine_3d_plot(): 
+    inp = Input()
+    # Füge zwei Prior-Verteilungen hinzu
+    inp.add_marginals()
+    inp.Marginals[0].name = 'x1'
+    inp.Marginals[0].dist_type = 'normal'
+    inp.Marginals[0].parameters = [0, 1]
+    
+    inp.add_marginals()
+    inp.Marginals[1].name = 'x2'
+    inp.Marginals[1].dist_type = 'normal'
+    inp.Marginals[1].parameters = [0, 1]
+    
+    expdes = ExpDesigns(inp)
+    expdes.init_param_space(max_deg=1)
+    # Erstelle Stichproben mit zwei Eingabedimensionen
+    expdes.X = np.array([[0, 0], [1, 1]])  # 2D-Array: (Anzahl der Stichproben, Anzahl der Priors)
+    expdes.Y = {'Z': [[0.4], [0.5]]}       # Zielwerte
+    expdes.x_values = [0, 1]               # Eingabewerte für beide Dimensionen
+    
+    mm = GPESkl(inp)
+    mm.fit(expdes.X, expdes.Y)
+    mod = PL()
+    engine = Engine(mm, mod, expdes)
+    engine.out_names = ['Z']
+    engine.emulator = True
+    engine.trained = True
+    return engine
+
 #%% Test PostProcessing init
 
 def test_postprocessing_noengine():
     None
 
-def test_postprocessing(basic_engine) -> None:
+def test_postprocessing_untrained_engine(basic_engine) -> None:
     engine = basic_engine
     with pytest.raises(AttributeError) as excinfo:
         PostProcessing(engine)
@@ -109,7 +179,6 @@ def test_postprocessing_gpe(gpe_engine) -> None:
     engine = gpe_engine
     PostProcessing(engine)
 #%% plot_moments
-
 def test_plot_moments_pce(pce_engine) -> None:
     """
     Plot moments for PCE metamodel
@@ -173,7 +242,49 @@ def test_plot_moments_gpebar(gpe_engine) -> None:
     assert list(stdev.keys()) == ['Z']
     assert stdev['Z'].shape == (1,)
     assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01)
+
+def test_plot_moments_with_invalid_model_type() -> None:
+    """
+    Plot moments with invalid model type
+    """
+    engine = type('Engine', (object,), {})()
+    engine.model_type = 'INVALID'
+    engine.trained = True
+    post = PostProcessing(engine)
+    with pytest.raises(ValueError) as excinfo:
+        post.plot_moments()
+    assert "Invalid model type" in str(excinfo.value)
+    
 #%% valid_metamodel
+def test_plot_validation_multi_pce(pce_engine):
+    engine = pce_engine
+    post = PostProcessing(engine)
+    out_mean = {'Z': np.array([[0.4], [0.5], [0.45]])}
+    out_std = {'Z': np.array([[0.1], [0.1], [0.1]])}
+    post.model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])}
+    post._plot_validation_multi(out_mean, out_std)
+
+def test_plot_validation_multi_gpe(gpe_engine):
+    engine = gpe_engine
+    post = PostProcessing(engine)
+    out_mean = {'Z': np.array([[0.4], [0.5], [0.45]])}
+    out_std = {'Z': np.array([[0.1], [0.1], [0.1]])}
+    post.model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])}
+    post._plot_validation_multi(out_mean, out_std)
+
+def test_valid_metamodel_pce(pce_engine):
+    engine = pce_engine
+    post = PostProcessing(engine)
+    samples = np.array([[0], [1], [0.5]])
+    model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])}
+    post.valid_metamodel(samples=samples, model_out_dict=model_out_dict)
+
+def test_valid_metamodel_gpe(gpe_engine):
+    engine = gpe_engine
+    post = PostProcessing(engine)
+    samples = np.array([[0], [1], [0.5]])
+    model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])}
+    post.valid_metamodel(samples=samples, model_out_dict=model_out_dict)
 
 #%% check_accuracy
 
@@ -192,8 +303,41 @@ def test_check_accuracy_gpe(gpe_engine) -> None:
     engine = gpe_engine
     post = PostProcessing(engine)
     post.check_accuracy(samples = engine.ExpDesign.X, outputs = engine.ExpDesign.Y)
-#%% plot_seq_design
+#%% plot_seq_design_diagnoxtics
+def test_plot_seq_design_diagnostics(basic_engine_trained):
+    """
+    Test the plot_seq_design_diagnostics method
+    """
+    engine = basic_engine_trained
+    post = PostProcessing(engine)
+    post.plot_seq_design_diagnostics()
+    # Check if the plot was created and saved
+    assert os.path.exists(f"./{engine.out_dir}/Seq_Design_Diagnostics.{engine.out_format}")
 
+def test_plot_seq_design_diagnostics_with_custom_values(basic_engine_trained):
+    """
+    Test the plot_seq_design_diagnostics method with custom values
+    """
+    engine = basic_engine_trained
+    engine.ExpDesign.X = [[0.1], [0.3], [0.5], [0.7], [0.9]]
+    engine.ExpDesign.Y = [[2], [4], [6], [8], [10]]
+    post = PostProcessing(engine)
+    post.plot_seq_design_diagnostics()
+    # Check if the plot was created and saved
+    assert os.path.exists(f"./{engine.out_dir}/Seq_Design_Diagnostics.{engine.out_format}")
+
+def test_plot_seq_design_diagnostics_with_empty_values(basic_engine_trained):
+    """
+    Test the plot_seq_design_diagnostics method with empty values
+    """
+    engine = basic_engine_trained
+    engine.ExpDesign.X = []
+    engine.ExpDesign.Y = []
+    post = PostProcessing(engine)
+    with pytest.raises(ValueError) as excinfo:
+        post.plot_seq_design_diagnostics()
+    assert "ExpDesign.X and ExpDesign.Y cannot be empty" in str(excinfo.value)
+    
 #%% sobol_indices
 
 def test_sobol_indices_pce(pce_engine) -> None:
@@ -214,6 +358,17 @@ def test_sobol_indices_pce(pce_engine) -> None:
     assert sobol[1]['Z'].shape == (1,1,1)
     assert sobol[1]['Z'][0,0] == 1
 
+def test_sobol_indices_with_invalid_model_type(basic_engine_trained) -> None:
+    """
+    Calculate sobol indices with invalid model type
+    """
+    engine = basic_engine_trained
+    post = PostProcessing(engine)
+    post.model_type = 'INVALID'
+    with pytest.raises(ValueError) as excinfo:
+        post.sobol_indices()
+    assert "Invalid model type" in str(excinfo.value)
+
 #%% check_reg_quality
 
 def test_check_reg_quality_pce(pce_engine) -> None:
@@ -224,11 +379,103 @@ def test_check_reg_quality_pce(pce_engine) -> None:
     post = PostProcessing(engine)
     post.check_reg_quality(samples=engine.ExpDesign.X, outputs=engine.ExpDesign.Y)
 
-#%% eplot_metamodel_3d
+def test_check_reg_quality_gpe(gpe_engine) -> None:
+    """
+    Check the regression quality for GPE metamodel
+    """
+    engine = gpe_engine
+    post = PostProcessing(engine)
+    post.check_reg_quality(samples=engine.ExpDesign.X, outputs=engine.ExpDesign.Y)
+    # Add assertions to check the quality metrics if available
+
+def test_check_reg_quality_with_invalid_samples(pce_engine) -> None:
+    """
+    Check the regression quality with invalid samples
+    """
+    engine = pce_engine
+    post = PostProcessing(engine)
+    with pytest.raises(AttributeError) as excinfo:
+        post.check_reg_quality(outputs=engine.ExpDesign.Y)
+    assert "Samples cannot be empty" in str(excinfo.value)
+
+def test_check_reg_quality_with_invalid_outputs(pce_engine) -> None:
+    """
+    Check the regression quality with invalid outputs
+    """
+    engine = pce_engine
+    post = PostProcessing(engine)
+    with pytest.raises(ValueError) as excinfo:
+        post.check_reg_quality(samples=engine.ExpDesign.X, outputs=[])
+    assert "Please provide the outputs of the model!" in str(excinfo.value)
+
+#%% plot_metamodel_3d
+def test_plot_metamodel_3d_pce(pce_engine_3d_plot) -> None:
+    """
+    Test the plot_metamodel_3d method for PCE metamodel
+    """
+    engine = pce_engine_3d_plot
+    post = PostProcessing(engine)
+    post.plot_metamodel_3d()
+    # Check if the plot was created and saved
+    assert os.path.exists(f"./{engine.out_dir}/Metamodel_3D.{engine.out_format}")
+
+def test_plot_metamodel_3d_gpe(gpe_engine_3d_plot) -> None:
+    """
+    Test the plot_metamodel_3d method for GPE metamodel
+    """
+    engine = gpe_engine_3d_plot
+    post = PostProcessing(engine)
+    post.plot_metamodel_3d()
+    # Check if the plot was created and saved
+    assert os.path.exists(f"./{engine.out_dir}/Metamodel_3D.{engine.out_format}")
 
-#%% _get_sample
-#%% _eval_model
+def test_plot_metamodel_3d_with_invalid_data(pce_engine_3d_plot) -> None:
+    """
+    Test the plot_metamodel_3d method with invalid data
+    """
+    engine = pce_engine_3d_plot
+    engine.ExpDesign.X = []
+    post = PostProcessing(engine)
+    with pytest.raises(ValueError) as excinfo:
+        post.plot_metamodel_3d()
+    assert "Input data cannot be empty" in str(excinfo.value)
 
 
 #%% _plot_validation_multi
+def test_plot_validation_multi(basic_engine_trained):
+    """
+    Test the _plot_validation_multi method
+    """
+    engine = basic_engine_trained
+    post = PostProcessing(engine)
+    y_val = {'key1': [1, 2, 3, 4, 5]}
+    y_val_std = {'key1': [0.1, 0.2, 0.3, 0.4, 0.5]}
+    post._plot_validation_multi(y_val, y_val_std)
+    # Check if the plot was created and saved
+    assert os.path.exists(f"./{engine.out_dir}/Model_vs_Model_key1.{engine.out_format}")
+
+def test_plot_validation_multi_with_multiple_keys(basic_engine_trained):
+    """
+    Test the _plot_validation_multi method with multiple keys
+    """
+    engine = basic_engine_trained
+    post = PostProcessing(engine)
+    y_val = {'key1': [1, 2, 3, 4, 5], 'key2': [2, 3, 4, 5, 6]}
+    y_val_std = {'key1': [0.1, 0.2, 0.3, 0.4, 0.5], 'key2': [0.2, 0.3, 0.4, 0.5, 0.6]}
+    post._plot_validation_multi(y_val, y_val_std)
+    # Check if the plots were created and saved
+    assert os.path.exists(f"./{engine.out_dir}/Model_vs_Model_key1.{engine.out_format}")
+    assert os.path.exists(f"./{engine.out_dir}/Model_vs_Model_key2.{engine.out_format}")
+
+def test_plot_validation_multi_with_empty_values(basic_engine_trained):
+    """
+    Test the _plot_validation_multi method with empty values
+    """
+    engine = basic_engine_trained
+    post = PostProcessing(engine)
+    y_val = {}
+    y_val_std = {}
+    with pytest.raises(ValueError) as excinfo:
+        post._plot_validation_multi(y_val, y_val_std)
+    assert "y_val and y_val_std cannot be empty" in str(excinfo.value)
     
\ No newline at end of file
-- 
GitLab