From 260b4d80e0c6bdd899027aba486eb6d0b9adb20b Mon Sep 17 00:00:00 2001
From: kohlhaasrebecca <rebecca.kohlhaas@outlook.com>
Date: Thu, 27 Jun 2024 17:46:39 +0200
Subject: [PATCH] Fix for the BMC issue

---
 .../bayes_inference/bayes_inference.py        |  6 ++-
 .../bayes_inference/bayes_model_comparison.py | 54 ++++++++++---------
 2 files changed, 32 insertions(+), 28 deletions(-)

diff --git a/src/bayesvalidrox/bayes_inference/bayes_inference.py b/src/bayesvalidrox/bayes_inference/bayes_inference.py
index ade1662c9..e6748a878 100644
--- a/src/bayesvalidrox/bayes_inference/bayes_inference.py
+++ b/src/bayesvalidrox/bayes_inference/bayes_inference.py
@@ -226,8 +226,6 @@ class BayesInference:
         self.log_BME = None
         self.KLD = None
         self.__mean_pce_prior_pred = None
-        if perturbed_data is None:
-            perturbed_data = []
         self.engine = engine
         self.Discrepancy = discrepancy
         self.emulator = emulator
@@ -268,6 +266,9 @@ class BayesInference:
         self.__model_prior_pred = None
         self.MCMC_Obj = None
 
+        # Empty perturbed data init
+        if self.perturbed_data is None:
+            self.perturbed_data = []
         # System settings
         if os.name == 'nt':
             print('')
@@ -315,6 +316,7 @@ class BayesInference:
         # Convert measured_data to a data frame
         if not isinstance(self.measured_data, pd.DataFrame):
             self.measured_data = pd.DataFrame(self.measured_data)
+            
 
         # Extract the total number of measurement points
         if self.name.lower() == 'calib':
diff --git a/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py b/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py
index 3cc0918e4..dc66e1cf5 100644
--- a/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py
+++ b/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py
@@ -254,37 +254,39 @@ class BayesModelComparison:
         
         # Generate data
         # TODO: generate the datset only if it does not exist yet
+        # TODO: shape of this is still ok
         self.just_data = self.generate_dataset(
             model_dict, True, n_bootstrap=self.n_bootstrap)
 
         # Run inference for each model if this is not available
-        if self.just_bayes_dict is None:
-            self.just_bayes_dict = {}
-            for model in model_dict.keys():
-                print("-"*20)
-                print("Bayesian inference of {}.\n".format(model))
-                BayesOpts = BayesInference(model_dict[model])
-                    
-                # Set BayesInference options
-                for key, value in opts_dict.items():
-                    if key in BayesOpts.__dict__.keys():
-                        if key == "Discrepancy" and isinstance(value, dict):
-                            setattr(BayesOpts, key, value[model])
-                        else:
-                            setattr(BayesOpts, key, value)
-    
-                # Pass justifiability data as perturbed data
-                BayesOpts.bmc = True
-                BayesOpts.emulator= self.emulator
-                BayesOpts.just_analysis = True
-                BayesOpts.perturbed_data = self.just_data
-    
-                self.just_bayes_dict[model] = BayesOpts.create_inference()
-                print("-"*20)
+        #if self.just_bayes_dict is None:
+        self.just_bayes_dict = {}
+        for model in model_dict.keys():
+            print("-"*20)
+            print("Bayesian inference of {}.\n".format(model))
+            BayesOpts = BayesInference(model_dict[model])
+                
+            # Set BayesInference options
+            for key, value in opts_dict.items():
+                if key in BayesOpts.__dict__.keys():
+                    if key == "Discrepancy" and isinstance(value, dict):
+                        setattr(BayesOpts, key, value[model])
+                    else:
+                        setattr(BayesOpts, key, value)
+
+            # Pass justifiability data as perturbed data
+            BayesOpts.bmc = True
+            BayesOpts.emulator= self.emulator
+            BayesOpts.just_analysis = True
+            BayesOpts.perturbed_data = self.just_data
+
+            self.just_bayes_dict[model] = BayesOpts.create_inference()
+            print("-"*20)
 
         # Compute model weights
+        # TODO: shape of this now ok as well
         self.BME_dict = dict()
-        for modelName, bayesObj in self.bayes_dict.items():
+        for modelName, bayesObj in self.just_bayes_dict.items():
             self.BME_dict[modelName] = np.exp(bayesObj.log_BME, dtype=self.dtype)
 
         # BME correction in BayesInference class
@@ -293,7 +295,7 @@ class BayesModelComparison:
 
         # Split the model weights and save in a dict
         list_ModelWeights = np.split(
-            just_model_weights, self.model_weights.shape[1]/self.n_meas, axis=1)
+            just_model_weights, just_model_weights.shape[1]/self.n_meas, axis=1)
         self.just_model_weights_dict = {key: weights for key, weights in
                               zip(model_names, list_ModelWeights)}
         
@@ -458,7 +460,7 @@ class BayesModelComparison:
         return model_weights
 
     # -------------------------------------------------------------------------
-    def plot_just_analysis(self, model_weights_dict):
+    def plot_just_analysis(self):
         """
         Visualizes the confusion matrix and the model wights for the
         justifiability analysis.
-- 
GitLab