From c323ed966385f7df191605a9750303b115e071b1 Mon Sep 17 00:00:00 2001
From: Farid Mohammadi <farid.mohammadi@iws.uni-stuttgart.de>
Date: Fri, 18 Mar 2022 10:56:35 +0100
Subject: [PATCH] [bayesinference] add valid_metrics option to be passed as a
 list and parallize the inf_entropy caliculations.

---
 .../bayes_inference/bayes_inference.py        | 88 ++++++++++++-------
 1 file changed, 54 insertions(+), 34 deletions(-)

diff --git a/src/bayesvalidrox/bayes_inference/bayes_inference.py b/src/bayesvalidrox/bayes_inference/bayes_inference.py
index 0e82c5901..0b9965edf 100644
--- a/src/bayesvalidrox/bayes_inference/bayes_inference.py
+++ b/src/bayesvalidrox/bayes_inference/bayes_inference.py
@@ -12,6 +12,7 @@ import scipy.linalg as spla
 import seaborn as sns
 import corner
 import h5py
+import multiprocessing
 import gc
 from sklearn.metrics import mean_squared_error, r2_score
 from sklearn import preprocessing
@@ -100,6 +101,13 @@ class BayesInference:
         User defined perturbed data. The default is `[]`.
     bootstrap_noise : float, optional
         A noise level to perturb the data set. The default is `0.05`.
+    valid_metrics : list, optional
+        List of the validation metrics. The following metrics are supported:
+
+        1. log_BME : logarithm of the Bayesian model evidence
+        2. KLD : Kullback-Leibler Divergence
+        3. inf_entropy: Information entropy
+        The default is `['log_BME']`.
     plot_post_pred : bool, optional
         Plot posterior predictive plots. The default is `True`.
     plot_map_pred : bool, optional
@@ -120,7 +128,8 @@ class BayesInference:
                  selected_indices=None, samples=None, n_samples=500000,
                  measured_data=None, inference_method='rejection',
                  mcmc_params=None, bayes_loocv=False, n_bootstrap_itrs=1,
-                 perturbed_data=[], bootstrap_noise=0.05, plot_post_pred=True,
+                 perturbed_data=[], bootstrap_noise=0.05,
+                 valid_metrics=['BME'], plot_post_pred=True,
                  plot_map_pred=False, max_a_posteriori='mean',
                  corner_title_fmt='.3e'):
 
@@ -140,6 +149,7 @@ class BayesInference:
         self.bayes_loocv = bayes_loocv
         self.n_bootstrap_itrs = n_bootstrap_itrs
         self.bootstrap_noise = bootstrap_noise
+        self.valid_metrics = valid_metrics
         self.plot_post_pred = plot_post_pred
         self.plot_map_pred = plot_map_pred
         self.max_a_posteriori = max_a_posteriori
@@ -315,8 +325,9 @@ class BayesInference:
 
             # Start the likelihood-BME computations for the perturbed data
             for itr_idx, data in tqdm(
-                    enumerate(self.perturbed_data), ascii=True,
-                    desc="Boostraping the BME calculations"
+                    enumerate(self.perturbed_data),
+                    total=self.n_bootstrap_itrs,
+                    desc="Boostraping the BME calculations", ascii=True
                     ):
 
                 # ---------------- Likelihood calculation ----------------
@@ -358,54 +369,63 @@ class BayesInference:
                                       dtype=np.float128))
                     )
 
-                # Rejection Step
-                # Random numbers between 0 and 1
-                unif = np.random.rand(1, self.n_samples)[0]
-
-                # Reject the poorly performed prior
-                Likelihoods = np.exp(logLikelihoods[:, itr_idx],
-                                     dtype=np.float64)
-                accepted = (Likelihoods/np.max(Likelihoods)) >= unif
-                posterior = self.samples[accepted]
-
-                # Posterior-based expectation of likelihoods
-                postExpLikelihoods = np.mean(
-                    logLikelihoods[:, itr_idx][accepted]
-                    )
-
-                # Posterior-based expectation of prior densities
-                postExpPrior = np.mean(
-                    np.log([MetaModel.ExpDesign.JDist.pdf(posterior.T)])
-                    )
-
-                # Calculate Kullback-Leibler Divergence
-                KLD[itr_idx] = postExpLikelihoods - log_BME[itr_idx]
-
-                # Information Entropy based on Entropy paper Eq. 38
-                inf_entropy[itr_idx] = log_BME[itr_idx] - postExpPrior - \
-                    postExpLikelihoods
-
                 # TODO: BME correction when using Emulator
                 # if self.emulator:
                 #     BME_Corr[itr_idx] = self._corr_factor_BME(
                 #         data, total_sigma2, posterior
                 #         )
 
+                # Rejection Step
+                if 'kld' in list(map(str.lower, self.valid_metrics)) and\
+                   'inf_entropy' in list(map(str.lower, self.valid_metrics)):
+                    # Random numbers between 0 and 1
+                    unif = np.random.rand(1, self.n_samples)[0]
+
+                    # Reject the poorly performed prior
+                    Likelihoods = np.exp(logLikelihoods[:, itr_idx],
+                                         dtype=np.float64)
+                    accepted = (Likelihoods/np.max(Likelihoods)) >= unif
+                    posterior = self.samples[accepted]
+
+                    # Posterior-based expectation of likelihoods
+                    postExpLikelihoods = np.mean(
+                        logLikelihoods[:, itr_idx][accepted]
+                        )
+
+                    # Calculate Kullback-Leibler Divergence
+                    KLD[itr_idx] = postExpLikelihoods - log_BME[itr_idx]
+
+                # Posterior-based expectation of prior densities
+                if 'inf_entropy' in list(map(str.lower, self.valid_metrics)):
+                    n_thread = int(0.875 * multiprocessing.cpu_count())
+                    with multiprocessing.Pool(n_thread) as p:
+                        postExpPrior = np.mean(np.concatenate(
+                            p.map(
+                                MetaModel.ExpDesign.JDist.pdf,
+                                np.array_split(posterior.T, n_thread, axis=1))
+                            )
+                            )
+                    # Information Entropy based on Entropy paper Eq. 38
+                    inf_entropy[itr_idx] = log_BME[itr_idx] - postExpPrior - \
+                        postExpLikelihoods
+
                 # Clear memory
                 gc.collect(generation=2)
 
-            # ---------------- Store BME, Likelihoods for all ----------------
+            # ---------- Store metrics for perturbed data set ----------------
             # Likelihoods (Size: n_samples, n_bootstrap_itr)
             self.log_likes = logLikelihoods
 
             # BME (log), KLD, infEntropy (Size: 1,n_bootstrap_itr)
             self.log_BME = log_BME
-            self.KLD = KLD
-            self.inf_entropy = inf_entropy
-
             # TODO: BMECorrFactor (log) (Size: 1,n_bootstrap_itr)
             # if self.emulator: self.BMECorrFactor = BME_Corr
 
+            if 'kld' in list(map(str.lower, self.valid_metrics)):
+                self.KLD = KLD
+            if 'inf_entropy' in list(map(str.lower, self.valid_metrics)):
+                self.inf_entropy = inf_entropy
+
             # BME = BME + BMECorrFactor
             if self.emulator:
                 self.log_BME = self.log_BME  # + self.BMECorrFactor
-- 
GitLab