From 9bd057a46de1f62e6d2a6f22b03659d2c402b145 Mon Sep 17 00:00:00 2001 From: kohlhaasrebecca <rebecca.kohlhaas@outlook.com> Date: Thu, 14 Nov 2024 14:17:43 +0100 Subject: [PATCH] Fix tests --- .../post_processing/post_processing.py | 23 ++++----- tests/test_PostProcessing.py | 50 +++---------------- 2 files changed, 18 insertions(+), 55 deletions(-) diff --git a/src/bayesvalidrox/post_processing/post_processing.py b/src/bayesvalidrox/post_processing/post_processing.py index c3b833b42..a4a5f55b5 100644 --- a/src/bayesvalidrox/post_processing/post_processing.py +++ b/src/bayesvalidrox/post_processing/post_processing.py @@ -38,6 +38,9 @@ class PostProcessing: """ def __init__(self, engine, name="calib", out_dir=""): + # PostProcessing only available for trained engines + if not engine.trained: + raise AttributeError('PostProcessing can only be performed on trained engines.') self.engine = engine self.name = name self.par_names = self.engine.ExpDesign.par_names @@ -663,16 +666,6 @@ class PostProcessing: comments="", ) - # Check if the x_values match the number of metamodel outputs - # TODO: How relevant is this check? - if np.array(x_values_orig).shape[0] != total_sobol_all[outputs[0]].shape[1]: - print( - "The number of MetaModel outputs does not match the x_values" - " specified in ExpDesign. Images are created with " - "equidistant numbers on the x-axis" - ) - x_values_orig = np.arange(0, 1, total_sobol_all[outputs[0]].shape[0]) - # Plot Sobol' indices self.plot_type = plot_type for i_order in range(1, max_order + 1): @@ -686,7 +679,7 @@ class PostProcessing: return sobol_all, total_sobol_all - def plot_sobol(self, outputs, par_names, sobol_type="sobol", i_order=0): + def plot_sobol(self, par_names, outputs, sobol_type="sobol", i_order=0): """ Generate plots for each output in the given set of Sobol' indices. @@ -697,20 +690,22 @@ class PostProcessing: if sobol_type == "totalsobol": sobol = self.totalsobol - fig = plt.figure() for _, output in enumerate(outputs): x = ( self.x_values[output] if isinstance(self.x_values, dict) else self.x_values ) - sobol_ = sobol[output][0] + sobol_ = sobol[output] + if sobol_type == 'sobol': + sobol_ = sobol_[0] # Compute quantiles q_5 = np.quantile(sobol[output], q=0.05, axis=0) q_97_5 = np.quantile(sobol[output], q=0.975, axis=0) if self.plot_type == "bar": + fig = plt.figure() ax = fig.add_axes([0, 0, 1, 1]) dict1 = {self.xlabel: x} dict2 = dict(zip(par_names, sobol_)) @@ -731,6 +726,8 @@ class PostProcessing: ax.set_ylabel("Total Sobol indices, $S^T$") else: + fig = plt.figure() + ax = fig.add_axes([0, 0, 1, 1]) for i, sobol_indices in enumerate(sobol_): plt.plot( x, diff --git a/tests/test_PostProcessing.py b/tests/test_PostProcessing.py index eb3d19466..3edddf686 100644 --- a/tests/test_PostProcessing.py +++ b/tests/test_PostProcessing.py @@ -49,6 +49,7 @@ def basic_engine(): def pce_engine(): inp = Input() inp.add_marginals() + inp.Marginals[0].name = 'x' inp.Marginals[0].dist_type = 'normal' inp.Marginals[0].parameters = [0, 1] expdes = ExpDesigns(inp) @@ -63,6 +64,7 @@ def pce_engine(): engine = Engine(mm, mod, expdes) engine.out_names = ['Z'] engine.emulator = True + engine.trained = True return engine #%% Test PostProcessing init @@ -72,6 +74,12 @@ def test_postprocessing_noengine(): def test_postprocessing(basic_engine) -> None: engine = basic_engine + with pytest.raises(AttributeError) as excinfo: + PostProcessing(engine) + assert str(excinfo.value) == 'PostProcessing can only be performed on trained engines.' + +def test_postprocessing_pce(pce_engine) -> None: + engine = pce_engine PostProcessing(engine) @@ -127,16 +135,6 @@ def test_check_accuracy_pce(pce_engine) -> None: #%% sobol_indices -def test_sobol_indices_nopce(basic_engine) -> None: - """ - Calculate sobol indices for non-PCE metamodel - """ - engine = basic_engine - post = PostProcessing(engine) - with pytest.raises(AttributeError) as excinfo: - post.sobol_indices() - assert str(excinfo.value) == 'Sobol indices only support PCE-type models!' - def test_sobol_indices_pce(pce_engine) -> None: """ Calculate sobol indices for PCE metamodel @@ -167,41 +165,9 @@ def test_check_reg_quality_pce(pce_engine) -> None: #%% eplot_metamodel_3d -def test_plot_metamodel_3d_nopce(basic_engine) -> None: - """ - 3d eval of non-PCE metamodel - """ - engine = basic_engine - post = PostProcessing(engine) - with pytest.raises(AttributeError) as excinfo: - post.plot_metamodel_3d() - assert str(excinfo.value) == 'This function is only applicable if the MetaModel input dimension is 2.' - - #%% _get_sample #%% _eval_model #%% _plot_validation -def test_plot_validation_nopce(basic_engine) -> None: - """ - Plot validation of non-PCE metamodel - """ - engine = basic_engine - samples = engine.ExpDesign.generate_samples(10,'random') - post = PostProcessing(engine) - with pytest.raises(AttributeError) as excinfo: - post._plot_validation(samples) - assert str(excinfo.value) == 'This evaluation only support PCE-type models!' - #%% _plot_validation_multi - -def test_plot_validation_multi_nopce(basic_engine) -> None: - """ - Plot multi-validation of non-PCE metamodel - """ - engine = basic_engine - post = PostProcessing(engine) - with pytest.raises(AttributeError) as excinfo: - post._plot_validation_multi() - assert str(excinfo.value) == 'This evaluation only support PCE-type models!' \ No newline at end of file -- GitLab