diff --git a/tests/test_PostProcessing.py b/tests/test_PostProcessing.py index 3edddf686d66535071f0a9062684a80a5a1c8aa4..7009d6c6cfb42d3cf744c41e1f72102b6e58b10b 100644 --- a/tests/test_PostProcessing.py +++ b/tests/test_PostProcessing.py @@ -29,6 +29,7 @@ from bayesvalidrox.surrogate_models.surrogate_models import MetaModel from bayesvalidrox.surrogate_models.polynomial_chaos import PCE from bayesvalidrox.pylink.pylink import PyLinkForwardModel as PL from bayesvalidrox.surrogate_models.engine import Engine +from bayesvalidrox import GPESkl @pytest.fixture @@ -67,6 +68,28 @@ def pce_engine(): engine.trained = True return engine +@pytest.fixture +def gpe_engine(): + inp = Input() + inp.add_marginals() + inp.Marginals[0].name = 'x' + inp.Marginals[0].dist_type = 'normal' + inp.Marginals[0].parameters = [0, 1] + expdes = ExpDesigns(inp) + expdes.init_param_space(max_deg=1) + expdes.X = np.array([[0], [1], [0.5]]) + expdes.Y = {'Z': [[0.4], [0.5], [0.45]]} + expdes.x_values = [0] + + mm = GPESkl(inp) + mm.fit(expdes.X, expdes.Y) + mod = PL() + engine = Engine(mm, mod, expdes) + engine.out_names = ['Z'] + engine.emulator = True + engine.trained = True + return engine + #%% Test PostProcessing init def test_postprocessing_noengine(): @@ -82,7 +105,9 @@ def test_postprocessing_pce(pce_engine) -> None: engine = pce_engine PostProcessing(engine) - +def test_postprocessing_gpe(gpe_engine) -> None: + engine = gpe_engine + PostProcessing(engine) #%% plot_moments def test_plot_moments_pce(pce_engine) -> None: @@ -117,7 +142,37 @@ def test_plot_moments_pcebar(pce_engine) -> None: assert stdev['Z'].shape == (1,) assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01) +def test_plot_moments_gpe(gpe_engine) -> None: + """ + Plot moments for GPE metamodel + """ + engine = gpe_engine + post = PostProcessing(engine) + mean, stdev = post.plot_moments() + # Check the mean dict + assert list(mean.keys()) == ['Z'] + assert mean['Z'].shape == (1,) + assert mean['Z'][0] == pytest.approx(0.4, abs=0.01) + # Check the stdev dict + assert list(stdev.keys()) == ['Z'] + assert stdev['Z'].shape == (1,) + assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01) +def test_plot_moments_gpebar(gpe_engine) -> None: + """ + Plot moments for GPE metamodel with bar-plot + """ + engine = gpe_engine + post = PostProcessing(engine) + mean, stdev = post.plot_moments(plot_type='bar') + # Check the mean dict + assert list(mean.keys()) == ['Z'] + assert mean['Z'].shape == (1,) + assert mean['Z'][0] == pytest.approx(0.4, abs=0.01) + # Check the stdev dict + assert list(stdev.keys()) == ['Z'] + assert stdev['Z'].shape == (1,) + assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01) #%% valid_metamodel #%% check_accuracy @@ -130,7 +185,13 @@ def test_check_accuracy_pce(pce_engine) -> None: post = PostProcessing(engine) post.check_accuracy(samples = engine.ExpDesign.X, outputs = engine.ExpDesign.Y) - +def test_check_accuracy_gpe(gpe_engine) -> None: + """ + Check accuracy for GPE metamodel + """ + engine = gpe_engine + post = PostProcessing(engine) + post.check_accuracy(samples = engine.ExpDesign.X, outputs = engine.ExpDesign.Y) #%% plot_seq_design #%% sobol_indices