Skip to content
Snippets Groups Projects
Commit d5937852 authored by Alina Lacheim's avatar Alina Lacheim
Browse files

added GPE engine for PostProcessing test and tests with the GPE engine

parent 900d958d
No related branches found
No related tags found
1 merge request!37Fix/post processing
Pipeline #51412 passed
......@@ -29,6 +29,7 @@ from bayesvalidrox.surrogate_models.surrogate_models import MetaModel
from bayesvalidrox.surrogate_models.polynomial_chaos import PCE
from bayesvalidrox.pylink.pylink import PyLinkForwardModel as PL
from bayesvalidrox.surrogate_models.engine import Engine
from bayesvalidrox import GPESkl
@pytest.fixture
......@@ -67,6 +68,28 @@ def pce_engine():
engine.trained = True
return engine
@pytest.fixture
def gpe_engine():
inp = Input()
inp.add_marginals()
inp.Marginals[0].name = 'x'
inp.Marginals[0].dist_type = 'normal'
inp.Marginals[0].parameters = [0, 1]
expdes = ExpDesigns(inp)
expdes.init_param_space(max_deg=1)
expdes.X = np.array([[0], [1], [0.5]])
expdes.Y = {'Z': [[0.4], [0.5], [0.45]]}
expdes.x_values = [0]
mm = GPESkl(inp)
mm.fit(expdes.X, expdes.Y)
mod = PL()
engine = Engine(mm, mod, expdes)
engine.out_names = ['Z']
engine.emulator = True
engine.trained = True
return engine
#%% Test PostProcessing init
def test_postprocessing_noengine():
......@@ -82,7 +105,9 @@ def test_postprocessing_pce(pce_engine) -> None:
engine = pce_engine
PostProcessing(engine)
def test_postprocessing_gpe(gpe_engine) -> None:
engine = gpe_engine
PostProcessing(engine)
#%% plot_moments
def test_plot_moments_pce(pce_engine) -> None:
......@@ -117,7 +142,37 @@ def test_plot_moments_pcebar(pce_engine) -> None:
assert stdev['Z'].shape == (1,)
assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01)
def test_plot_moments_gpe(gpe_engine) -> None:
"""
Plot moments for GPE metamodel
"""
engine = gpe_engine
post = PostProcessing(engine)
mean, stdev = post.plot_moments()
# Check the mean dict
assert list(mean.keys()) == ['Z']
assert mean['Z'].shape == (1,)
assert mean['Z'][0] == pytest.approx(0.4, abs=0.01)
# Check the stdev dict
assert list(stdev.keys()) == ['Z']
assert stdev['Z'].shape == (1,)
assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01)
def test_plot_moments_gpebar(gpe_engine) -> None:
"""
Plot moments for GPE metamodel with bar-plot
"""
engine = gpe_engine
post = PostProcessing(engine)
mean, stdev = post.plot_moments(plot_type='bar')
# Check the mean dict
assert list(mean.keys()) == ['Z']
assert mean['Z'].shape == (1,)
assert mean['Z'][0] == pytest.approx(0.4, abs=0.01)
# Check the stdev dict
assert list(stdev.keys()) == ['Z']
assert stdev['Z'].shape == (1,)
assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01)
#%% valid_metamodel
#%% check_accuracy
......@@ -130,7 +185,13 @@ def test_check_accuracy_pce(pce_engine) -> None:
post = PostProcessing(engine)
post.check_accuracy(samples = engine.ExpDesign.X, outputs = engine.ExpDesign.Y)
def test_check_accuracy_gpe(gpe_engine) -> None:
"""
Check accuracy for GPE metamodel
"""
engine = gpe_engine
post = PostProcessing(engine)
post.check_accuracy(samples = engine.ExpDesign.X, outputs = engine.ExpDesign.Y)
#%% plot_seq_design
#%% sobol_indices
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment