From d5937852a4c75a3ae029fef7ac9ac2dcfd012718 Mon Sep 17 00:00:00 2001
From: Alina Lacheim <a.lacheim@outlook.de>
Date: Thu, 14 Nov 2024 17:44:00 +0100
Subject: [PATCH] added GPE engine for PostProcessing test and tests with the
 GPE engine

---
 tests/test_PostProcessing.py | 65 ++++++++++++++++++++++++++++++++++--
 1 file changed, 63 insertions(+), 2 deletions(-)

diff --git a/tests/test_PostProcessing.py b/tests/test_PostProcessing.py
index 3edddf686..7009d6c6c 100644
--- a/tests/test_PostProcessing.py
+++ b/tests/test_PostProcessing.py
@@ -29,6 +29,7 @@ from bayesvalidrox.surrogate_models.surrogate_models import MetaModel
 from bayesvalidrox.surrogate_models.polynomial_chaos import PCE
 from bayesvalidrox.pylink.pylink import PyLinkForwardModel as PL
 from bayesvalidrox.surrogate_models.engine import Engine
+from bayesvalidrox import GPESkl
 
 
 @pytest.fixture
@@ -67,6 +68,28 @@ def pce_engine():
     engine.trained = True
     return engine
 
+@pytest.fixture
+def gpe_engine(): 
+    inp = Input()
+    inp.add_marginals()
+    inp.Marginals[0].name = 'x'
+    inp.Marginals[0].dist_type = 'normal'
+    inp.Marginals[0].parameters = [0, 1]
+    expdes = ExpDesigns(inp)
+    expdes.init_param_space(max_deg=1)
+    expdes.X = np.array([[0], [1], [0.5]])
+    expdes.Y = {'Z': [[0.4], [0.5], [0.45]]}
+    expdes.x_values = [0]
+    
+    mm = GPESkl(inp)
+    mm.fit(expdes.X, expdes.Y)
+    mod = PL()
+    engine = Engine(mm, mod, expdes)
+    engine.out_names = ['Z']
+    engine.emulator = True
+    engine.trained = True
+    return engine
+
 #%% Test PostProcessing init
 
 def test_postprocessing_noengine():
@@ -82,7 +105,9 @@ def test_postprocessing_pce(pce_engine) -> None:
     engine = pce_engine
     PostProcessing(engine)
     
-    
+def test_postprocessing_gpe(gpe_engine) -> None:
+    engine = gpe_engine
+    PostProcessing(engine)
 #%% plot_moments
 
 def test_plot_moments_pce(pce_engine) -> None:
@@ -117,7 +142,37 @@ def test_plot_moments_pcebar(pce_engine) -> None:
     assert stdev['Z'].shape == (1,)
     assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01)
     
+def test_plot_moments_gpe(gpe_engine) -> None:
+    """
+    Plot moments for GPE metamodel
+    """
+    engine = gpe_engine
+    post = PostProcessing(engine)
+    mean, stdev = post.plot_moments()
+    # Check the mean dict
+    assert list(mean.keys()) == ['Z']
+    assert mean['Z'].shape == (1,)
+    assert mean['Z'][0] == pytest.approx(0.4, abs=0.01)
+    # Check the stdev dict
+    assert list(stdev.keys()) == ['Z']
+    assert stdev['Z'].shape == (1,)
+    assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01)
 
+def test_plot_moments_gpebar(gpe_engine) -> None:
+    """
+    Plot moments for GPE metamodel with bar-plot
+    """
+    engine = gpe_engine
+    post = PostProcessing(engine)
+    mean, stdev = post.plot_moments(plot_type='bar')
+    # Check the mean dict
+    assert list(mean.keys()) == ['Z']
+    assert mean['Z'].shape == (1,)
+    assert mean['Z'][0] == pytest.approx(0.4, abs=0.01)
+    # Check the stdev dict
+    assert list(stdev.keys()) == ['Z']
+    assert stdev['Z'].shape == (1,)
+    assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01)
 #%% valid_metamodel
 
 #%% check_accuracy
@@ -130,7 +185,13 @@ def test_check_accuracy_pce(pce_engine) -> None:
     post = PostProcessing(engine)
     post.check_accuracy(samples = engine.ExpDesign.X, outputs = engine.ExpDesign.Y)
 
-
+def test_check_accuracy_gpe(gpe_engine) -> None:
+    """
+    Check accuracy for GPE metamodel
+    """
+    engine = gpe_engine
+    post = PostProcessing(engine)
+    post.check_accuracy(samples = engine.ExpDesign.X, outputs = engine.ExpDesign.Y)
 #%% plot_seq_design
 
 #%% sobol_indices
-- 
GitLab