From 29e986875219a0c69aad9a7dfd2bf8084abda901 Mon Sep 17 00:00:00 2001 From: Farid Mohammadi <farid.mohammadi@iws.uni-stuttgart.de> Date: Thu, 19 May 2022 10:01:20 +0200 Subject: [PATCH] [pylink] update run_forward_parallel --- src/bayesvalidrox/pylink/pylink.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/bayesvalidrox/pylink/pylink.py b/src/bayesvalidrox/pylink/pylink.py index 62b169c8a..ca07df101 100644 --- a/src/bayesvalidrox/pylink/pylink.py +++ b/src/bayesvalidrox/pylink/pylink.py @@ -402,7 +402,7 @@ class PyLinkForwardModel(object): # ------------------------------------------------------------------------- def run_model_parallel(self, c_points, prevRun_No=0, key_str='', - mp=True): + mp=True, verbose=True): """ Runs model simulations. If mp is true (default), then the simulations are started in parallel. @@ -418,6 +418,8 @@ class PyLinkForwardModel(object): A descriptive string for validation runs. The default is `''`. mp : bool, optional Multiprocessing. The default is `True`. + verbose: bool, optional + Verbosity. The default is `True`. Returns ------- @@ -448,7 +450,7 @@ class PyLinkForwardModel(object): n_cpus = self.n_cpus # Run forward model - if n_c_points == 1: + if n_c_points == 1 or not mp: if self.link_type.lower() == 'function': group_results = Function(c_points)[np.newaxis] else: @@ -456,9 +458,9 @@ class PyLinkForwardModel(object): c_points, prevRun_No+1, key_str )[np.newaxis] - elif self.multi_process: + elif self.multi_process or mp: with multiprocessing.Pool(n_cpus) as p: - desc = f'Running forward model {key_str}' + if self.link_type.lower() == 'function': imap_var = p.imap(Function, c_points[:, np.newaxis]) else: @@ -467,8 +469,12 @@ class PyLinkForwardModel(object): [key_str]*n_c_points) imap_var = p.imap(self.run_forwardmodel, args) - group_results = list(tqdm.tqdm(imap_var, total=n_c_points, - desc=desc)) + if verbose: + desc = f'Running forward model {key_str}' + group_results = list(tqdm.tqdm(imap_var, total=n_c_points, + desc=desc)) + else: + group_results = list(imap_var) # Check for NaN for varIdx, var in enumerate(self.Output.names): -- GitLab