From a231eb750b1187fc95e502ccccf06a52b356ce40 Mon Sep 17 00:00:00 2001 From: Farid Mohammadi <farid.mohammadi@iws.uni-stuttgart.de> Date: Fri, 12 Aug 2022 16:53:17 +0200 Subject: [PATCH] [surrogate] fix bugs in the PCE implementation of the Rosenblatt transformation for dependent variables. --- .../surrogate_models/exp_designs.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/bayesvalidrox/surrogate_models/exp_designs.py b/src/bayesvalidrox/surrogate_models/exp_designs.py index 8224053a0..2caa086e5 100644 --- a/src/bayesvalidrox/surrogate_models/exp_designs.py +++ b/src/bayesvalidrox/surrogate_models/exp_designs.py @@ -170,12 +170,6 @@ class ExpDesigns: samples = self.random_sampler(int(n_samples)).T return samples.T - # # Transform samples to the original space - # if transform: - # orig_samples = self.transform(samples.T) - # return orig_samples, samples.T - # else: - # return samples.T # ------------------------------------------------------------------------- def generate_ED(self, n_samples, sampling_method='random', transform=False, @@ -260,14 +254,14 @@ class ExpDesigns: # Transform samples to the original space if transform: - orig_samples = self.transform( + tr_samples = self.transform( samples, method=sampling_method ) - if sampling_method == 'user': - return samples, orig_samples + if sampling_method == 'user' or not self.apce: + return samples, tr_samples else: - return orig_samples, samples + return tr_samples, samples else: return samples @@ -327,8 +321,10 @@ class ExpDesigns: self.polycoeffs = {} for parIdx in tqdm(range(ndim), ascii=True, desc="Computing orth. polynomial coeffs"): - poly_coeffs = apoly_construction(self.raw_data[parIdx], - max_deg) + poly_coeffs = apoly_construction( + self.raw_data[parIdx], + max_deg + ) self.polycoeffs[f'p_{parIdx+1}'] = poly_coeffs # Extract moments @@ -648,6 +644,7 @@ class ExpDesigns: # Compute invCDF_y(cdfx) tr_X[:, par_i] = inv_cdf(cdfx[:, par_i]) + return tr_X # ------------------------------------------------------------------------- -- GitLab