From 0cad6155f5bb04a743138d7b3f56a666031ea366 Mon Sep 17 00:00:00 2001 From: kohlhaasrebecca <rebecca.kohlhaas@outlook.com> Date: Wed, 9 Oct 2024 12:25:24 +0200 Subject: [PATCH] [fix] Add output-option to check_reg_quality Also added related PCE-tests --- src/bayesvalidrox/post_processing/post_processing.py | 10 ++++++++-- tests/test_PostProcessing.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/bayesvalidrox/post_processing/post_processing.py b/src/bayesvalidrox/post_processing/post_processing.py index e68e5120e..d5ad267b5 100644 --- a/src/bayesvalidrox/post_processing/post_processing.py +++ b/src/bayesvalidrox/post_processing/post_processing.py @@ -827,7 +827,7 @@ class PostProcessing: return self.total_sobol # ------------------------------------------------------------------------- - def check_reg_quality(self, n_samples=1000, samples=None): + def check_reg_quality(self, n_samples=1000, samples=None, outputs=None): """ Checks the quality of the metamodel for single output models based on: https://towardsdatascience.com/how-do-you-check-the-quality-of-your-regression-model-in-python-fa61759ff685 @@ -839,6 +839,9 @@ class PostProcessing: Number of parameter sets to use for the check. The default is 1000. samples : array of shape (n_samples, n_params), optional Parameter sets to use for the check. The default is None. + outputs : dict, optional + Output dictionary with model outputs for all given output types in + `Model.Output.names`. The default is None. Returns ------- @@ -852,7 +855,10 @@ class PostProcessing: self.n_samples = samples.shape[0] # Evaluate the original and the surrogate model - y_val = self._eval_model(samples, key_str='valid') + if outputs is None: + y_val = self._eval_model(samples, key_str='valid') + else: + y_val = outputs y_pce_val, _ = self.engine.eval_metamodel(samples=samples) # Open a pdf for the plots diff --git a/tests/test_PostProcessing.py b/tests/test_PostProcessing.py index 79cb666f7..ba905d647 100644 --- a/tests/test_PostProcessing.py +++ b/tests/test_PostProcessing.py @@ -124,6 +124,7 @@ def test_check_accuracy_pce(pce_engine) -> None: #%% plot_seq_design + #%% sobol_indices def test_sobol_indices_nopce(basic_engine) -> None: @@ -148,6 +149,15 @@ def test_sobol_indices_pce(pce_engine) -> None: assert sobol['Z'][0,0] == 1 #%% check_reg_quality + +def test_check_reg_quality_pce(pce_engine) -> None: + """ + Check the regression quality for PCE metamodel + """ + engine = pce_engine + post = PostProcessing(engine) + post.check_reg_quality(samples=engine.ExpDesign.X, outputs=engine.ExpDesign.Y) + #%% eval_pce_model_3d def test_eval_pce_model_3d_nopce(basic_engine) -> None: -- GitLab