From 260b4d80e0c6bdd899027aba486eb6d0b9adb20b Mon Sep 17 00:00:00 2001 From: kohlhaasrebecca <rebecca.kohlhaas@outlook.com> Date: Thu, 27 Jun 2024 17:46:39 +0200 Subject: [PATCH] Fix for the BMC issue --- .../bayes_inference/bayes_inference.py | 6 ++- .../bayes_inference/bayes_model_comparison.py | 54 ++++++++++--------- 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/src/bayesvalidrox/bayes_inference/bayes_inference.py b/src/bayesvalidrox/bayes_inference/bayes_inference.py index ade1662c9..e6748a878 100644 --- a/src/bayesvalidrox/bayes_inference/bayes_inference.py +++ b/src/bayesvalidrox/bayes_inference/bayes_inference.py @@ -226,8 +226,6 @@ class BayesInference: self.log_BME = None self.KLD = None self.__mean_pce_prior_pred = None - if perturbed_data is None: - perturbed_data = [] self.engine = engine self.Discrepancy = discrepancy self.emulator = emulator @@ -268,6 +266,9 @@ class BayesInference: self.__model_prior_pred = None self.MCMC_Obj = None + # Empty perturbed data init + if self.perturbed_data is None: + self.perturbed_data = [] # System settings if os.name == 'nt': print('') @@ -315,6 +316,7 @@ class BayesInference: # Convert measured_data to a data frame if not isinstance(self.measured_data, pd.DataFrame): self.measured_data = pd.DataFrame(self.measured_data) + # Extract the total number of measurement points if self.name.lower() == 'calib': diff --git a/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py b/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py index 3cc0918e4..dc66e1cf5 100644 --- a/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py +++ b/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py @@ -254,37 +254,39 @@ class BayesModelComparison: # Generate data # TODO: generate the datset only if it does not exist yet + # TODO: shape of this is still ok self.just_data = self.generate_dataset( model_dict, True, n_bootstrap=self.n_bootstrap) # Run inference for each model if this is not available - if self.just_bayes_dict is None: - self.just_bayes_dict = {} - for model in model_dict.keys(): - print("-"*20) - print("Bayesian inference of {}.\n".format(model)) - BayesOpts = BayesInference(model_dict[model]) - - # Set BayesInference options - for key, value in opts_dict.items(): - if key in BayesOpts.__dict__.keys(): - if key == "Discrepancy" and isinstance(value, dict): - setattr(BayesOpts, key, value[model]) - else: - setattr(BayesOpts, key, value) - - # Pass justifiability data as perturbed data - BayesOpts.bmc = True - BayesOpts.emulator= self.emulator - BayesOpts.just_analysis = True - BayesOpts.perturbed_data = self.just_data - - self.just_bayes_dict[model] = BayesOpts.create_inference() - print("-"*20) + #if self.just_bayes_dict is None: + self.just_bayes_dict = {} + for model in model_dict.keys(): + print("-"*20) + print("Bayesian inference of {}.\n".format(model)) + BayesOpts = BayesInference(model_dict[model]) + + # Set BayesInference options + for key, value in opts_dict.items(): + if key in BayesOpts.__dict__.keys(): + if key == "Discrepancy" and isinstance(value, dict): + setattr(BayesOpts, key, value[model]) + else: + setattr(BayesOpts, key, value) + + # Pass justifiability data as perturbed data + BayesOpts.bmc = True + BayesOpts.emulator= self.emulator + BayesOpts.just_analysis = True + BayesOpts.perturbed_data = self.just_data + + self.just_bayes_dict[model] = BayesOpts.create_inference() + print("-"*20) # Compute model weights + # TODO: shape of this now ok as well self.BME_dict = dict() - for modelName, bayesObj in self.bayes_dict.items(): + for modelName, bayesObj in self.just_bayes_dict.items(): self.BME_dict[modelName] = np.exp(bayesObj.log_BME, dtype=self.dtype) # BME correction in BayesInference class @@ -293,7 +295,7 @@ class BayesModelComparison: # Split the model weights and save in a dict list_ModelWeights = np.split( - just_model_weights, self.model_weights.shape[1]/self.n_meas, axis=1) + just_model_weights, just_model_weights.shape[1]/self.n_meas, axis=1) self.just_model_weights_dict = {key: weights for key, weights in zip(model_names, list_ModelWeights)} @@ -458,7 +460,7 @@ class BayesModelComparison: return model_weights # ------------------------------------------------------------------------- - def plot_just_analysis(self, model_weights_dict): + def plot_just_analysis(self): """ Visualizes the confusion matrix and the model wights for the justifiability analysis. -- GitLab