Skip to content
Snippets Groups Projects
Commit 17427e4b authored by Alina Lacheim's avatar Alina Lacheim
Browse files

test for PostProcessing

parent 9b8ebbc4
No related branches found
No related tags found
1 merge request!37Fix/post processing
Pipeline #52234 failed
...@@ -7,19 +7,19 @@ Class PostProcessing: ...@@ -7,19 +7,19 @@ Class PostProcessing:
plot_moments plot_moments
valid_metamodel valid_metamodel
check_accuracy check_accuracy
plot_seq_design plot_seq_design_diagnostics
sobol_indices sobol_indices
plot_sobol
check_req_quality check_req_quality
eval_pce_model_3d plot_metamodel_3d
_get_sample
_eval_model
_plot_validation
_plot_validation_multi _plot_validation_multi
""" """
import sys import sys
sys.path.append("../src/") sys.path.append("../src/")
import numpy as np import numpy as np
import pytest import pytest
import os
from bayesvalidrox.post_processing.post_processing import PostProcessing from bayesvalidrox.post_processing.post_processing import PostProcessing
from bayesvalidrox.surrogate_models.inputs import Input from bayesvalidrox.surrogate_models.inputs import Input
...@@ -46,6 +46,14 @@ def basic_engine(): ...@@ -46,6 +46,14 @@ def basic_engine():
engine.emulator = True engine.emulator = True
return engine return engine
@pytest.fixture
def basic_engine_trained():
# Setup a basic engine fixture
engine = type('Engine', (object,), {})()
engine.trained = True
engine.ExpDesign = type('ExpDesign', (object,), {'X': [[0.1], [0.2], [0.3]], 'Y': [[1], [2], [3]]})
return engine
@pytest.fixture @pytest.fixture
def pce_engine(): def pce_engine():
inp = Input() inp = Input()
...@@ -59,6 +67,36 @@ def pce_engine(): ...@@ -59,6 +67,36 @@ def pce_engine():
expdes.Y = {'Z': [[0.4], [0.5], [0.45]]} expdes.Y = {'Z': [[0.4], [0.5], [0.45]]}
expdes.x_values = [0] expdes.x_values = [0]
mm = PCE(inp)
mm.fit(expdes.X, expdes.Y)
mod = PL()
engine = Engine(mm, mod, expdes)
engine.out_names = ['Z']
engine.emulator = True
engine.trained = True
return engine
@pytest.fixture
def pce_engine_3d_plot():
inp = Input()
inp.add_marginals()
inp.Marginals[0].name = 'x1'
inp.Marginals[0].dist_type = 'normal'
inp.Marginals[0].parameters = [0, 1]
inp.add_marginals()
inp.Marginals[1].name = 'x2'
inp.Marginals[1].dist_type = 'normal'
inp.Marginals[1].parameters = [0, 1]
expdes = ExpDesigns(inp)
expdes.init_param_space(max_deg=1)
expdes.X = np.array([[0, 0], [1, 1]]) # Zwei Eingabedimensionen
expdes.Y = {'Z': [[0.4], [0.5]]}
expdes.x_values = [0, 1]
mm = PCE(inp) mm = PCE(inp)
mm.fit(expdes.X, expdes.Y) mm.fit(expdes.X, expdes.Y)
mod = PL() mod = PL()
...@@ -71,10 +109,12 @@ def pce_engine(): ...@@ -71,10 +109,12 @@ def pce_engine():
@pytest.fixture @pytest.fixture
def gpe_engine(): def gpe_engine():
inp = Input() inp = Input()
inp.add_marginals() inp.add_marginals()
inp.Marginals[0].name = 'x' inp.Marginals[0].name = 'x'
inp.Marginals[0].dist_type = 'normal' inp.Marginals[0].dist_type = 'normal'
inp.Marginals[0].parameters = [0, 1] inp.Marginals[0].parameters = [0, 1]
expdes = ExpDesigns(inp) expdes = ExpDesigns(inp)
expdes.init_param_space(max_deg=1) expdes.init_param_space(max_deg=1)
expdes.X = np.array([[0], [1], [0.5]]) expdes.X = np.array([[0], [1], [0.5]])
...@@ -90,12 +130,42 @@ def gpe_engine(): ...@@ -90,12 +130,42 @@ def gpe_engine():
engine.trained = True engine.trained = True
return engine return engine
@pytest.fixture
def gpe_engine_3d_plot():
inp = Input()
# Füge zwei Prior-Verteilungen hinzu
inp.add_marginals()
inp.Marginals[0].name = 'x1'
inp.Marginals[0].dist_type = 'normal'
inp.Marginals[0].parameters = [0, 1]
inp.add_marginals()
inp.Marginals[1].name = 'x2'
inp.Marginals[1].dist_type = 'normal'
inp.Marginals[1].parameters = [0, 1]
expdes = ExpDesigns(inp)
expdes.init_param_space(max_deg=1)
# Erstelle Stichproben mit zwei Eingabedimensionen
expdes.X = np.array([[0, 0], [1, 1]]) # 2D-Array: (Anzahl der Stichproben, Anzahl der Priors)
expdes.Y = {'Z': [[0.4], [0.5]]} # Zielwerte
expdes.x_values = [0, 1] # Eingabewerte für beide Dimensionen
mm = GPESkl(inp)
mm.fit(expdes.X, expdes.Y)
mod = PL()
engine = Engine(mm, mod, expdes)
engine.out_names = ['Z']
engine.emulator = True
engine.trained = True
return engine
#%% Test PostProcessing init #%% Test PostProcessing init
def test_postprocessing_noengine(): def test_postprocessing_noengine():
None None
def test_postprocessing(basic_engine) -> None: def test_postprocessing_untrained_engine(basic_engine) -> None:
engine = basic_engine engine = basic_engine
with pytest.raises(AttributeError) as excinfo: with pytest.raises(AttributeError) as excinfo:
PostProcessing(engine) PostProcessing(engine)
...@@ -109,7 +179,6 @@ def test_postprocessing_gpe(gpe_engine) -> None: ...@@ -109,7 +179,6 @@ def test_postprocessing_gpe(gpe_engine) -> None:
engine = gpe_engine engine = gpe_engine
PostProcessing(engine) PostProcessing(engine)
#%% plot_moments #%% plot_moments
def test_plot_moments_pce(pce_engine) -> None: def test_plot_moments_pce(pce_engine) -> None:
""" """
Plot moments for PCE metamodel Plot moments for PCE metamodel
...@@ -173,7 +242,49 @@ def test_plot_moments_gpebar(gpe_engine) -> None: ...@@ -173,7 +242,49 @@ def test_plot_moments_gpebar(gpe_engine) -> None:
assert list(stdev.keys()) == ['Z'] assert list(stdev.keys()) == ['Z']
assert stdev['Z'].shape == (1,) assert stdev['Z'].shape == (1,)
assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01) assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01)
def test_plot_moments_with_invalid_model_type() -> None:
"""
Plot moments with invalid model type
"""
engine = type('Engine', (object,), {})()
engine.model_type = 'INVALID'
engine.trained = True
post = PostProcessing(engine)
with pytest.raises(ValueError) as excinfo:
post.plot_moments()
assert "Invalid model type" in str(excinfo.value)
#%% valid_metamodel #%% valid_metamodel
def test_plot_validation_multi_pce(pce_engine):
engine = pce_engine
post = PostProcessing(engine)
out_mean = {'Z': np.array([[0.4], [0.5], [0.45]])}
out_std = {'Z': np.array([[0.1], [0.1], [0.1]])}
post.model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])}
post._plot_validation_multi(out_mean, out_std)
def test_plot_validation_multi_gpe(gpe_engine):
engine = gpe_engine
post = PostProcessing(engine)
out_mean = {'Z': np.array([[0.4], [0.5], [0.45]])}
out_std = {'Z': np.array([[0.1], [0.1], [0.1]])}
post.model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])}
post._plot_validation_multi(out_mean, out_std)
def test_valid_metamodel_pce(pce_engine):
engine = pce_engine
post = PostProcessing(engine)
samples = np.array([[0], [1], [0.5]])
model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])}
post.valid_metamodel(samples=samples, model_out_dict=model_out_dict)
def test_valid_metamodel_gpe(gpe_engine):
engine = gpe_engine
post = PostProcessing(engine)
samples = np.array([[0], [1], [0.5]])
model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])}
post.valid_metamodel(samples=samples, model_out_dict=model_out_dict)
#%% check_accuracy #%% check_accuracy
...@@ -192,8 +303,41 @@ def test_check_accuracy_gpe(gpe_engine) -> None: ...@@ -192,8 +303,41 @@ def test_check_accuracy_gpe(gpe_engine) -> None:
engine = gpe_engine engine = gpe_engine
post = PostProcessing(engine) post = PostProcessing(engine)
post.check_accuracy(samples = engine.ExpDesign.X, outputs = engine.ExpDesign.Y) post.check_accuracy(samples = engine.ExpDesign.X, outputs = engine.ExpDesign.Y)
#%% plot_seq_design #%% plot_seq_design_diagnoxtics
def test_plot_seq_design_diagnostics(basic_engine_trained):
"""
Test the plot_seq_design_diagnostics method
"""
engine = basic_engine_trained
post = PostProcessing(engine)
post.plot_seq_design_diagnostics()
# Check if the plot was created and saved
assert os.path.exists(f"./{engine.out_dir}/Seq_Design_Diagnostics.{engine.out_format}")
def test_plot_seq_design_diagnostics_with_custom_values(basic_engine_trained):
"""
Test the plot_seq_design_diagnostics method with custom values
"""
engine = basic_engine_trained
engine.ExpDesign.X = [[0.1], [0.3], [0.5], [0.7], [0.9]]
engine.ExpDesign.Y = [[2], [4], [6], [8], [10]]
post = PostProcessing(engine)
post.plot_seq_design_diagnostics()
# Check if the plot was created and saved
assert os.path.exists(f"./{engine.out_dir}/Seq_Design_Diagnostics.{engine.out_format}")
def test_plot_seq_design_diagnostics_with_empty_values(basic_engine_trained):
"""
Test the plot_seq_design_diagnostics method with empty values
"""
engine = basic_engine_trained
engine.ExpDesign.X = []
engine.ExpDesign.Y = []
post = PostProcessing(engine)
with pytest.raises(ValueError) as excinfo:
post.plot_seq_design_diagnostics()
assert "ExpDesign.X and ExpDesign.Y cannot be empty" in str(excinfo.value)
#%% sobol_indices #%% sobol_indices
def test_sobol_indices_pce(pce_engine) -> None: def test_sobol_indices_pce(pce_engine) -> None:
...@@ -214,6 +358,17 @@ def test_sobol_indices_pce(pce_engine) -> None: ...@@ -214,6 +358,17 @@ def test_sobol_indices_pce(pce_engine) -> None:
assert sobol[1]['Z'].shape == (1,1,1) assert sobol[1]['Z'].shape == (1,1,1)
assert sobol[1]['Z'][0,0] == 1 assert sobol[1]['Z'][0,0] == 1
def test_sobol_indices_with_invalid_model_type(basic_engine_trained) -> None:
"""
Calculate sobol indices with invalid model type
"""
engine = basic_engine_trained
post = PostProcessing(engine)
post.model_type = 'INVALID'
with pytest.raises(ValueError) as excinfo:
post.sobol_indices()
assert "Invalid model type" in str(excinfo.value)
#%% check_reg_quality #%% check_reg_quality
def test_check_reg_quality_pce(pce_engine) -> None: def test_check_reg_quality_pce(pce_engine) -> None:
...@@ -224,11 +379,103 @@ def test_check_reg_quality_pce(pce_engine) -> None: ...@@ -224,11 +379,103 @@ def test_check_reg_quality_pce(pce_engine) -> None:
post = PostProcessing(engine) post = PostProcessing(engine)
post.check_reg_quality(samples=engine.ExpDesign.X, outputs=engine.ExpDesign.Y) post.check_reg_quality(samples=engine.ExpDesign.X, outputs=engine.ExpDesign.Y)
#%% eplot_metamodel_3d def test_check_reg_quality_gpe(gpe_engine) -> None:
"""
Check the regression quality for GPE metamodel
"""
engine = gpe_engine
post = PostProcessing(engine)
post.check_reg_quality(samples=engine.ExpDesign.X, outputs=engine.ExpDesign.Y)
# Add assertions to check the quality metrics if available
def test_check_reg_quality_with_invalid_samples(pce_engine) -> None:
"""
Check the regression quality with invalid samples
"""
engine = pce_engine
post = PostProcessing(engine)
with pytest.raises(AttributeError) as excinfo:
post.check_reg_quality(outputs=engine.ExpDesign.Y)
assert "Samples cannot be empty" in str(excinfo.value)
def test_check_reg_quality_with_invalid_outputs(pce_engine) -> None:
"""
Check the regression quality with invalid outputs
"""
engine = pce_engine
post = PostProcessing(engine)
with pytest.raises(ValueError) as excinfo:
post.check_reg_quality(samples=engine.ExpDesign.X, outputs=[])
assert "Please provide the outputs of the model!" in str(excinfo.value)
#%% plot_metamodel_3d
def test_plot_metamodel_3d_pce(pce_engine_3d_plot) -> None:
"""
Test the plot_metamodel_3d method for PCE metamodel
"""
engine = pce_engine_3d_plot
post = PostProcessing(engine)
post.plot_metamodel_3d()
# Check if the plot was created and saved
assert os.path.exists(f"./{engine.out_dir}/Metamodel_3D.{engine.out_format}")
def test_plot_metamodel_3d_gpe(gpe_engine_3d_plot) -> None:
"""
Test the plot_metamodel_3d method for GPE metamodel
"""
engine = gpe_engine_3d_plot
post = PostProcessing(engine)
post.plot_metamodel_3d()
# Check if the plot was created and saved
assert os.path.exists(f"./{engine.out_dir}/Metamodel_3D.{engine.out_format}")
#%% _get_sample def test_plot_metamodel_3d_with_invalid_data(pce_engine_3d_plot) -> None:
#%% _eval_model """
Test the plot_metamodel_3d method with invalid data
"""
engine = pce_engine_3d_plot
engine.ExpDesign.X = []
post = PostProcessing(engine)
with pytest.raises(ValueError) as excinfo:
post.plot_metamodel_3d()
assert "Input data cannot be empty" in str(excinfo.value)
#%% _plot_validation_multi #%% _plot_validation_multi
def test_plot_validation_multi(basic_engine_trained):
"""
Test the _plot_validation_multi method
"""
engine = basic_engine_trained
post = PostProcessing(engine)
y_val = {'key1': [1, 2, 3, 4, 5]}
y_val_std = {'key1': [0.1, 0.2, 0.3, 0.4, 0.5]}
post._plot_validation_multi(y_val, y_val_std)
# Check if the plot was created and saved
assert os.path.exists(f"./{engine.out_dir}/Model_vs_Model_key1.{engine.out_format}")
def test_plot_validation_multi_with_multiple_keys(basic_engine_trained):
"""
Test the _plot_validation_multi method with multiple keys
"""
engine = basic_engine_trained
post = PostProcessing(engine)
y_val = {'key1': [1, 2, 3, 4, 5], 'key2': [2, 3, 4, 5, 6]}
y_val_std = {'key1': [0.1, 0.2, 0.3, 0.4, 0.5], 'key2': [0.2, 0.3, 0.4, 0.5, 0.6]}
post._plot_validation_multi(y_val, y_val_std)
# Check if the plots were created and saved
assert os.path.exists(f"./{engine.out_dir}/Model_vs_Model_key1.{engine.out_format}")
assert os.path.exists(f"./{engine.out_dir}/Model_vs_Model_key2.{engine.out_format}")
def test_plot_validation_multi_with_empty_values(basic_engine_trained):
"""
Test the _plot_validation_multi method with empty values
"""
engine = basic_engine_trained
post = PostProcessing(engine)
y_val = {}
y_val_std = {}
with pytest.raises(ValueError) as excinfo:
post._plot_validation_multi(y_val, y_val_std)
assert "y_val and y_val_std cannot be empty" in str(excinfo.value)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment