diff --git a/src/bayesvalidrox/post_processing/post_processing.py b/src/bayesvalidrox/post_processing/post_processing.py index e68e5120e3de6187d75662cfee4098247d4f93a2..d5ad267b5c2a1a18914b4d962384cf53732f5be4 100644 --- a/src/bayesvalidrox/post_processing/post_processing.py +++ b/src/bayesvalidrox/post_processing/post_processing.py @@ -827,7 +827,7 @@ class PostProcessing: return self.total_sobol # ------------------------------------------------------------------------- - def check_reg_quality(self, n_samples=1000, samples=None): + def check_reg_quality(self, n_samples=1000, samples=None, outputs=None): """ Checks the quality of the metamodel for single output models based on: https://towardsdatascience.com/how-do-you-check-the-quality-of-your-regression-model-in-python-fa61759ff685 @@ -839,6 +839,9 @@ class PostProcessing: Number of parameter sets to use for the check. The default is 1000. samples : array of shape (n_samples, n_params), optional Parameter sets to use for the check. The default is None. + outputs : dict, optional + Output dictionary with model outputs for all given output types in + `Model.Output.names`. The default is None. Returns ------- @@ -852,7 +855,10 @@ class PostProcessing: self.n_samples = samples.shape[0] # Evaluate the original and the surrogate model - y_val = self._eval_model(samples, key_str='valid') + if outputs is None: + y_val = self._eval_model(samples, key_str='valid') + else: + y_val = outputs y_pce_val, _ = self.engine.eval_metamodel(samples=samples) # Open a pdf for the plots diff --git a/tests/test_PostProcessing.py b/tests/test_PostProcessing.py index 79cb666f74e3cb069353716476f9a75d9c717054..ba905d6473a62b53182744d5f10e7836adc39c1e 100644 --- a/tests/test_PostProcessing.py +++ b/tests/test_PostProcessing.py @@ -124,6 +124,7 @@ def test_check_accuracy_pce(pce_engine) -> None: #%% plot_seq_design + #%% sobol_indices def test_sobol_indices_nopce(basic_engine) -> None: @@ -148,6 +149,15 @@ def test_sobol_indices_pce(pce_engine) -> None: assert sobol['Z'][0,0] == 1 #%% check_reg_quality + +def test_check_reg_quality_pce(pce_engine) -> None: + """ + Check the regression quality for PCE metamodel + """ + engine = pce_engine + post = PostProcessing(engine) + post.check_reg_quality(samples=engine.ExpDesign.X, outputs=engine.ExpDesign.Y) + #%% eval_pce_model_3d def test_eval_pce_model_3d_nopce(basic_engine) -> None: