Skip to content
Snippets Groups Projects
Commit 9bd057a4 authored by kohlhaasrebecca's avatar kohlhaasrebecca
Browse files

Fix tests

parent 7b92b89d
No related branches found
No related tags found
1 merge request!37Fix/post processing
......@@ -38,6 +38,9 @@ class PostProcessing:
"""
def __init__(self, engine, name="calib", out_dir=""):
# PostProcessing only available for trained engines
if not engine.trained:
raise AttributeError('PostProcessing can only be performed on trained engines.')
self.engine = engine
self.name = name
self.par_names = self.engine.ExpDesign.par_names
......@@ -663,16 +666,6 @@ class PostProcessing:
comments="",
)
# Check if the x_values match the number of metamodel outputs
# TODO: How relevant is this check?
if np.array(x_values_orig).shape[0] != total_sobol_all[outputs[0]].shape[1]:
print(
"The number of MetaModel outputs does not match the x_values"
" specified in ExpDesign. Images are created with "
"equidistant numbers on the x-axis"
)
x_values_orig = np.arange(0, 1, total_sobol_all[outputs[0]].shape[0])
# Plot Sobol' indices
self.plot_type = plot_type
for i_order in range(1, max_order + 1):
......@@ -686,7 +679,7 @@ class PostProcessing:
return sobol_all, total_sobol_all
def plot_sobol(self, outputs, par_names, sobol_type="sobol", i_order=0):
def plot_sobol(self, par_names, outputs, sobol_type="sobol", i_order=0):
"""
Generate plots for each output in the given set of Sobol' indices.
......@@ -697,20 +690,22 @@ class PostProcessing:
if sobol_type == "totalsobol":
sobol = self.totalsobol
fig = plt.figure()
for _, output in enumerate(outputs):
x = (
self.x_values[output]
if isinstance(self.x_values, dict)
else self.x_values
)
sobol_ = sobol[output][0]
sobol_ = sobol[output]
if sobol_type == 'sobol':
sobol_ = sobol_[0]
# Compute quantiles
q_5 = np.quantile(sobol[output], q=0.05, axis=0)
q_97_5 = np.quantile(sobol[output], q=0.975, axis=0)
if self.plot_type == "bar":
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
dict1 = {self.xlabel: x}
dict2 = dict(zip(par_names, sobol_))
......@@ -731,6 +726,8 @@ class PostProcessing:
ax.set_ylabel("Total Sobol indices, $S^T$")
else:
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
for i, sobol_indices in enumerate(sobol_):
plt.plot(
x,
......
......@@ -49,6 +49,7 @@ def basic_engine():
def pce_engine():
inp = Input()
inp.add_marginals()
inp.Marginals[0].name = 'x'
inp.Marginals[0].dist_type = 'normal'
inp.Marginals[0].parameters = [0, 1]
expdes = ExpDesigns(inp)
......@@ -63,6 +64,7 @@ def pce_engine():
engine = Engine(mm, mod, expdes)
engine.out_names = ['Z']
engine.emulator = True
engine.trained = True
return engine
#%% Test PostProcessing init
......@@ -72,6 +74,12 @@ def test_postprocessing_noengine():
def test_postprocessing(basic_engine) -> None:
engine = basic_engine
with pytest.raises(AttributeError) as excinfo:
PostProcessing(engine)
assert str(excinfo.value) == 'PostProcessing can only be performed on trained engines.'
def test_postprocessing_pce(pce_engine) -> None:
engine = pce_engine
PostProcessing(engine)
......@@ -127,16 +135,6 @@ def test_check_accuracy_pce(pce_engine) -> None:
#%% sobol_indices
def test_sobol_indices_nopce(basic_engine) -> None:
"""
Calculate sobol indices for non-PCE metamodel
"""
engine = basic_engine
post = PostProcessing(engine)
with pytest.raises(AttributeError) as excinfo:
post.sobol_indices()
assert str(excinfo.value) == 'Sobol indices only support PCE-type models!'
def test_sobol_indices_pce(pce_engine) -> None:
"""
Calculate sobol indices for PCE metamodel
......@@ -167,41 +165,9 @@ def test_check_reg_quality_pce(pce_engine) -> None:
#%% eplot_metamodel_3d
def test_plot_metamodel_3d_nopce(basic_engine) -> None:
"""
3d eval of non-PCE metamodel
"""
engine = basic_engine
post = PostProcessing(engine)
with pytest.raises(AttributeError) as excinfo:
post.plot_metamodel_3d()
assert str(excinfo.value) == 'This function is only applicable if the MetaModel input dimension is 2.'
#%% _get_sample
#%% _eval_model
#%% _plot_validation
def test_plot_validation_nopce(basic_engine) -> None:
"""
Plot validation of non-PCE metamodel
"""
engine = basic_engine
samples = engine.ExpDesign.generate_samples(10,'random')
post = PostProcessing(engine)
with pytest.raises(AttributeError) as excinfo:
post._plot_validation(samples)
assert str(excinfo.value) == 'This evaluation only support PCE-type models!'
#%% _plot_validation_multi
def test_plot_validation_multi_nopce(basic_engine) -> None:
"""
Plot multi-validation of non-PCE metamodel
"""
engine = basic_engine
post = PostProcessing(engine)
with pytest.raises(AttributeError) as excinfo:
post._plot_validation_multi()
assert str(excinfo.value) == 'This evaluation only support PCE-type models!'
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment