From 29e986875219a0c69aad9a7dfd2bf8084abda901 Mon Sep 17 00:00:00 2001
From: Farid Mohammadi <farid.mohammadi@iws.uni-stuttgart.de>
Date: Thu, 19 May 2022 10:01:20 +0200
Subject: [PATCH] [pylink] update run_forward_parallel

---
 src/bayesvalidrox/pylink/pylink.py | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/src/bayesvalidrox/pylink/pylink.py b/src/bayesvalidrox/pylink/pylink.py
index 62b169c8a..ca07df101 100644
--- a/src/bayesvalidrox/pylink/pylink.py
+++ b/src/bayesvalidrox/pylink/pylink.py
@@ -402,7 +402,7 @@ class PyLinkForwardModel(object):
 
     # -------------------------------------------------------------------------
     def run_model_parallel(self, c_points, prevRun_No=0, key_str='',
-                           mp=True):
+                           mp=True, verbose=True):
         """
         Runs model simulations. If mp is true (default), then the simulations
          are started in parallel.
@@ -418,6 +418,8 @@ class PyLinkForwardModel(object):
             A descriptive string for validation runs. The default is `''`.
         mp : bool, optional
             Multiprocessing. The default is `True`.
+        verbose: bool, optional
+            Verbosity. The default is `True`.
 
         Returns
         -------
@@ -448,7 +450,7 @@ class PyLinkForwardModel(object):
             n_cpus = self.n_cpus
 
         # Run forward model
-        if n_c_points == 1:
+        if n_c_points == 1 or not mp:
             if self.link_type.lower() == 'function':
                 group_results = Function(c_points)[np.newaxis]
             else:
@@ -456,9 +458,9 @@ class PyLinkForwardModel(object):
                     c_points, prevRun_No+1, key_str
                     )[np.newaxis]
 
-        elif self.multi_process:
+        elif self.multi_process or mp:
             with multiprocessing.Pool(n_cpus) as p:
-                desc = f'Running forward model {key_str}'
+
                 if self.link_type.lower() == 'function':
                     imap_var = p.imap(Function, c_points[:, np.newaxis])
                 else:
@@ -467,8 +469,12 @@ class PyLinkForwardModel(object):
                                [key_str]*n_c_points)
                     imap_var = p.imap(self.run_forwardmodel, args)
 
-                group_results = list(tqdm.tqdm(imap_var, total=n_c_points,
-                                               desc=desc))
+                if verbose:
+                    desc = f'Running forward model {key_str}'
+                    group_results = list(tqdm.tqdm(imap_var, total=n_c_points,
+                                                   desc=desc))
+                else:
+                    group_results = list(imap_var)
 
         # Check for NaN
         for varIdx, var in enumerate(self.Output.names):
-- 
GitLab