From 7e4aecd095af2e988d99c2614a7a95314a5574fe Mon Sep 17 00:00:00 2001 From: Farid Mohammadi <farid.mohammadi@iws.uni-stuttgart.de> Date: Fri, 30 Sep 2022 15:50:50 +0200 Subject: [PATCH] fix bugs related to n_bootstrap. --- .../bayes_inference/bayes_model_comparison.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py b/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py index a12e493dd..693afc82d 100644 --- a/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py +++ b/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py @@ -38,7 +38,7 @@ class BayesModelComparison: ) output = { - 'Bayes factor': bayes_dict_bf, + 'Bayes objects BF': bayes_dict_bf, 'Model weights BF': model_weights_dict_bf } @@ -48,7 +48,7 @@ class BayesModelComparison: model_dict, opts_dict, justifiability=True ) - output['Justifiability analysis'] = bayes_dict_ja + output['Bayes objects JA'] = bayes_dict_ja output['Model weights JA'] = model_weights_dict_ja return output @@ -70,12 +70,11 @@ class BayesModelComparison: # ----- Generate data ----- # Find n_bootstrap - if not justifiability: + if self.perturbed_data is None: n_bootstarp = self.n_bootstarp else: - # find the smallest n_samples - n_bootstarp = min([len(MetaModel.ExpDesign.X) for MetaModel - in modelDict.values()]) + n_bootstarp = self.perturbed_data.shape[0] + # Create dataset justData = self.generate_dataset( modelDict, justifiability, n_bootstarp=n_bootstarp) @@ -106,7 +105,7 @@ class BayesModelComparison: # Compute model weights BME_Dict = dict() for modelName, bayesObj in bayesDict.items(): - BME_Dict[modelName] = np.exp(bayesObj.log_BME) + BME_Dict[modelName] = np.exp(bayesObj.log_BME, dtype=np.float128) # BME correction in BayesInference class model_weights = self.cal_modelWeight( @@ -169,8 +168,7 @@ class BayesModelComparison: # Use surrogate runs for data-generating process for key, metaModel in modelDict.items(): model_data = np.array( - [runs[key][out][i] for out in out_names] - ).reshape(y_data.shape) + [runs[key][out][i] for out in out_names]).reshape(y_data.shape) justData = np.vstack(( justData, np.tril(np.repeat(model_data, model_data.shape[1], axis=0)) -- GitLab