Skip to content
Snippets Groups Projects
Commit 260b4d80 authored by kohlhaasrebecca's avatar kohlhaasrebecca
Browse files

Fix for the BMC issue

parent b5e57edd
No related branches found
No related tags found
3 merge requests!29Preparation for release 1.1.0: fixes and test for pages,!28Merge to circumvent issues,!26Fix for the BMC issue
Pipeline #45579 passed
...@@ -226,8 +226,6 @@ class BayesInference: ...@@ -226,8 +226,6 @@ class BayesInference:
self.log_BME = None self.log_BME = None
self.KLD = None self.KLD = None
self.__mean_pce_prior_pred = None self.__mean_pce_prior_pred = None
if perturbed_data is None:
perturbed_data = []
self.engine = engine self.engine = engine
self.Discrepancy = discrepancy self.Discrepancy = discrepancy
self.emulator = emulator self.emulator = emulator
...@@ -268,6 +266,9 @@ class BayesInference: ...@@ -268,6 +266,9 @@ class BayesInference:
self.__model_prior_pred = None self.__model_prior_pred = None
self.MCMC_Obj = None self.MCMC_Obj = None
# Empty perturbed data init
if self.perturbed_data is None:
self.perturbed_data = []
# System settings # System settings
if os.name == 'nt': if os.name == 'nt':
print('') print('')
...@@ -315,6 +316,7 @@ class BayesInference: ...@@ -315,6 +316,7 @@ class BayesInference:
# Convert measured_data to a data frame # Convert measured_data to a data frame
if not isinstance(self.measured_data, pd.DataFrame): if not isinstance(self.measured_data, pd.DataFrame):
self.measured_data = pd.DataFrame(self.measured_data) self.measured_data = pd.DataFrame(self.measured_data)
# Extract the total number of measurement points # Extract the total number of measurement points
if self.name.lower() == 'calib': if self.name.lower() == 'calib':
......
...@@ -254,37 +254,39 @@ class BayesModelComparison: ...@@ -254,37 +254,39 @@ class BayesModelComparison:
# Generate data # Generate data
# TODO: generate the datset only if it does not exist yet # TODO: generate the datset only if it does not exist yet
# TODO: shape of this is still ok
self.just_data = self.generate_dataset( self.just_data = self.generate_dataset(
model_dict, True, n_bootstrap=self.n_bootstrap) model_dict, True, n_bootstrap=self.n_bootstrap)
# Run inference for each model if this is not available # Run inference for each model if this is not available
if self.just_bayes_dict is None: #if self.just_bayes_dict is None:
self.just_bayes_dict = {} self.just_bayes_dict = {}
for model in model_dict.keys(): for model in model_dict.keys():
print("-"*20) print("-"*20)
print("Bayesian inference of {}.\n".format(model)) print("Bayesian inference of {}.\n".format(model))
BayesOpts = BayesInference(model_dict[model]) BayesOpts = BayesInference(model_dict[model])
# Set BayesInference options # Set BayesInference options
for key, value in opts_dict.items(): for key, value in opts_dict.items():
if key in BayesOpts.__dict__.keys(): if key in BayesOpts.__dict__.keys():
if key == "Discrepancy" and isinstance(value, dict): if key == "Discrepancy" and isinstance(value, dict):
setattr(BayesOpts, key, value[model]) setattr(BayesOpts, key, value[model])
else: else:
setattr(BayesOpts, key, value) setattr(BayesOpts, key, value)
# Pass justifiability data as perturbed data # Pass justifiability data as perturbed data
BayesOpts.bmc = True BayesOpts.bmc = True
BayesOpts.emulator= self.emulator BayesOpts.emulator= self.emulator
BayesOpts.just_analysis = True BayesOpts.just_analysis = True
BayesOpts.perturbed_data = self.just_data BayesOpts.perturbed_data = self.just_data
self.just_bayes_dict[model] = BayesOpts.create_inference() self.just_bayes_dict[model] = BayesOpts.create_inference()
print("-"*20) print("-"*20)
# Compute model weights # Compute model weights
# TODO: shape of this now ok as well
self.BME_dict = dict() self.BME_dict = dict()
for modelName, bayesObj in self.bayes_dict.items(): for modelName, bayesObj in self.just_bayes_dict.items():
self.BME_dict[modelName] = np.exp(bayesObj.log_BME, dtype=self.dtype) self.BME_dict[modelName] = np.exp(bayesObj.log_BME, dtype=self.dtype)
# BME correction in BayesInference class # BME correction in BayesInference class
...@@ -293,7 +295,7 @@ class BayesModelComparison: ...@@ -293,7 +295,7 @@ class BayesModelComparison:
# Split the model weights and save in a dict # Split the model weights and save in a dict
list_ModelWeights = np.split( list_ModelWeights = np.split(
just_model_weights, self.model_weights.shape[1]/self.n_meas, axis=1) just_model_weights, just_model_weights.shape[1]/self.n_meas, axis=1)
self.just_model_weights_dict = {key: weights for key, weights in self.just_model_weights_dict = {key: weights for key, weights in
zip(model_names, list_ModelWeights)} zip(model_names, list_ModelWeights)}
...@@ -458,7 +460,7 @@ class BayesModelComparison: ...@@ -458,7 +460,7 @@ class BayesModelComparison:
return model_weights return model_weights
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def plot_just_analysis(self, model_weights_dict): def plot_just_analysis(self):
""" """
Visualizes the confusion matrix and the model wights for the Visualizes the confusion matrix and the model wights for the
justifiability analysis. justifiability analysis.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment