From a231eb750b1187fc95e502ccccf06a52b356ce40 Mon Sep 17 00:00:00 2001
From: Farid Mohammadi <farid.mohammadi@iws.uni-stuttgart.de>
Date: Fri, 12 Aug 2022 16:53:17 +0200
Subject: [PATCH] [surrogate] fix bugs in the PCE implementation of the
 Rosenblatt transformation for dependent variables.

---
 .../surrogate_models/exp_designs.py           | 21 ++++++++-----------
 1 file changed, 9 insertions(+), 12 deletions(-)

diff --git a/src/bayesvalidrox/surrogate_models/exp_designs.py b/src/bayesvalidrox/surrogate_models/exp_designs.py
index 8224053a0..2caa086e5 100644
--- a/src/bayesvalidrox/surrogate_models/exp_designs.py
+++ b/src/bayesvalidrox/surrogate_models/exp_designs.py
@@ -170,12 +170,6 @@ class ExpDesigns:
             samples = self.random_sampler(int(n_samples)).T
 
         return samples.T
-        # # Transform samples to the original space
-        # if transform:
-        #     orig_samples = self.transform(samples.T)
-        #     return orig_samples, samples.T
-        # else:
-        #     return samples.T
 
     # -------------------------------------------------------------------------
     def generate_ED(self, n_samples, sampling_method='random', transform=False,
@@ -260,14 +254,14 @@ class ExpDesigns:
 
         # Transform samples to the original space
         if transform:
-            orig_samples = self.transform(
+            tr_samples = self.transform(
                 samples,
                 method=sampling_method
                 )
-            if sampling_method == 'user':
-                return samples, orig_samples
+            if sampling_method == 'user' or not self.apce:
+                return samples, tr_samples
             else:
-                return orig_samples, samples
+                return tr_samples, samples
         else:
             return samples
 
@@ -327,8 +321,10 @@ class ExpDesigns:
             self.polycoeffs = {}
             for parIdx in tqdm(range(ndim), ascii=True,
                                desc="Computing orth. polynomial coeffs"):
-                poly_coeffs = apoly_construction(self.raw_data[parIdx],
-                                                 max_deg)
+                poly_coeffs = apoly_construction(
+                    self.raw_data[parIdx],
+                    max_deg
+                    )
                 self.polycoeffs[f'p_{parIdx+1}'] = poly_coeffs
 
         # Extract moments
@@ -648,6 +644,7 @@ class ExpDesigns:
 
                 # Compute invCDF_y(cdfx)
                 tr_X[:, par_i] = inv_cdf(cdfx[:, par_i])
+
         return tr_X
 
     # -------------------------------------------------------------------------
-- 
GitLab