From b80b2140f9c49ba5a32880112c927aedc2a04689 Mon Sep 17 00:00:00 2001
From: kohlhaasrebecca <rebecca.kohlhaas@outlook.com>
Date: Wed, 27 Nov 2024 10:06:46 +0100
Subject: [PATCH] Remove PostProcessing._eval_model

---
 .../post_processing/post_processing.py        | 197 +++++++++---------
 1 file changed, 97 insertions(+), 100 deletions(-)

diff --git a/src/bayesvalidrox/post_processing/post_processing.py b/src/bayesvalidrox/post_processing/post_processing.py
index b812c859f..6e2ee60a2 100644
--- a/src/bayesvalidrox/post_processing/post_processing.py
+++ b/src/bayesvalidrox/post_processing/post_processing.py
@@ -40,10 +40,12 @@ class PostProcessing:
     def __init__(self, engine, name="calib", out_dir="", out_format="pdf"):
         # PostProcessing only available for trained engines
         if not engine.trained:
-            raise AttributeError('PostProcessing can only be performed on trained engines.')
+            raise AttributeError(
+                "PostProcessing can only be performed on trained engines."
+            )
         self.engine = engine
         self.name = name
-        self.out_format = out_format 
+        self.out_format = out_format
         self.par_names = self.engine.ExpDesign.par_names
         self.x_values = self.engine.ExpDesign.x_values
 
@@ -72,7 +74,7 @@ class PostProcessing:
     # -------------------------------------------------------------------------
     def plot_moments(self, plot_type: str = None):
         """
-        Plots the moments in a user defined output format (standard is pdf) in the directory 
+        Plots the moments in a user defined output format (standard is pdf) in the directory
         `Outputs_PostProcessing`.
 
         Parameters
@@ -90,7 +92,6 @@ class PostProcessing:
             Standard deviation of the model outputs.
 
         """
-
         bar_plot = bool(plot_type == "bar")
         meta_model_type = self.engine.MetaModel.meta_model_type
 
@@ -100,8 +101,7 @@ class PostProcessing:
         # Compute the moments with the PCEModel object
         self.means, self.stds = self.engine.MetaModel.calculate_moments()
 
-        # Plot the best fit line, set the linewidth (lw), color and
-        # transparency (alpha) of the line
+        # Plot the best fit line
         for key in self.engine.out_names:
             fig, ax = plt.subplots(nrows=1, ncols=2)
 
@@ -184,11 +184,11 @@ class PostProcessing:
             if not bar_plot:
                 ax[0].legend(loc="best")
                 ax[1].legend(loc="best")
-
             plt.tight_layout()
-
-            # save the current figure
-            fig.savefig(f"{self.out_dir}Mean_Std_PCE_{key}.{self.out_format}", bbox_inches="tight")
+            fig.savefig(
+                f"{self.out_dir}Mean_Std_PCE_{key}.{self.out_format}",
+                bbox_inches="tight",
+            )
 
         return self.means, self.stds
 
@@ -209,6 +209,9 @@ class PostProcessing:
         x_axis : str, optional
             Label of x axis. The default is `'Time [s]'`.
 
+        Returns
+        -------
+        None.
 
         """
         if samples is None:
@@ -219,21 +222,12 @@ class PostProcessing:
         if model_out_dict is not None:
             self.model_out_dict = model_out_dict
         else:
-            self.model_out_dict = self._eval_model(samples, key_str="valid")
+            self.model_out_dict, _ = self.engine.Model.run_model_parallel(samples, key_str="valid")
         self.pce_out_mean, self.pce_out_std = self.engine.eval_metamodel(samples)
 
-        try:
-            key = self.engine.out_names[1]
-        except IndexError:
-            key = self.engine.out_names[0]
-
-        n_obs = self.model_out_dict[key].shape[1]
-
-        # if n_obs == 1:
-        #     self._plot_validation()
-        # else:
         self._plot_validation_multi()
 
+        # TODO: should this be kept here?
         # Zip the subdirectories
         self.engine.Model.zip_subdirs(
             f"{self.engine.Model.name}valid", f"{self.engine.Model.name}valid_"
@@ -265,6 +259,10 @@ class PostProcessing:
         AttributeError
             When neither n_samples nor samples are provided.
 
+        Returns
+        -------
+        None.
+
         """
         # Set the number of samples
         if samples is None and n_samples is None:
@@ -277,11 +275,8 @@ class PostProcessing:
         samples = self._get_sample(n_samples) if samples is None else samples
 
         # Run the original model with the generated samples
-        outputs = (
-            self._eval_model(samples, key_str="validSet")
-            if outputs is None
-            else outputs
-        )
+        if outputs is None:
+            outputs, _ = self.engine.Model.run_model_parallel(samples, key_str="valid")
 
         # Run the PCE model with the generated samples
         metamod_outputs, _ = self.engine.eval_metamodel(samples)
@@ -290,7 +285,7 @@ class PostProcessing:
         self.valid_error = {}
         # Loop over the keys and compute RMSE error.
         for key in self.engine.out_names:
-            # Root mena square
+            # Root mean square
             self.rmse[key] = mean_squared_error(
                 outputs[key],
                 metamod_outputs[key],
@@ -314,7 +309,7 @@ class PostProcessing:
                     )
                 )
             )
-        # Save error dicts in PCEModel object
+        # Save error dicts in MetaModel object
         self.engine.MetaModel.rmse = self.rmse
         self.engine.MetaModel.valid_error = self.valid_error
 
@@ -329,6 +324,10 @@ class PostProcessing:
         ref_BME_KLD : array, optional
             Reference BME and KLD . The default is `None`.
 
+        Returns
+        -------
+        None.
+
         """
         engine = self.engine
         n_init_samples = engine.ExpDesign.n_init_samples
@@ -494,7 +493,10 @@ class PostProcessing:
 
                 # save the current figure
                 plot_name = plot.replace(" ", "_")
-                fig.savefig(f"./{newpath}/seq_{plot_name}.{self.out_format}", bbox_inches="tight")
+                fig.savefig(
+                    f"./{newpath}/seq_{plot_name}.{self.out_format}",
+                    bbox_inches="tight",
+                )
                 # Destroy the current plot
                 plt.close()
                 # Save arrays into files
@@ -597,7 +599,10 @@ class PostProcessing:
 
                 # save the current figure
                 plot_name = plot.replace(" ", "_")
-                fig.savefig(f"./{newpath}/seq_{plot_name}.{self.out_format}", bbox_inches="tight")
+                fig.savefig(
+                    f"./{newpath}/seq_{plot_name}.{self.out_format}",
+                    bbox_inches="tight",
+                )
                 # Destroy the current plot
                 plt.close()
 
@@ -626,6 +631,13 @@ class PostProcessing:
         AttributeError
             MetaModel in given Engine needs to be of type 'pce' or 'apce'.
 
+        Returns
+        -------
+        sobol_all : dict
+            All possible Sobol' indices for the given metamodel.
+        total_sobol_all : dict
+            All Total Sobol' indices for the given metamodel.
+
         """
         # This function currently only supports PCE/aPCE
         metamod = self.engine.MetaModel
@@ -680,10 +692,34 @@ class PostProcessing:
 
         return sobol_all, total_sobol_all
 
-    def plot_sobol(self, par_names, outputs, sobol_type="sobol", i_order=0):
+    def plot_sobol(
+        self,
+        par_names: list,
+        outputs: list,
+        sobol_type: str = "sobol",
+        i_order: int = 0,
+    ) -> None:
         """
         Generate plots for each output in the given set of Sobol' indices.
 
+        Parameters
+        ----------
+        par_names : list
+            Parameter names for each Sobol' index.
+        outputs : list
+            Output names to be plotted.
+        sobol_type : string, optional
+            Type of Sobol' indices to visualize. Can be either
+            'sobol' or 'totalsobol'. The default is 'sobol'.
+        i_order : int, optional
+            Order of Sobol' index that should be plotted.
+            This parameter is only applied for sobol_type = 'sobol'.
+            The default is 0.
+
+        Returns
+        -------
+        None.
+
         """
         sobol = None
         if sobol_type == "sobol":
@@ -698,7 +734,7 @@ class PostProcessing:
                 else self.x_values
             )
             sobol_ = sobol[output]
-            if sobol_type == 'sobol':
+            if sobol_type == "sobol":
                 sobol_ = sobol_[0]
 
             # Compute quantiles
@@ -789,67 +825,59 @@ class PostProcessing:
             n_samples = samples.shape[0]
 
         # Evaluate the original and the surrogate model
-        if outputs is None:
-            y_val = self._eval_model(samples, key_str="valid")
-        else:
-            y_val = outputs
+        y_val = outputs
+        if y_val is None:
+            y_val, _ = self.engine.Model.run_model_parallel(samples, key_str="valid")
         y_pce_val, _ = self.engine.eval_metamodel(samples=samples)
 
         # Fit the data(train the model)
         for key in y_pce_val.keys():
-
-            y_pce_val_ = y_pce_val[key]
-            y_val_ = y_val[key]
-            residuals = y_val_ - y_pce_val_
+            residuals = y_val[key] - y_pce_val[key]
 
             # ------ Residuals vs. predicting variables ------
             # Check the assumptions of linearity and independence
-            fig1 = plt.figure()
             for i, par in enumerate(self.engine.ExpDesign.par_names):
-                plt.title(f"{key}: Residuals vs. {par}")
                 plt.scatter(x=samples[:, i], y=residuals, color="blue", edgecolor="k")
+                plt.title(f"{key}: Residuals vs. {par}")
                 plt.grid(True)
-                xmin, xmax = min(samples[:, i]), max(samples[:, i])
                 plt.hlines(
                     y=0,
-                    xmin=xmin * 0.9,
-                    xmax=xmax * 1.1,
+                    xmin=min(samples[:, i]) * 0.9,
+                    xmax=max(samples[:, i]) * 1.1,
                     color="red",
                     lw=3,
                     linestyle="--",
                 )
                 plt.xlabel(par)
                 plt.ylabel("Residuals")
-
-                # save the current figure
-                fig1.savefig(
-                    f"./{self.out_dir}/Residuals_vs_Par_{i+1}.{self.out_format}", bbox_inches="tight"
+                plt.savefig(
+                    f"./{self.out_dir}/Residuals_vs_Par_{i+1}.{self.out_format}",
+                    bbox_inches="tight",
                 )
-                # Destroy the current plot
                 plt.close()
 
             # ------ Fitted vs. residuals ------
             # Check the assumptions of linearity and independence
-            fig2 = plt.figure()
+            plt.scatter(x=y_pce_val[key], y=residuals, color="blue", edgecolor="k")
             plt.title(f"{key}: Residuals vs. fitted values")
-            plt.scatter(x=y_pce_val_, y=residuals, color="blue", edgecolor="k")
             plt.grid(True)
-            xmin, xmax = min(y_val_), max(y_val_)
             plt.hlines(
-                y=0, xmin=xmin * 0.9, xmax=xmax * 1.1, color="red", lw=3, linestyle="--"
+                y=0,
+                xmin=min(y_val[key]) * 0.9,
+                xmax=max(y_val[key]) * 1.1,
+                color="red",
+                lw=3,
+                linestyle="--",
             )
             plt.xlabel(key)
             plt.ylabel("Residuals")
-
-            # save the current figure
-            fig2.savefig(
-                f"./{self.out_dir}/Fitted_vs_Residuals.{self.out_format}", bbox_inches="tight"
+            plt.savefig(
+                f"./{self.out_dir}/Fitted_vs_Residuals.{self.out_format}",
+                bbox_inches="tight",
             )
-            # Destroy the current plot
             plt.close()
 
             # ------ Histogram of normalized residuals ------
-            fig3 = plt.figure()
             resid_pearson = residuals / (max(residuals) - min(residuals))
             plt.hist(resid_pearson, bins=20, edgecolor="k")
             plt.ylabel("Count")
@@ -868,12 +896,10 @@ class PostProcessing:
             )
             at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
             ax.add_artist(at)
-
-            # save the current figure
-            fig3.savefig(
-                f"./{self.out_dir}/Hist_NormResiduals.{self.out_format}", bbox_inches="tight"
+            plt.savefig(
+                f"./{self.out_dir}/Hist_NormResiduals.{self.out_format}",
+                bbox_inches="tight",
             )
-            # Destroy the current plot
             plt.close()
 
             # ------ Q-Q plot of the normalized residuals ------
@@ -885,12 +911,10 @@ class PostProcessing:
             plt.ylabel("Sample quantiles")
             plt.title(f"{key}: Q-Q plot of normalized residuals")
             plt.grid(True)
-
-            # save the current figure
             plt.savefig(
-                f"./{self.out_dir}/QQPlot_NormResiduals.{self.out_format}", bbox_inches="tight"
+                f"./{self.out_dir}/QQPlot_NormResiduals.{self.out_format}",
+                bbox_inches="tight",
             )
-            # Destroy the current plot
             plt.close()
 
     # -------------------------------------------------------------------------
@@ -922,10 +946,7 @@ class PostProcessing:
         samples = np.sort(np.sort(samples, axis=1), axis=0)
         mean, _ = self.engine.eval_metamodel(samples=samples)
 
-        if self.engine.emulator:
-            title = "MetaModel"
-        else:
-            title = "Model"
+        title = "MetaModel" if self.engine.emulator else "Model"
         x, y = np.meshgrid(samples[:, 0], samples[:, 1])
         for name in self.engine.out_names:
             for t in range(mean[name].shape[1]):
@@ -946,8 +967,6 @@ class PostProcessing:
                 ax.set_zlabel("$f(x_1,x_2)$")
 
                 plt.grid()
-
-                # save the figure to file
                 fig.savefig(
                     f"./{self.out_dir}/3DPlot_{title}_{name}{t}.{self.out_format}",
                     bbox_inches="tight",
@@ -976,29 +995,6 @@ class PostProcessing:
         )
         return samples
 
-    # -------------------------------------------------------------------------
-    def _eval_model(self, samples, key_str="Valid"):
-        """
-        Evaluates Forward Model on the given samples.
-
-        Parameters
-        ----------
-        samples : array of shape (n_samples, n_params)
-            Samples to evaluate the model at.
-        key_str : str, optional
-            Key string pass to the model. The default is 'Valid'.
-
-        Returns
-        -------
-        model_outs : dict
-            Dictionary of results.
-
-        """
-        # samples = self._get_sample()
-        model_outs, _ = self.engine.Model.run_model_parallel(samples, key_str=key_str)
-
-        return model_outs
-
     # -------------------------------------------------------------------------
     def _plot_validation(self):
         """
@@ -1042,7 +1038,7 @@ class PostProcessing:
             for key, value in metamod._coeffs_dict["b_1"][key].items():
                 length_list.append(len(value))
             n_predictors = min(length_list)
-    
+
             n_samples = y_pce_val[key].shape[0]
 
             r_2 = r2_score(y_pce_val[key], y_val[key])
@@ -1152,6 +1148,7 @@ class PostProcessing:
             plt.grid()
             key = key.replace(" ", "_")
             fig.savefig(
-                f"./{self.out_dir}/Model_vs_PCEModel_{key}.{self.out_format}", bbox_inches="tight"
+                f"./{self.out_dir}/Model_vs_PCEModel_{key}.{self.out_format}",
+                bbox_inches="tight",
             )
             plt.close()
-- 
GitLab