From 17427e4b0f3027aee7a21adb317fbe5c31966569 Mon Sep 17 00:00:00 2001 From: Alina Lacheim <a.lacheim@outlook.de> Date: Thu, 5 Dec 2024 11:28:42 +0100 Subject: [PATCH] test for PostProcessing --- tests/test_PostProcessing.py | 269 +++++++++++++++++++++++++++++++++-- 1 file changed, 258 insertions(+), 11 deletions(-) diff --git a/tests/test_PostProcessing.py b/tests/test_PostProcessing.py index 95afa3e9e..efcffc481 100644 --- a/tests/test_PostProcessing.py +++ b/tests/test_PostProcessing.py @@ -7,19 +7,19 @@ Class PostProcessing: plot_moments valid_metamodel check_accuracy - plot_seq_design + plot_seq_design_diagnostics sobol_indices + plot_sobol check_req_quality - eval_pce_model_3d - _get_sample - _eval_model - _plot_validation + plot_metamodel_3d _plot_validation_multi """ + import sys sys.path.append("../src/") import numpy as np import pytest +import os from bayesvalidrox.post_processing.post_processing import PostProcessing from bayesvalidrox.surrogate_models.inputs import Input @@ -46,6 +46,14 @@ def basic_engine(): engine.emulator = True return engine +@pytest.fixture +def basic_engine_trained(): + # Setup a basic engine fixture + engine = type('Engine', (object,), {})() + engine.trained = True + engine.ExpDesign = type('ExpDesign', (object,), {'X': [[0.1], [0.2], [0.3]], 'Y': [[1], [2], [3]]}) + return engine + @pytest.fixture def pce_engine(): inp = Input() @@ -59,6 +67,36 @@ def pce_engine(): expdes.Y = {'Z': [[0.4], [0.5], [0.45]]} expdes.x_values = [0] + mm = PCE(inp) + mm.fit(expdes.X, expdes.Y) + mod = PL() + engine = Engine(mm, mod, expdes) + engine.out_names = ['Z'] + engine.emulator = True + engine.trained = True + return engine + +@pytest.fixture +def pce_engine_3d_plot(): + inp = Input() + inp.add_marginals() + + inp.Marginals[0].name = 'x1' + inp.Marginals[0].dist_type = 'normal' + inp.Marginals[0].parameters = [0, 1] + + inp.add_marginals() + inp.Marginals[1].name = 'x2' + inp.Marginals[1].dist_type = 'normal' + inp.Marginals[1].parameters = [0, 1] + + expdes = ExpDesigns(inp) + expdes.init_param_space(max_deg=1) + expdes.X = np.array([[0, 0], [1, 1]]) # Zwei Eingabedimensionen + expdes.Y = {'Z': [[0.4], [0.5]]} + expdes.x_values = [0, 1] + + mm = PCE(inp) mm.fit(expdes.X, expdes.Y) mod = PL() @@ -71,10 +109,12 @@ def pce_engine(): @pytest.fixture def gpe_engine(): inp = Input() + inp.add_marginals() inp.Marginals[0].name = 'x' inp.Marginals[0].dist_type = 'normal' inp.Marginals[0].parameters = [0, 1] + expdes = ExpDesigns(inp) expdes.init_param_space(max_deg=1) expdes.X = np.array([[0], [1], [0.5]]) @@ -90,12 +130,42 @@ def gpe_engine(): engine.trained = True return engine +@pytest.fixture +def gpe_engine_3d_plot(): + inp = Input() + # Füge zwei Prior-Verteilungen hinzu + inp.add_marginals() + inp.Marginals[0].name = 'x1' + inp.Marginals[0].dist_type = 'normal' + inp.Marginals[0].parameters = [0, 1] + + inp.add_marginals() + inp.Marginals[1].name = 'x2' + inp.Marginals[1].dist_type = 'normal' + inp.Marginals[1].parameters = [0, 1] + + expdes = ExpDesigns(inp) + expdes.init_param_space(max_deg=1) + # Erstelle Stichproben mit zwei Eingabedimensionen + expdes.X = np.array([[0, 0], [1, 1]]) # 2D-Array: (Anzahl der Stichproben, Anzahl der Priors) + expdes.Y = {'Z': [[0.4], [0.5]]} # Zielwerte + expdes.x_values = [0, 1] # Eingabewerte für beide Dimensionen + + mm = GPESkl(inp) + mm.fit(expdes.X, expdes.Y) + mod = PL() + engine = Engine(mm, mod, expdes) + engine.out_names = ['Z'] + engine.emulator = True + engine.trained = True + return engine + #%% Test PostProcessing init def test_postprocessing_noengine(): None -def test_postprocessing(basic_engine) -> None: +def test_postprocessing_untrained_engine(basic_engine) -> None: engine = basic_engine with pytest.raises(AttributeError) as excinfo: PostProcessing(engine) @@ -109,7 +179,6 @@ def test_postprocessing_gpe(gpe_engine) -> None: engine = gpe_engine PostProcessing(engine) #%% plot_moments - def test_plot_moments_pce(pce_engine) -> None: """ Plot moments for PCE metamodel @@ -173,7 +242,49 @@ def test_plot_moments_gpebar(gpe_engine) -> None: assert list(stdev.keys()) == ['Z'] assert stdev['Z'].shape == (1,) assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01) + +def test_plot_moments_with_invalid_model_type() -> None: + """ + Plot moments with invalid model type + """ + engine = type('Engine', (object,), {})() + engine.model_type = 'INVALID' + engine.trained = True + post = PostProcessing(engine) + with pytest.raises(ValueError) as excinfo: + post.plot_moments() + assert "Invalid model type" in str(excinfo.value) + #%% valid_metamodel +def test_plot_validation_multi_pce(pce_engine): + engine = pce_engine + post = PostProcessing(engine) + out_mean = {'Z': np.array([[0.4], [0.5], [0.45]])} + out_std = {'Z': np.array([[0.1], [0.1], [0.1]])} + post.model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])} + post._plot_validation_multi(out_mean, out_std) + +def test_plot_validation_multi_gpe(gpe_engine): + engine = gpe_engine + post = PostProcessing(engine) + out_mean = {'Z': np.array([[0.4], [0.5], [0.45]])} + out_std = {'Z': np.array([[0.1], [0.1], [0.1]])} + post.model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])} + post._plot_validation_multi(out_mean, out_std) + +def test_valid_metamodel_pce(pce_engine): + engine = pce_engine + post = PostProcessing(engine) + samples = np.array([[0], [1], [0.5]]) + model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])} + post.valid_metamodel(samples=samples, model_out_dict=model_out_dict) + +def test_valid_metamodel_gpe(gpe_engine): + engine = gpe_engine + post = PostProcessing(engine) + samples = np.array([[0], [1], [0.5]]) + model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])} + post.valid_metamodel(samples=samples, model_out_dict=model_out_dict) #%% check_accuracy @@ -192,8 +303,41 @@ def test_check_accuracy_gpe(gpe_engine) -> None: engine = gpe_engine post = PostProcessing(engine) post.check_accuracy(samples = engine.ExpDesign.X, outputs = engine.ExpDesign.Y) -#%% plot_seq_design +#%% plot_seq_design_diagnoxtics +def test_plot_seq_design_diagnostics(basic_engine_trained): + """ + Test the plot_seq_design_diagnostics method + """ + engine = basic_engine_trained + post = PostProcessing(engine) + post.plot_seq_design_diagnostics() + # Check if the plot was created and saved + assert os.path.exists(f"./{engine.out_dir}/Seq_Design_Diagnostics.{engine.out_format}") +def test_plot_seq_design_diagnostics_with_custom_values(basic_engine_trained): + """ + Test the plot_seq_design_diagnostics method with custom values + """ + engine = basic_engine_trained + engine.ExpDesign.X = [[0.1], [0.3], [0.5], [0.7], [0.9]] + engine.ExpDesign.Y = [[2], [4], [6], [8], [10]] + post = PostProcessing(engine) + post.plot_seq_design_diagnostics() + # Check if the plot was created and saved + assert os.path.exists(f"./{engine.out_dir}/Seq_Design_Diagnostics.{engine.out_format}") + +def test_plot_seq_design_diagnostics_with_empty_values(basic_engine_trained): + """ + Test the plot_seq_design_diagnostics method with empty values + """ + engine = basic_engine_trained + engine.ExpDesign.X = [] + engine.ExpDesign.Y = [] + post = PostProcessing(engine) + with pytest.raises(ValueError) as excinfo: + post.plot_seq_design_diagnostics() + assert "ExpDesign.X and ExpDesign.Y cannot be empty" in str(excinfo.value) + #%% sobol_indices def test_sobol_indices_pce(pce_engine) -> None: @@ -214,6 +358,17 @@ def test_sobol_indices_pce(pce_engine) -> None: assert sobol[1]['Z'].shape == (1,1,1) assert sobol[1]['Z'][0,0] == 1 +def test_sobol_indices_with_invalid_model_type(basic_engine_trained) -> None: + """ + Calculate sobol indices with invalid model type + """ + engine = basic_engine_trained + post = PostProcessing(engine) + post.model_type = 'INVALID' + with pytest.raises(ValueError) as excinfo: + post.sobol_indices() + assert "Invalid model type" in str(excinfo.value) + #%% check_reg_quality def test_check_reg_quality_pce(pce_engine) -> None: @@ -224,11 +379,103 @@ def test_check_reg_quality_pce(pce_engine) -> None: post = PostProcessing(engine) post.check_reg_quality(samples=engine.ExpDesign.X, outputs=engine.ExpDesign.Y) -#%% eplot_metamodel_3d +def test_check_reg_quality_gpe(gpe_engine) -> None: + """ + Check the regression quality for GPE metamodel + """ + engine = gpe_engine + post = PostProcessing(engine) + post.check_reg_quality(samples=engine.ExpDesign.X, outputs=engine.ExpDesign.Y) + # Add assertions to check the quality metrics if available + +def test_check_reg_quality_with_invalid_samples(pce_engine) -> None: + """ + Check the regression quality with invalid samples + """ + engine = pce_engine + post = PostProcessing(engine) + with pytest.raises(AttributeError) as excinfo: + post.check_reg_quality(outputs=engine.ExpDesign.Y) + assert "Samples cannot be empty" in str(excinfo.value) + +def test_check_reg_quality_with_invalid_outputs(pce_engine) -> None: + """ + Check the regression quality with invalid outputs + """ + engine = pce_engine + post = PostProcessing(engine) + with pytest.raises(ValueError) as excinfo: + post.check_reg_quality(samples=engine.ExpDesign.X, outputs=[]) + assert "Please provide the outputs of the model!" in str(excinfo.value) + +#%% plot_metamodel_3d +def test_plot_metamodel_3d_pce(pce_engine_3d_plot) -> None: + """ + Test the plot_metamodel_3d method for PCE metamodel + """ + engine = pce_engine_3d_plot + post = PostProcessing(engine) + post.plot_metamodel_3d() + # Check if the plot was created and saved + assert os.path.exists(f"./{engine.out_dir}/Metamodel_3D.{engine.out_format}") + +def test_plot_metamodel_3d_gpe(gpe_engine_3d_plot) -> None: + """ + Test the plot_metamodel_3d method for GPE metamodel + """ + engine = gpe_engine_3d_plot + post = PostProcessing(engine) + post.plot_metamodel_3d() + # Check if the plot was created and saved + assert os.path.exists(f"./{engine.out_dir}/Metamodel_3D.{engine.out_format}") -#%% _get_sample -#%% _eval_model +def test_plot_metamodel_3d_with_invalid_data(pce_engine_3d_plot) -> None: + """ + Test the plot_metamodel_3d method with invalid data + """ + engine = pce_engine_3d_plot + engine.ExpDesign.X = [] + post = PostProcessing(engine) + with pytest.raises(ValueError) as excinfo: + post.plot_metamodel_3d() + assert "Input data cannot be empty" in str(excinfo.value) #%% _plot_validation_multi +def test_plot_validation_multi(basic_engine_trained): + """ + Test the _plot_validation_multi method + """ + engine = basic_engine_trained + post = PostProcessing(engine) + y_val = {'key1': [1, 2, 3, 4, 5]} + y_val_std = {'key1': [0.1, 0.2, 0.3, 0.4, 0.5]} + post._plot_validation_multi(y_val, y_val_std) + # Check if the plot was created and saved + assert os.path.exists(f"./{engine.out_dir}/Model_vs_Model_key1.{engine.out_format}") + +def test_plot_validation_multi_with_multiple_keys(basic_engine_trained): + """ + Test the _plot_validation_multi method with multiple keys + """ + engine = basic_engine_trained + post = PostProcessing(engine) + y_val = {'key1': [1, 2, 3, 4, 5], 'key2': [2, 3, 4, 5, 6]} + y_val_std = {'key1': [0.1, 0.2, 0.3, 0.4, 0.5], 'key2': [0.2, 0.3, 0.4, 0.5, 0.6]} + post._plot_validation_multi(y_val, y_val_std) + # Check if the plots were created and saved + assert os.path.exists(f"./{engine.out_dir}/Model_vs_Model_key1.{engine.out_format}") + assert os.path.exists(f"./{engine.out_dir}/Model_vs_Model_key2.{engine.out_format}") + +def test_plot_validation_multi_with_empty_values(basic_engine_trained): + """ + Test the _plot_validation_multi method with empty values + """ + engine = basic_engine_trained + post = PostProcessing(engine) + y_val = {} + y_val_std = {} + with pytest.raises(ValueError) as excinfo: + post._plot_validation_multi(y_val, y_val_std) + assert "y_val and y_val_std cannot be empty" in str(excinfo.value) \ No newline at end of file -- GitLab