From 0cad6155f5bb04a743138d7b3f56a666031ea366 Mon Sep 17 00:00:00 2001
From: kohlhaasrebecca <rebecca.kohlhaas@outlook.com>
Date: Wed, 9 Oct 2024 12:25:24 +0200
Subject: [PATCH] [fix] Add output-option to check_reg_quality

Also added related PCE-tests
---
 src/bayesvalidrox/post_processing/post_processing.py | 10 ++++++++--
 tests/test_PostProcessing.py                         | 10 ++++++++++
 2 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/src/bayesvalidrox/post_processing/post_processing.py b/src/bayesvalidrox/post_processing/post_processing.py
index e68e5120e..d5ad267b5 100644
--- a/src/bayesvalidrox/post_processing/post_processing.py
+++ b/src/bayesvalidrox/post_processing/post_processing.py
@@ -827,7 +827,7 @@ class PostProcessing:
         return self.total_sobol
 
     # -------------------------------------------------------------------------
-    def check_reg_quality(self, n_samples=1000, samples=None):
+    def check_reg_quality(self, n_samples=1000, samples=None, outputs=None):
         """
         Checks the quality of the metamodel for single output models based on:
         https://towardsdatascience.com/how-do-you-check-the-quality-of-your-regression-model-in-python-fa61759ff685
@@ -839,6 +839,9 @@ class PostProcessing:
             Number of parameter sets to use for the check. The default is 1000.
         samples : array of shape (n_samples, n_params), optional
             Parameter sets to use for the check. The default is None.
+        outputs : dict, optional
+            Output dictionary with model outputs for all given output types in
+            `Model.Output.names`. The default is None.
 
         Returns
         -------
@@ -852,7 +855,10 @@ class PostProcessing:
             self.n_samples = samples.shape[0]
 
         # Evaluate the original and the surrogate model
-        y_val = self._eval_model(samples, key_str='valid')
+        if outputs is None:
+            y_val = self._eval_model(samples, key_str='valid')
+        else: 
+            y_val = outputs
         y_pce_val, _ = self.engine.eval_metamodel(samples=samples)
 
         # Open a pdf for the plots
diff --git a/tests/test_PostProcessing.py b/tests/test_PostProcessing.py
index 79cb666f7..ba905d647 100644
--- a/tests/test_PostProcessing.py
+++ b/tests/test_PostProcessing.py
@@ -124,6 +124,7 @@ def test_check_accuracy_pce(pce_engine) -> None:
 
 
 #%% plot_seq_design
+
 #%% sobol_indices
 
 def test_sobol_indices_nopce(basic_engine) -> None:
@@ -148,6 +149,15 @@ def test_sobol_indices_pce(pce_engine) -> None:
     assert sobol['Z'][0,0] == 1
 
 #%% check_reg_quality
+
+def test_check_reg_quality_pce(pce_engine) -> None:
+    """
+    Check the regression quality for PCE metamodel
+    """
+    engine = pce_engine
+    post = PostProcessing(engine)
+    post.check_reg_quality(samples=engine.ExpDesign.X, outputs=engine.ExpDesign.Y)
+
 #%% eval_pce_model_3d
 
 def test_eval_pce_model_3d_nopce(basic_engine) -> None:
-- 
GitLab