From 7b92b89de7927e073bec938c32e8c619244a5edf Mon Sep 17 00:00:00 2001
From: kohlhaasrebecca <rebecca.kohlhaas@outlook.com>
Date: Thu, 14 Nov 2024 11:52:42 +0100
Subject: [PATCH] More linting and reformatting of PostProcessing

---
 .../post_processing/post_processing.py        | 628 +++++++++++-------
 1 file changed, 383 insertions(+), 245 deletions(-)

diff --git a/src/bayesvalidrox/post_processing/post_processing.py b/src/bayesvalidrox/post_processing/post_processing.py
index b9e73ff23..c3b833b42 100644
--- a/src/bayesvalidrox/post_processing/post_processing.py
+++ b/src/bayesvalidrox/post_processing/post_processing.py
@@ -15,9 +15,9 @@ import matplotlib.pyplot as plt
 from matplotlib import ticker
 from matplotlib.offsetbox import AnchoredText
 from matplotlib.patches import Patch
+
 # Load the mplstyle
-plt.style.use(os.path.join(os.path.split(__file__)[0],
-                           '../', 'bayesvalidrox.mplstyle'))
+plt.style.use(os.path.join(os.path.split(__file__)[0], "../", "bayesvalidrox.mplstyle"))
 
 
 class PostProcessing:
@@ -37,21 +37,21 @@ class PostProcessing:
 
     """
 
-    def __init__(self, engine, name='calib', out_dir=''):
+    def __init__(self, engine, name="calib", out_dir=""):
         self.engine = engine
         self.name = name
         self.par_names = self.engine.ExpDesign.par_names
         self.x_values = self.engine.ExpDesign.x_values
 
-        self.out_dir = f'./{out_dir}/Outputs_PostProcessing_{self.name}/'
+        self.out_dir = f"./{out_dir}/Outputs_PostProcessing_{self.name}/"
 
         # Open a pdf for the plots
         if not os.path.exists(self.out_dir):
             os.makedirs(self.out_dir)
 
         # Initialize attributes
-        self.plot_type = ''
-        self.xlabel = 'Time [s]'
+        self.plot_type = ""
+        self.xlabel = "Time [s]"
         self.mc_reference = None
         self.sobol = None
         self.totalsobol = None
@@ -87,11 +87,11 @@ class PostProcessing:
 
         """
 
-        bar_plot = True if plot_type == 'bar' else False
+        bar_plot = bool(plot_type == "bar")
         meta_model_type = self.engine.MetaModel.meta_model_type
 
         # Read Monte-Carlo reference
-        self.mc_reference = self.engine.Model.read_observation('mc_ref')
+        self.mc_reference = self.engine.Model.read_observation("mc_ref")
 
         # Compute the moments with the PCEModel object
         self.means, self.stds = self.engine.MetaModel.calculate_moments()
@@ -107,31 +107,65 @@ class PostProcessing:
 
             # Plot: bar plot or line plot
             if bar_plot:
-                ax[0].bar(list(map(str, self.x_values)), mean_data, color='b',
-                          width=0.25)
-                ax[1].bar(list(map(str, self.x_values)), std_data, color='b',
-                          width=0.25)
+                ax[0].bar(
+                    list(map(str, self.x_values)), mean_data, color="b", width=0.25
+                )
+                ax[1].bar(
+                    list(map(str, self.x_values)), std_data, color="b", width=0.25
+                )
                 ax[0].legend(labels=[meta_model_type])
                 ax[1].legend(labels=[meta_model_type])
             else:
-                ax[0].plot(self.x_values, mean_data, lw=3, color='k', marker='x',
-                           label=meta_model_type)
-                ax[1].plot(self.x_values, std_data, lw=3, color='k', marker='x',
-                           label=meta_model_type)
+                ax[0].plot(
+                    self.x_values,
+                    mean_data,
+                    lw=3,
+                    color="k",
+                    marker="x",
+                    label=meta_model_type,
+                )
+                ax[1].plot(
+                    self.x_values,
+                    std_data,
+                    lw=3,
+                    color="k",
+                    marker="x",
+                    label=meta_model_type,
+                )
 
             if self.mc_reference is not None:
                 if bar_plot:
-                    ax[0].bar(list(map(str, self.x_values)), self.mc_reference['mean'],
-                              color='r', width=0.25)
-                    ax[1].bar(list(map(str, self.x_values)), self.mc_reference['std'],
-                              color='r', width=0.25)
+                    ax[0].bar(
+                        list(map(str, self.x_values)),
+                        self.mc_reference["mean"],
+                        color="r",
+                        width=0.25,
+                    )
+                    ax[1].bar(
+                        list(map(str, self.x_values)),
+                        self.mc_reference["std"],
+                        color="r",
+                        width=0.25,
+                    )
                     ax[0].legend(labels=[meta_model_type])
                     ax[1].legend(labels=[meta_model_type])
                 else:
-                    ax[0].plot(self.x_values, self.mc_reference['mean'], lw=3, marker='x',
-                               color='r', label='Ref.')
-                    ax[1].plot(self.x_values, self.mc_reference['std'], lw=3, marker='x',
-                               color='r', label='Ref.')
+                    ax[0].plot(
+                        self.x_values,
+                        self.mc_reference["mean"],
+                        lw=3,
+                        marker="x",
+                        color="r",
+                        label="Ref.",
+                    )
+                    ax[1].plot(
+                        self.x_values,
+                        self.mc_reference["std"],
+                        lw=3,
+                        marker="x",
+                        color="r",
+                        label="Ref.",
+                    )
 
             # Label the axes and provide a title
             ax[0].set_xlabel(self.xlabel)
@@ -140,20 +174,17 @@ class PostProcessing:
             ax[1].set_ylabel(key)
 
             # Provide a title
-            ax[0].set_title('Mean of ' + key)
-            ax[1].set_title('Std of ' + key)
+            ax[0].set_title("Mean of " + key)
+            ax[1].set_title("Std of " + key)
 
             if not bar_plot:
-                ax[0].legend(loc='best')
-                ax[1].legend(loc='best')
+                ax[0].legend(loc="best")
+                ax[1].legend(loc="best")
 
             plt.tight_layout()
 
             # save the current figure
-            fig.savefig(
-                f'{self.out_dir}Mean_Std_PCE_{key}.pdf',
-                bbox_inches='tight'
-            )
+            fig.savefig(f"{self.out_dir}Mean_Std_PCE_{key}.pdf", bbox_inches="tight")
 
         return self.means, self.stds
 
@@ -181,15 +212,11 @@ class PostProcessing:
         else:
             n_samples = samples.shape[0]
 
-        # Extract x_values
-        x_values = self.engine.ExpDesign.x_values
-
         if model_out_dict is not None:
             self.model_out_dict = model_out_dict
         else:
-            self.model_out_dict = self._eval_model(samples, key_str='valid')
-        self.pce_out_mean, self.pce_out_std = self.engine.eval_metamodel(
-            samples)
+            self.model_out_dict = self._eval_model(samples, key_str="valid")
+        self.pce_out_mean, self.pce_out_std = self.engine.eval_metamodel(samples)
 
         try:
             key = self.engine.out_names[1]
@@ -205,11 +232,13 @@ class PostProcessing:
 
             # Zip the subdirectories
             self.engine.Model.zip_subdirs(
-                f'{self.engine.Model.name}valid', f'{self.engine.Model.name}valid_')
+                f"{self.engine.Model.name}valid", f"{self.engine.Model.name}valid_"
+            )
 
             # Zip the subdirectories
             self.engine.Model.zip_subdirs(
-                f'{self.engine.Model.name}valid', f'{self.engine.Model.name}valid_')
+                f"{self.engine.Model.name}valid", f"{self.engine.Model.name}valid_"
+            )
 
     # -------------------------------------------------------------------------
     def check_accuracy(self, n_samples=None, samples=None, outputs=None) -> None:
@@ -235,15 +264,20 @@ class PostProcessing:
         """
         # Set the number of samples
         if samples is None and n_samples is None:
-            raise AttributeError("Please provide either samples or pass the number"
-                            " of samples!")
+            raise AttributeError(
+                "Please provide either samples or pass the number of samples!"
+            )
         n_samples = samples.shape[0] if samples is not None else n_samples
 
         # Generate random samples if necessary
         samples = self._get_sample(n_samples) if samples is None else samples
 
         # Run the original model with the generated samples
-        outputs = self._eval_model(samples, key_str='validSet') if outputs is None else outputs
+        outputs = (
+            self._eval_model(samples, key_str="validSet")
+            if outputs is None
+            else outputs
+        )
 
         # Run the PCE model with the generated samples
         metamod_outputs, _ = self.engine.eval_metamodel(samples)
@@ -253,20 +287,29 @@ class PostProcessing:
         # Loop over the keys and compute RMSE error.
         for key in self.engine.out_names:
             # Root mena square
-            self.rmse[key] = mean_squared_error(outputs[key], metamod_outputs[key],
-                                                squared=False,
-                                                multioutput='raw_values')
+            self.rmse[key] = mean_squared_error(
+                outputs[key],
+                metamod_outputs[key],
+                squared=False,
+                multioutput="raw_values",
+            )
             # Validation error
-            self.valid_error[key] = (self.rmse[key]**2) / \
-                np.var(outputs[key], ddof=1, axis=0)
+            self.valid_error[key] = (self.rmse[key] ** 2) / np.var(
+                outputs[key], ddof=1, axis=0
+            )
 
             # Print a report table
             print(f"\n>>>>> Errors of {key} <<<<<")
             print("\nIndex  |  RMSE   |  Validation Error")
-            print('-'*35)
-            print('\n'.join(f'{i+1}  |  {k:.3e}  |  {j:.3e}' for i, (k, j)
-                            in enumerate(zip(self.rmse[key],
-                                             self.valid_error[key]))))
+            print("-" * 35)
+            print(
+                "\n".join(
+                    f"{i+1}  |  {k:.3e}  |  {j:.3e}"
+                    for i, (k, j) in enumerate(
+                        zip(self.rmse[key], self.valid_error[key])
+                    )
+                )
+            )
         # Save error dicts in PCEModel object
         self.engine.MetaModel.rmse = self.rmse
         self.engine.MetaModel.valid_error = self.valid_error
@@ -287,18 +330,31 @@ class PostProcessing:
         n_init_samples = engine.ExpDesign.n_init_samples
         n_total_samples = engine.ExpDesign.X.shape[0]
 
-        newpath = f'Outputs_PostProcessing_{self.name}/seq_design_diagnostics/'
+        newpath = f"Outputs_PostProcessing_{self.name}/seq_design_diagnostics/"
         if not os.path.exists(newpath):
             os.makedirs(newpath)
 
-        plot_list = ['Modified LOO error', 'Validation error', 'KLD', 'BME',
-                    'RMSEMean', 'RMSEStd', 'Hellinger distance']
-        seq_list = [engine.SeqModifiedLOO, engine.seqValidError,
-                   engine.SeqKLD, engine.SeqBME, engine.seqRMSEMean,
-                   engine.seqRMSEStd, engine.SeqDistHellinger]
-
-        markers = ('x', 'o', 'd', '*', '+')
-        colors = ('k', 'darkgreen', 'b', 'navy', 'darkred')
+        plot_list = [
+            "Modified LOO error",
+            "Validation error",
+            "KLD",
+            "BME",
+            "RMSEMean",
+            "RMSEStd",
+            "Hellinger distance",
+        ]
+        seq_list = [
+            engine.SeqModifiedLOO,
+            engine.seqValidError,
+            engine.SeqKLD,
+            engine.SeqBME,
+            engine.seqRMSEMean,
+            engine.seqRMSEStd,
+            engine.SeqDistHellinger,
+        ]
+
+        markers = ("x", "o", "d", "*", "+")
+        colors = ("k", "darkgreen", "b", "navy", "darkred")
 
         # Plot the evolution of the diagnostic criteria of the
         # Sequential Experimental Design.
@@ -322,27 +378,41 @@ class PostProcessing:
 
                 for util in util_funcs:
                     sorted_seq = {}
-                    n_runs = min(seq_dict[f'{util}_rep_{i + 1}'].shape[0] for i in range(n_reps))
+                    n_runs = min(
+                        seq_dict[f"{util}_rep_{i + 1}"].shape[0] for i in range(n_reps)
+                    )
 
                     for run_idx in range(n_runs):
                         values = []
                         for key in seq_dict.keys():
                             if util in key:
                                 values.append(seq_dict[key][run_idx].mean())
-                        sorted_seq['SeqItr_'+str(run_idx)] = np.array(values)
+                        sorted_seq["SeqItr_" + str(run_idx)] = np.array(values)
                     sorted_seq_opt[util] = sorted_seq
 
                 # BoxPlot
                 def draw_plot(data, labels, edge_color, fill_color, idx):
-                    pos = labels - (idx-1)
-                    bp = plt.boxplot(data, positions=pos, labels=labels,
-                                     patch_artist=True, sym='', widths=0.75)
-                    elements = ['boxes', 'whiskers', 'fliers', 'means',
-                                'medians', 'caps']
+                    pos = labels - (idx - 1)
+                    bp = plt.boxplot(
+                        data,
+                        positions=pos,
+                        labels=labels,
+                        patch_artist=True,
+                        sym="",
+                        widths=0.75,
+                    )
+                    elements = [
+                        "boxes",
+                        "whiskers",
+                        "fliers",
+                        "means",
+                        "medians",
+                        "caps",
+                    ]
                     for element in elements:
                         plt.setp(bp[element], color=edge_color[idx])
 
-                    for patch in bp['boxes']:
+                    for patch in bp["boxes"]:
                         patch.set(facecolor=fill_color[idx])
 
                 if engine.ExpDesign.n_new_samples != 1:
@@ -351,8 +421,8 @@ class PostProcessing:
                 else:
                     step1 = 5
                     step2 = 5
-                edge_color = ['red', 'blue', 'green']
-                fill_color = ['tan', 'cyan', 'lightgreen']
+                edge_color = ["red", "blue", "green"]
+                fill_color = ["tan", "cyan", "lightgreen"]
                 plot_label = plot
                 # Plot for different Utility Functions
                 for idx, util in enumerate(util_funcs):
@@ -363,66 +433,68 @@ class PostProcessing:
                         all_errors = np.hstack((all_errors, errors))
 
                     # Special cases for BME and KLD
-                    if plot in['KLD', 'BME']:
+                    if plot in ["KLD", "BME"]:
                         # BME convergence if refBME is provided
                         if ref_bme_kld is not None:
                             ref_value = None
-                            if plot == 'BME':
+                            if plot == "BME":
                                 ref_value = ref_bme_kld[0]
-                                plot_label = r'BME/BME$^{Ref.}$'
-                            if plot == 'KLD':
+                                plot_label = r"BME/BME$^{Ref.}$"
+                            if plot == "KLD":
                                 ref_value = ref_bme_kld[1]
-                                plot_label = '$D_{KL}[p(\\theta|y_*),p(\\theta)]'\
-                                    ' / D_{KL}^{Ref.}[p(\\theta|y_*), '\
-                                    'p(\\theta)]$'
+                                plot_label = (
+                                    "$D_{KL}[p(\\theta|y_*),p(\\theta)]"
+                                    " / D_{KL}^{Ref.}[p(\\theta|y_*), "
+                                    "p(\\theta)]$"
+                                )
 
                             # Difference between BME/KLD and the ref. values
-                            all_errors = np.divide(all_errors,
-                                                   np.full((all_errors.shape),
-                                                           ref_value))
+                            all_errors = np.divide(
+                                all_errors, np.full((all_errors.shape), ref_value)
+                            )
 
                             # Plot baseline for zero, i.e. no difference
-                            plt.axhline(y=1.0, xmin=0, xmax=1, c='green',
-                                        ls='--', lw=2)
+                            plt.axhline(y=1.0, xmin=0, xmax=1, c="green", ls="--", lw=2)
 
                     # Plot each UtilFuncs
-                    labels = np.arange(
-                        n_init_samples, n_total_samples+1, step1)
-                    draw_plot(all_errors[:, ::step2], labels, edge_color,
-                              fill_color, idx)
+                    labels = np.arange(n_init_samples, n_total_samples + 1, step1)
+                    draw_plot(
+                        all_errors[:, ::step2], labels, edge_color, fill_color, idx
+                    )
 
                 plt.xticks(labels, labels)
                 # Set the major and minor locators
                 ax.xaxis.set_major_locator(ticker.AutoLocator())
                 ax.xaxis.set_minor_locator(ticker.AutoMinorLocator())
-                ax.xaxis.grid(True, which='major', linestyle='-')
-                ax.xaxis.grid(True, which='minor', linestyle='--')
+                ax.xaxis.grid(True, which="major", linestyle="-")
+                ax.xaxis.grid(True, which="minor", linestyle="--")
 
                 # Legend
                 legend_elements = []
                 for idx, util in enumerate(util_funcs):
-                    legend_elements.append(Patch(facecolor=fill_color[idx],
-                                                 edgecolor=edge_color[idx],
-                                                 label=util))
-                plt.legend(handles=legend_elements[::-1], loc='best')
+                    legend_elements.append(
+                        Patch(
+                            facecolor=fill_color[idx],
+                            edgecolor=edge_color[idx],
+                            label=util,
+                        )
+                    )
+                plt.legend(handles=legend_elements[::-1], loc="best")
 
-                if plot not in ['BME','KLD']:
-                    plt.yscale('log')
+                if plot not in ["BME", "KLD"]:
+                    plt.yscale("log")
                 plt.autoscale(True)
-                plt.xlabel('\\# of training samples')
+                plt.xlabel("\\# of training samples")
                 plt.ylabel(plot_label)
                 plt.title(plot)
 
                 # save the current figure
-                plot_name = plot.replace(' ', '_')
-                fig.savefig(
-                    f'./{newpath}/seq_{plot_name}.pdf',
-                    bbox_inches='tight'
-                )
+                plot_name = plot.replace(" ", "_")
+                fig.savefig(f"./{newpath}/seq_{plot_name}.pdf", bbox_inches="tight")
                 # Destroy the current plot
                 plt.close()
                 # Save arrays into files
-                f = open(f'./{newpath}/seq_{plot_name}.txt', 'w')
+                f = open(f"./{newpath}/seq_{plot_name}.txt", "w")
                 f.write(str(sorted_seq_opt))
                 f.close()
             else:
@@ -432,45 +504,58 @@ class PostProcessing:
                         step = engine.ExpDesign.n_new_samples
                     else:
                         step = 1
-                    x_idx = np.arange(n_init_samples, n_total_samples+1, step)
+                    x_idx = np.arange(n_init_samples, n_total_samples + 1, step)
                     if n_total_samples not in x_idx:
                         x_idx = np.hstack((x_idx, n_total_samples))
 
-                    if plot in ['KLD', 'BME']:
+                    if plot in ["KLD", "BME"]:
                         # BME convergence if refBME is provided
                         if ref_bme_kld is not None:
-                            if plot == 'BME':
+                            if plot == "BME":
                                 ref_value = ref_bme_kld[0]
-                                plot_label = r'BME/BME$^{Ref.}$'
-                            if plot == 'KLD':
+                                plot_label = r"BME/BME$^{Ref.}$"
+                            if plot == "KLD":
                                 ref_value = ref_bme_kld[1]
-                                plot_label = '$D_{KL}[p(\\theta|y_*),p(\\theta)]'\
-                                    ' / D_{KL}^{Ref.}[p(\\theta|y_*), '\
-                                    'p(\\theta)]$'
+                                plot_label = (
+                                    "$D_{KL}[p(\\theta|y_*),p(\\theta)]"
+                                    " / D_{KL}^{Ref.}[p(\\theta|y_*), "
+                                    "p(\\theta)]$"
+                                )
 
                             # Difference between BME/KLD and the ref. values
-                            values = np.divide(seq_values,
-                                               np.full((seq_values.shape),
-                                                       ref_value))
+                            values = np.divide(
+                                seq_values, np.full((seq_values.shape), ref_value)
+                            )
 
                             # Plot baseline for zero, i.e. no difference
-                            plt.axhline(y=1.0, xmin=0, xmax=1, c='green',
-                                        ls='--', lw=2)
+                            plt.axhline(y=1.0, xmin=0, xmax=1, c="green", ls="--", lw=2)
 
                             # Set the limits
                             plt.ylim([1e-1, 1e1])
 
                             # Create the plots
-                            plt.semilogy(x_idx, values, marker=markers[idx],
-                                         color=colors[idx], ls='--', lw=2,
-                                         label=name.split("_rep", 1)[0])
+                            plt.semilogy(
+                                x_idx,
+                                values,
+                                marker=markers[idx],
+                                color=colors[idx],
+                                ls="--",
+                                lw=2,
+                                label=name.split("_rep", 1)[0],
+                            )
                         else:
                             plot_label = plot
 
                             # Create the plots
-                            plt.plot(x_idx, seq_values, marker=markers[idx],
-                                     color=colors[idx], ls='--', lw=2,
-                                     label=name.split("_rep", 1)[0])
+                            plt.plot(
+                                x_idx,
+                                seq_values,
+                                marker=markers[idx],
+                                color=colors[idx],
+                                ls="--",
+                                lw=2,
+                                label=name.split("_rep", 1)[0],
+                            )
 
                     else:
                         plot_label = plot
@@ -479,41 +564,44 @@ class PostProcessing:
                         # Plot the error evolution for each output
                         # print(x_idx.shape)
                         # print(seq_values.mean(axis=1).shape)
-                        plt.semilogy(x_idx, seq_values.mean(axis=1),
-                                     marker=markers[idx], ls='--', lw=2,
-                                     color=colors[idx],
-                                     label=name.split("_rep", 1)[0])
+                        plt.semilogy(
+                            x_idx,
+                            seq_values.mean(axis=1),
+                            marker=markers[idx],
+                            ls="--",
+                            lw=2,
+                            color=colors[idx],
+                            label=name.split("_rep", 1)[0],
+                        )
 
                 # Set the major and minor locators
                 ax.xaxis.set_major_locator(ticker.AutoLocator())
                 ax.xaxis.set_minor_locator(ticker.AutoMinorLocator())
-                ax.xaxis.grid(True, which='major', linestyle='-')
-                ax.xaxis.grid(True, which='minor', linestyle='--')
-
-                ax.tick_params(axis='both', which='major', direction='in',
-                               width=3, length=10)
-                ax.tick_params(axis='both', which='minor', direction='in',
-                               width=2, length=8)
-                plt.xlabel('Number of runs')
+                ax.xaxis.grid(True, which="major", linestyle="-")
+                ax.xaxis.grid(True, which="minor", linestyle="--")
+
+                ax.tick_params(
+                    axis="both", which="major", direction="in", width=3, length=10
+                )
+                ax.tick_params(
+                    axis="both", which="minor", direction="in", width=2, length=8
+                )
+                plt.xlabel("Number of runs")
                 plt.ylabel(plot_label)
                 plt.title(plot)
                 plt.legend(frameon=True)
 
                 # save the current figure
-                plot_name = plot.replace(' ', '_')
-                fig.savefig(
-                    f'./{newpath}/seq_{plot_name}.pdf',
-                    bbox_inches='tight'
-                )
+                plot_name = plot.replace(" ", "_")
+                fig.savefig(f"./{newpath}/seq_{plot_name}.pdf", bbox_inches="tight")
                 # Destroy the current plot
                 plt.close()
 
                 # ---------------- Saving arrays into files ---------------
-                np.save(f'./{newpath}/seq_{plot_name}.npy', seq_values)
-
+                np.save(f"./{newpath}/seq_{plot_name}.npy", seq_values)
 
     # -------------------------------------------------------------------------
-    def sobol_indices(self, plot_type: str = None, save:bool=True):
+    def sobol_indices(self, plot_type: str = None, save: bool = True):
         """
         Visualizes and writes out Sobol' and Total Sobol' indices of the trained metamodel.
         One file is created for each index and output key.
@@ -537,10 +625,10 @@ class PostProcessing:
         """
         # This function currently only supports PCE/aPCE
         metamod = self.engine.MetaModel
-        if not hasattr(metamod, 'meta_model_type'):
-            raise AttributeError('Sobol indices only support PCE-type models!')
-        if metamod.meta_model_type.lower() not in ['pce', 'apce']:
-            raise AttributeError('Sobol indices only support PCE-type models!')
+        if not hasattr(metamod, "meta_model_type"):
+            raise AttributeError("Sobol indices only support PCE-type models!")
+        if metamod.meta_model_type.lower() not in ["pce", "apce"]:
+            raise AttributeError("Sobol indices only support PCE-type models!")
 
         # Extract the necessary variables
         max_order = np.max(metamod._pce_deg)
@@ -550,6 +638,7 @@ class PostProcessing:
         sobol_all, total_sobol_all = metamod.sobol, metamod.total_sobol
         self.sobol, self.totalsobol = sobol_all, total_sobol_all
 
+        # TODO: move this to the PCE class?
         # Save indices
         if save:
             for _, output in enumerate(outputs):
@@ -587,26 +676,34 @@ class PostProcessing:
         # Plot Sobol' indices
         self.plot_type = plot_type
         for i_order in range(1, max_order + 1):
-            par_names_i = list(combinations(self.par_names, i_order)) if (i_order !=1) else self.par_names
-            self.plot_sobol(par_names_i, outputs, sobol_type = 'sobol', i_order = i_order)
-        self.plot_sobol(self.par_names, outputs, sobol_type = 'totalsobol')
+            par_names_i = (
+                list(combinations(self.par_names, i_order))
+                if (i_order != 1)
+                else self.par_names
+            )
+            self.plot_sobol(par_names_i, outputs, sobol_type="sobol", i_order=i_order)
+        self.plot_sobol(self.par_names, outputs, sobol_type="totalsobol")
 
         return sobol_all, total_sobol_all
 
-    def plot_sobol(self, outputs, par_names, sobol_type = 'sobol', i_order = 0):
+    def plot_sobol(self, outputs, par_names, sobol_type="sobol", i_order=0):
         """
         Generate plots for each output in the given set of Sobol' indices.
-        
+
         """
         sobol = None
-        if sobol_type == 'sobol':
+        if sobol_type == "sobol":
             sobol = self.sobol[i_order]
-        if sobol_type == 'totalsobol':
+        if sobol_type == "totalsobol":
             sobol = self.totalsobol
 
         fig = plt.figure()
         for _, output in enumerate(outputs):
-            x = self.x_values[output] if isinstance(self.x_values, dict) else self.x_values
+            x = (
+                self.x_values[output]
+                if isinstance(self.x_values, dict)
+                else self.x_values
+            )
             sobol_ = sobol[output][0]
 
             # Compute quantiles
@@ -628,9 +725,9 @@ class PostProcessing:
                     colormap="Dark2",
                     yerr=q_97_5 - q_5,
                 )
-                if sobol_type =='sobol':
+                if sobol_type == "sobol":
                     ax.set_ylabel("Sobol indices, $S^T$")
-                elif sobol_type == 'totalsobol':
+                elif sobol_type == "totalsobol":
                     ax.set_ylabel("Total Sobol indices, $S^T$")
 
             else:
@@ -644,28 +741,31 @@ class PostProcessing:
                     )
                     plt.fill_between(x, q_5[i], q_97_5[i], alpha=0.15)
 
-                if sobol_type =='sobol':
+                if sobol_type == "sobol":
                     ax.set_ylabel("Sobol indices, $S^T$")
-                elif sobol_type == 'totalsobol':
+                elif sobol_type == "totalsobol":
                     ax.set_ylabel("Total Sobol indices, $S^T$")
                 plt.xlabel(self.xlabel)
                 plt.legend(loc="best", frameon=True)
 
-            if sobol_type == 'sobol':
+            if sobol_type == "sobol":
                 plt.title(f"{i_order} order Sobol' indices of {output}")
                 fig.savefig(
                     f"{self.out_dir}Sobol_indices_{i_order}_{output}.pdf",
                     bbox_inches="tight",
                 )
-            elif sobol_type == 'totalsobol':
+            elif sobol_type == "totalsobol":
                 plt.title(f"Total Sobol' indices of {output}")
                 fig.savefig(
-                    f"{self.out_dir}TotalSobol_indices_{output}.pdf", bbox_inches="tight"
+                    f"{self.out_dir}TotalSobol_indices_{output}.pdf",
+                    bbox_inches="tight",
                 )
             plt.clf()
 
     # -------------------------------------------------------------------------
-    def check_reg_quality(self, n_samples: int = 1000, samples=None, outputs: dict = None) -> None:
+    def check_reg_quality(
+        self, n_samples: int = 1000, samples=None, outputs: dict = None
+    ) -> None:
         """
         Checks the quality of the metamodel for single output models based on:
         https://towardsdatascience.com/how-do-you-check-the-quality-of-your-regression-model-in-python-fa61759ff685
@@ -680,7 +780,7 @@ class PostProcessing:
             Output dictionary with model outputs for all given output types in
             `engine.out_names`. The default is None.
 
-        Return 
+        Return
         ------
         None
 
@@ -692,7 +792,7 @@ class PostProcessing:
 
         # Evaluate the original and the surrogate model
         if outputs is None:
-            y_val = self._eval_model(samples, key_str='valid')
+            y_val = self._eval_model(samples, key_str="valid")
         else:
             y_val = outputs
         y_pce_val, _ = self.engine.eval_metamodel(samples=samples)
@@ -709,18 +809,24 @@ class PostProcessing:
             fig1 = plt.figure()
             for i, par in enumerate(self.engine.ExpDesign.par_names):
                 plt.title(f"{key}: Residuals vs. {par}")
-                plt.scatter(
-                    x=samples[:, i], y=residuals, color='blue', edgecolor='k')
+                plt.scatter(x=samples[:, i], y=residuals, color="blue", edgecolor="k")
                 plt.grid(True)
                 xmin, xmax = min(samples[:, i]), max(samples[:, i])
-                plt.hlines(y=0, xmin=xmin*0.9, xmax=xmax*1.1, color='red',
-                           lw=3, linestyle='--')
+                plt.hlines(
+                    y=0,
+                    xmin=xmin * 0.9,
+                    xmax=xmax * 1.1,
+                    color="red",
+                    lw=3,
+                    linestyle="--",
+                )
                 plt.xlabel(par)
-                plt.ylabel('Residuals')
+                plt.ylabel("Residuals")
 
                 # save the current figure
-                fig1.savefig(f'./{self.out_dir}/Residuals_vs_Par_{i+1}.pdf',
-                             bbox_inches='tight')
+                fig1.savefig(
+                    f"./{self.out_dir}/Residuals_vs_Par_{i+1}.pdf", bbox_inches="tight"
+                )
                 # Destroy the current plot
                 plt.close()
 
@@ -728,26 +834,28 @@ class PostProcessing:
             # Check the assumptions of linearity and independence
             fig2 = plt.figure()
             plt.title(f"{key}: Residuals vs. fitted values")
-            plt.scatter(x=y_pce_val_, y=residuals, color='blue', edgecolor='k')
+            plt.scatter(x=y_pce_val_, y=residuals, color="blue", edgecolor="k")
             plt.grid(True)
             xmin, xmax = min(y_val_), max(y_val_)
-            plt.hlines(y=0, xmin=xmin*0.9, xmax=xmax*1.1, color='red', lw=3,
-                       linestyle='--')
+            plt.hlines(
+                y=0, xmin=xmin * 0.9, xmax=xmax * 1.1, color="red", lw=3, linestyle="--"
+            )
             plt.xlabel(key)
-            plt.ylabel('Residuals')
+            plt.ylabel("Residuals")
 
             # save the current figure
-            fig2.savefig(f'./{self.out_dir}/Fitted_vs_Residuals.pdf',
-                         bbox_inches='tight')
+            fig2.savefig(
+                f"./{self.out_dir}/Fitted_vs_Residuals.pdf", bbox_inches="tight"
+            )
             # Destroy the current plot
             plt.close()
 
             # ------ Histogram of normalized residuals ------
             fig3 = plt.figure()
-            resid_pearson = residuals / (max(residuals)-min(residuals))
-            plt.hist(resid_pearson, bins=20, edgecolor='k')
-            plt.ylabel('Count')
-            plt.xlabel('Normalized residuals')
+            resid_pearson = residuals / (max(residuals) - min(residuals))
+            plt.hist(resid_pearson, bins=20, edgecolor="k")
+            plt.ylabel("Count")
+            plt.xlabel("Normalized residuals")
             plt.title(f"{key}: Histogram of normalized residuals")
 
             # Normality (Shapiro-Wilk) test of the residuals
@@ -757,14 +865,16 @@ class PostProcessing:
                 ann_text = "The residuals seem to come from a Gaussian Process."
             else:
                 ann_text = "The normality assumption may not hold."
-            at = AnchoredText(ann_text, prop={'size':30}, frameon=True,
-                              loc='upper left')
+            at = AnchoredText(
+                ann_text, prop={"size": 30}, frameon=True, loc="upper left"
+            )
             at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
             ax.add_artist(at)
 
             # save the current figure
-            fig3.savefig(f'./{self.out_dir}/Hist_NormResiduals.pdf',
-                         bbox_inches='tight')
+            fig3.savefig(
+                f"./{self.out_dir}/Hist_NormResiduals.pdf", bbox_inches="tight"
+            )
             # Destroy the current plot
             plt.close()
 
@@ -779,15 +889,16 @@ class PostProcessing:
             plt.grid(True)
 
             # save the current figure
-            plt.savefig(f'./{self.out_dir}/QQPlot_NormResiduals.pdf',
-                        bbox_inches='tight')
+            plt.savefig(
+                f"./{self.out_dir}/QQPlot_NormResiduals.pdf", bbox_inches="tight"
+            )
             # Destroy the current plot
             plt.close()
 
     # -------------------------------------------------------------------------
     def plot_metamodel_3d(self, n_samples=10):
         """
-        Visualize the results of a PCE MetaModel as a 3D surface over two input 
+        Visualize the results of a PCE MetaModel as a 3D surface over two input
         parameters.
 
         Parameters
@@ -807,32 +918,42 @@ class PostProcessing:
         """
         if self.engine.ExpDesign.ndim != 2:
             raise AttributeError(
-                'This function is only applicable if the MetaModel input dimension is 2.')
+                "This function is only applicable if the MetaModel input dimension is 2."
+            )
         samples = self.engine.ExpDesign.generate_samples(n_samples)
         samples = np.sort(np.sort(samples, axis=1), axis=0)
         mean, _ = self.engine.eval_metamodel(samples=samples)
 
         if self.engine.emulator:
-            title = 'MetaModel'
+            title = "MetaModel"
         else:
-            title = 'Model'
+            title = "Model"
         x, y = np.meshgrid(samples[:, 0], samples[:, 1])
         for name in self.engine.out_names:
             for t in range(mean[name].shape[1]):
                 fig = plt.figure()
-                ax = plt.axes(projection='3d')
-                ax.plot_surface(x, y, np.atleast_2d(mean[name][:, t]), rstride=1, cstride=1,
-                                cmap='viridis', edgecolor='none')
+                ax = plt.axes(projection="3d")
+                ax.plot_surface(
+                    x,
+                    y,
+                    np.atleast_2d(mean[name][:, t]),
+                    rstride=1,
+                    cstride=1,
+                    cmap="viridis",
+                    edgecolor="none",
+                )
                 ax.set_title(title)
-                ax.set_xlabel('$x_1$')
-                ax.set_ylabel('$x_2$')
-                ax.set_zlabel('$f(x_1,x_2)$')
+                ax.set_xlabel("$x_1$")
+                ax.set_ylabel("$x_2$")
+                ax.set_zlabel("$f(x_1,x_2)$")
 
                 plt.grid()
 
                 # save the figure to file
-                fig.savefig(f'./{self.out_dir}/3DPlot_{title}_{name}{t}.pdf',
-                            bbox_inches='tight')
+                fig.savefig(
+                    f"./{self.out_dir}/3DPlot_{title}_{name}{t}.pdf",
+                    bbox_inches="tight",
+                )
                 plt.close(fig)
 
     # -------------------------------------------------------------------------
@@ -853,12 +974,12 @@ class PostProcessing:
 
         """
         samples = self.engine.ExpDesign.generate_samples(
-            n_samples,
-            sampling_method='random')
+            n_samples, sampling_method="random"
+        )
         return samples
 
     # -------------------------------------------------------------------------
-    def _eval_model(self, samples, key_str='Valid'):
+    def _eval_model(self, samples, key_str="Valid"):
         """
         Evaluates Forward Model on the given samples.
 
@@ -875,9 +996,8 @@ class PostProcessing:
             Dictionary of results.
 
         """
-        #samples = self._get_sample()
-        model_outs, _ = self.engine.Model.run_model_parallel(
-            samples, key_str=key_str)
+        # samples = self._get_sample()
+        model_outs, _ = self.engine.Model.run_model_parallel(samples, key_str=key_str)
 
         return model_outs
 
@@ -894,12 +1014,10 @@ class PostProcessing:
         """
         # This function currently only supports PCE/aPCE
         metamod = self.engine.MetaModel
-        if not hasattr(metamod, 'meta_model_type'):
-            raise AttributeError(
-                'This evaluation only support PCE-type models!')
-        if metamod.meta_model_type.lower() not in ['pce', 'apce']:
-            raise AttributeError(
-                'This evaluation only support PCE-type models!')
+        if not hasattr(metamod, "meta_model_type"):
+            raise AttributeError("This evaluation only support PCE-type models!")
+        if metamod.meta_model_type.lower() not in ["pce", "apce"]:
+            raise AttributeError("This evaluation only support PCE-type models!")
 
         # get the samples
         y_pce_val = self.pce_out_mean
@@ -915,34 +1033,39 @@ class PostProcessing:
             x_new = np.linspace(np.min(y_pce_val[key]), np.max(y_val[key]), 100)
             y_predicted = regression_model.predict(x_new[:, np.newaxis])
 
-            plt.scatter(y_pce_val[key], y_val[key], color='gold', linewidth=2)
-            plt.plot(x_new, y_predicted, color='k')
+            plt.scatter(y_pce_val[key], y_val[key], color="gold", linewidth=2)
+            plt.plot(x_new, y_predicted, color="k")
 
             # Calculate the adjusted R_squared and RMSE
             # the total number of explanatory variables in the model
             # (not including the constant term)
+            # TODO: this should work without PCE-specific values
             length_list = []
-            for key, value in metamod._coeffs_dict['b_1'][key].items():
+            for key, value in metamod._coeffs_dict["b_1"][key].items():
                 length_list.append(len(value))
             n_predictors = min(length_list)
             n_samples = y_pce_val[key].shape[0]
 
             r_2 = r2_score(y_pce_val[key], y_val[key])
-            adj_r2 = 1 - (1 - r_2) * (n_samples - 1) / \
-                (n_samples - n_predictors - 1)
+            adj_r2 = 1 - (1 - r_2) * (n_samples - 1) / (n_samples - n_predictors - 1)
             rmse = mean_squared_error(y_pce_val[key], y_val[key], squared=False)
 
-            plt.annotate(f'RMSE = {rmse:.3f}\n Adjusted $R^2$ = {adj_r2:.3f}',
-                         xy=(0.05, 0.85), xycoords='axes fraction')
+            plt.annotate(
+                f"RMSE = {rmse:.3f}\n Adjusted $R^2$ = {adj_r2:.3f}",
+                xy=(0.05, 0.85),
+                xycoords="axes fraction",
+            )
 
             plt.ylabel("Original Model")
             plt.xlabel("PCE Model")
             plt.grid()
 
             # save the current figure
-            plot_name = key.replace(' ', '_')
-            fig.savefig(f'./{self.out_dir}/Model_vs_PCEModel_{plot_name}.pdf',
-                        bbox_inches='tight')
+            plot_name = key.replace(" ", "_")
+            fig.savefig(
+                f"./{self.out_dir}/Model_vs_PCEModel_{plot_name}.pdf",
+                bbox_inches="tight",
+            )
 
             # Destroy the current plot
             plt.close()
@@ -956,7 +1079,7 @@ class PostProcessing:
         Parameters
         ----------
         x_values : list or array
-            List of x values. 
+            List of x values.
         x_axis : str, optional
             Label of the x axis. The default is "x [m]".
 
@@ -974,16 +1097,14 @@ class PostProcessing:
 
         """
         # This function currently only supports PCE/aPCE
-        if not hasattr(self.engine.MetaModel, 'meta_model_type'):
-            raise AttributeError(
-                'This evaluation only support PCE-type models!')
-        if self.engine.MetaModel.meta_model_type.lower() not in ['pce', 'apce']:
-            raise AttributeError(
-                'This evaluation only support PCE-type models!')
+        if not hasattr(self.engine.MetaModel, "meta_model_type"):
+            raise AttributeError("This evaluation only support PCE-type models!")
+        if self.engine.MetaModel.meta_model_type.lower() not in ["pce", "apce"]:
+            raise AttributeError("This evaluation only support PCE-type models!")
 
         # List of markers and colors
-        color = cycle((['b', 'g', 'r', 'y', 'k']))
-        marker = cycle(('x', 'd', '+', 'o', '*'))
+        color = cycle((["b", "g", "r", "y", "k"]))
+        marker = cycle(("x", "d", "+", "o", "*"))
 
         # Plot the model vs PCE model
         fig = plt.figure()
@@ -993,28 +1114,45 @@ class PostProcessing:
             y_val = self.model_out_dict[key]
 
             for idx in range(y_val.shape[0]):
-                plt.plot(self.x_values, y_val[idx], color=next(color), marker=next(marker),
-                         label='$Y_{%s}^M$' % (idx+1))
+                plt.plot(
+                    self.x_values,
+                    y_val[idx],
+                    color=next(color),
+                    marker=next(marker),
+                    label="$Y_{%s}^M$" % (idx + 1),
+                )
 
-                plt.plot(self.x_values, y_pce_val[idx], color=next(color), marker=next(marker),
-                         linestyle='--',
-                         label='$Y_{%s}^{PCE}$' % (idx+1))
-                plt.fill_between(self.x_values, y_pce_val[idx]-1.96*y_pce_val_std[idx],
-                                 y_pce_val[idx]+1.96*y_pce_val_std[idx],
-                                 color=next(color), alpha=0.15)
+                plt.plot(
+                    self.x_values,
+                    y_pce_val[idx],
+                    color=next(color),
+                    marker=next(marker),
+                    linestyle="--",
+                    label="$Y_{%s}^{PCE}$" % (idx + 1),
+                )
+                plt.fill_between(
+                    self.x_values,
+                    y_pce_val[idx] - 1.96 * y_pce_val_std[idx],
+                    y_pce_val[idx] + 1.96 * y_pce_val_std[idx],
+                    color=next(color),
+                    alpha=0.15,
+                )
 
             # Calculate the RMSE
             rmse = mean_squared_error(y_pce_val, y_val, squared=False)
-            r_2 = r2_score(y_pce_val[idx].reshape(-1, 1),
-                          y_val[idx].reshape(-1, 1))
+            r_2 = r2_score(y_pce_val[idx].reshape(-1, 1), y_val[idx].reshape(-1, 1))
 
-            plt.annotate(f'RMSE = {rmse:.3f}\n $R^2$ = {r_2:.3f}',
-                         xy=(0.85, 0.1), xycoords='axes fraction')
+            plt.annotate(
+                f"RMSE = {rmse:.3f}\n $R^2$ = {r_2:.3f}",
+                xy=(0.85, 0.1),
+                xycoords="axes fraction",
+            )
             plt.ylabel(key)
             plt.xlabel(self.xlabel)
-            plt.legend(loc='best')
+            plt.legend(loc="best")
             plt.grid()
-            key = key.replace(' ', '_')
-            fig.savefig(f'./{self.out_dir}/Model_vs_PCEModel_{key}.pdf',
-                        bbox_inches='tight')
+            key = key.replace(" ", "_")
+            fig.savefig(
+                f"./{self.out_dir}/Model_vs_PCEModel_{key}.pdf", bbox_inches="tight"
+            )
             plt.close()
-- 
GitLab