From 7b6a822430b39a03b0bcbfe7712ad7643211ea36 Mon Sep 17 00:00:00 2001
From: kohlhaasrebecca <rebecca.kohlhaas@outlook.com>
Date: Mon, 9 Dec 2024 17:53:13 +0100
Subject: [PATCH] More linting and reformatting with black

---
 .../post_processing/post_processing.py        | 36 +++++++++----------
 1 file changed, 17 insertions(+), 19 deletions(-)

diff --git a/src/bayesvalidrox/post_processing/post_processing.py b/src/bayesvalidrox/post_processing/post_processing.py
index 3776743e4..4d311fe4e 100644
--- a/src/bayesvalidrox/post_processing/post_processing.py
+++ b/src/bayesvalidrox/post_processing/post_processing.py
@@ -36,7 +36,7 @@ class PostProcessing:
         Name of the PostProcessing object to be used for saving the generated files.
         The default is 'calib'.
     out_dir : string
-        Output directory in which the PostProcessing results are placed. 
+        Output directory in which the PostProcessing results are placed.
         The results are contained in a subfolder '/Outputs_PostProcessing_name'
         The default is ''.
     out_format : string
@@ -80,7 +80,7 @@ class PostProcessing:
         self.stds = None
 
     # -------------------------------------------------------------------------
-    def plot_moments(self, plot_type: str = 'line'):
+    def plot_moments(self, plot_type: str = "line"):
         """
         Plots the moments in a user defined output format (standard is pdf) in the directory
         `Outputs_PostProcessing`.
@@ -104,8 +104,8 @@ class PostProcessing:
             Standard deviation of the model outputs.
 
         """
-        if plot_type not in ['bar', 'line']:
-            raise AttributeError('The wanted plot-type is not supported.')
+        if plot_type not in ["bar", "line"]:
+            raise AttributeError("The wanted plot-type is not supported.")
         bar_plot = bool(plot_type == "bar")
         meta_model_type = self.engine.MetaModel.meta_model_type
 
@@ -304,10 +304,7 @@ class PostProcessing:
         # Loop over the keys and compute RMSE error.
         for key in self.engine.out_names:
             # Root mena square
-            self.rmse[key] = root_mean_squared_error(
-                outputs[key],
-                metamod_outputs[key]
-            )
+            self.rmse[key] = root_mean_squared_error(outputs[key], metamod_outputs[key])
             # Validation error
             self.valid_error[key] = (self.rmse[key] ** 2) / np.var(
                 outputs[key], ddof=1, axis=0
@@ -377,7 +374,7 @@ class PostProcessing:
 
         # Plot the evolution of the diagnostic criteria of the
         # Sequential Experimental Design.
-        
+
         for plotidx, plot in enumerate(plot_list):
             fig, ax = plt.subplots()
             seq_dict = seq_list[plotidx]
@@ -627,7 +624,7 @@ class PostProcessing:
                 np.save(f"./{newpath}/seq_{plot_name}.npy", seq_values)
 
     # -------------------------------------------------------------------------
-    def sobol_indices(self, plot_type: str = 'line', save: bool = True):
+    def sobol_indices(self, plot_type: str = "line", save: bool = True):
         """
         Visualizes and writes out Sobol' and Total Sobol' indices of the trained metamodel.
         One file is created for each index and output key.
@@ -663,9 +660,9 @@ class PostProcessing:
             raise AttributeError("Sobol indices only support PCE-type models!")
         if metamod.meta_model_type.lower() not in ["pce", "apce"]:
             raise AttributeError("Sobol indices only support PCE-type models!")
-        
-        if plot_type not in ['line', 'bar']:
-            raise AttributeError('The wanted plot type is not supported.')
+
+        if plot_type not in ["line", "bar"]:
+            raise AttributeError("The wanted plot type is not supported.")
 
         # Extract the necessary variables
         max_order = np.max(metamod._pce_deg)
@@ -846,7 +843,7 @@ class PostProcessing:
             )
         else:
             n_samples = samples.shape[0]
-        
+
         # Evaluate the original and the surrogate model
         y_val = outputs
         if y_val is None:
@@ -995,6 +992,7 @@ class PostProcessing:
                     bbox_inches="tight",
                 )
                 plt.close(fig)
+
     # -------------------------------------------------------------------------
     def _plot_validation_multi(self, out_mean, out_std, model_out):
         """
@@ -1018,15 +1016,15 @@ class PostProcessing:
         # List of markers and colors
         color = cycle((["b", "g", "r", "y", "k"]))
         marker = cycle(("x", "d", "+", "o", "*"))
-        metamod_name = self.engine.MetaModel.meta_model_type.lower() 
-        
+        metamod_name = self.engine.MetaModel.meta_model_type.lower()
+
         # Plot the model vs PCE model
         fig = plt.figure()
         for _, key in enumerate(self.engine.out_names):
             y_val = out_mean[key]
             y_val_std = out_std[key]
             y_val = model_out[key]
-            
+
             for idx in range(y_val.shape[0]):
                 plt.plot(
                     self.x_values,
@@ -1041,7 +1039,7 @@ class PostProcessing:
                     color=next(color),
                     marker=next(marker),
                     linestyle="--",
-                    label="$Y_{{{}}}^{{{}}}$".format(idx + 1, metamod_name)
+                    label="$Y_{{{}}}^{{{}}}$".format(idx + 1, metamod_name),
                 )
                 plt.fill_between(
                     self.x_values,
@@ -1067,6 +1065,6 @@ class PostProcessing:
             key = key.replace(" ", "_")
             fig.savefig(
                 f"./{self.out_dir}/Model_vs_{metamod_name}Model_{key}.{self.out_format}",
-                bbox_inches="tight"
+                bbox_inches="tight",
             )
             plt.close()
-- 
GitLab