From 7a0c5e1edfd72d0e2ba11cbe27dc7f332bfbac26 Mon Sep 17 00:00:00 2001 From: Farid Mohammadi <farid.mohammadi@iws.uni-stuttgart.de> Date: Tue, 27 Sep 2022 18:45:52 +0200 Subject: [PATCH] [BayesInference] add option to pass a dictionary of Discrepancy objects. --- .../bayes_inference/bayes_inference.py | 30 +++++++++---------- .../bayes_inference/bayes_model_comparison.py | 17 +++++++++-- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/bayesvalidrox/bayes_inference/bayes_inference.py b/src/bayesvalidrox/bayes_inference/bayes_inference.py index 061283d3c..0ecbd33b7 100644 --- a/src/bayesvalidrox/bayes_inference/bayes_inference.py +++ b/src/bayesvalidrox/bayes_inference/bayes_inference.py @@ -735,9 +735,9 @@ class BayesInference: covMatrix += sigma_2 * np.eye(n_out) # Select the data points to compare - if self.selected_indices is not None: - indices = self.selected_indices - else: + try: + indices = self.selected_indices[output] + except: indices = list(range(n_out)) covMatrix = np.diag(covMatrix[indices, indices]) y_bar = y_bar[indices] @@ -1001,11 +1001,11 @@ class BayesInference: covMatrix = np.diag(tot_sigma2s) # Select the data points to compare - if self.selected_indices is not None: - indices = self.selected_indices - covMatrix = np.diag(covMatrix[indices, indices]) - else: + try: + indices = self.selected_indices[out] + except: indices = list(range(nout)) + covMatrix = np.diag(covMatrix[indices, indices]) # If sigma2 is not given, use given total_sigma2s if sigma2 is None: @@ -1045,11 +1045,11 @@ class BayesInference: # covMatrix = np.diag(sigma2 * total_sigma2s) # Select the data points to compare - if self.selected_indices is not None: + try: indices = self.selected_indices[out] - covMatrix = np.diag(covMatrix[indices, indices]) - else: + except: indices = list(range(nout)) + covMatrix = np.diag(covMatrix[indices, indices]) # Compute loglikelihood logliks[s_idx] = self._logpdf( @@ -1163,12 +1163,12 @@ class BayesInference: covMatrix = np.eye(len(y_m)) * 1/(2*np.pi) # Select the data points to compare - if self.selected_indices is not None: - indices = self.selected_indices - covMatrix = np.diag(covMatrix[indices, indices]) - covMatrix_data = np.diag(covMatrix_data[indices, indices]) - else: + try: + indices = self.selected_indices[out] + except: indices = list(range(nout)) + covMatrix = np.diag(covMatrix[indices, indices]) + covMatrix_data = np.diag(covMatrix_data[indices, indices]) # Compute likelilhood output vs data logLik_data[i] += self._logpdf( diff --git a/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py b/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py index 21a0fe302..a12e493dd 100644 --- a/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py +++ b/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py @@ -91,7 +91,10 @@ class BayesModelComparison: # Set BayesInference options for key, value in optsDict.items(): if key in BayesOpts.__dict__.keys(): - setattr(BayesOpts, key, value) + if key == "Discrepancy" and isinstance(value, dict): + setattr(BayesOpts, key, value[model]) + else: + setattr(BayesOpts, key, value) # Pass justifiability data as perturbed data BayesOpts.perturbed_data = justData @@ -153,6 +156,12 @@ class BayesModelComparison: if not justifiability: return self.perturbed_data + # Evaluate metamodel + runs = {} + for key, metaModel in modelDict.items(): + y_hat, _ = metaModel.eval_metamodel(nsamples=n_bootstarp) + runs[key] = y_hat + # Generate data for i in range(n_bootstarp): y_data = self.perturbed_data[i].reshape(1, -1) @@ -160,7 +169,8 @@ class BayesModelComparison: # Use surrogate runs for data-generating process for key, metaModel in modelDict.items(): model_data = np.array( - [metaModel.ExpDesign.Y[out][i] for out in out_names]) + [runs[key][out][i] for out in out_names] + ).reshape(y_data.shape) justData = np.vstack(( justData, np.tril(np.repeat(model_data, model_data.shape[1], axis=0)) @@ -277,7 +287,8 @@ class BayesModelComparison: plt.close() # Confusion matrix for some measurement points - for index in range(0, self.n_meas+1, self.just_n_meas): + epsilon = 1 if self.just_n_meas != 1 else 0 + for index in range(0, self.n_meas+epsilon, self.just_n_meas): weights = np.array( [ModelWeights_dict[key][:, index] for key in ModelWeights_dict] ) -- GitLab