From b4c7666230f29b449b7bb163b35a0bc791f2f57c Mon Sep 17 00:00:00 2001 From: faridm69 <faridmohammadi69@gmail.com> Date: Thu, 16 Jul 2020 16:16:36 +0200 Subject: [PATCH] [surrogate] bug fixed --- BayesValidRox/surrogate_models/surrogate_models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/BayesValidRox/surrogate_models/surrogate_models.py b/BayesValidRox/surrogate_models/surrogate_models.py index 12e86cb53..7334412fa 100644 --- a/BayesValidRox/surrogate_models/surrogate_models.py +++ b/BayesValidRox/surrogate_models/surrogate_models.py @@ -607,7 +607,7 @@ class aPCE: # Extract the univariate polynomials on ExpDesign univ_p_val = self.univ_p_val - + for degIdx, deg in enumerate(DegreeArray): for qidx, q in enumerate(qnorm): @@ -650,7 +650,7 @@ class aPCE: break # Leave the loop, if FastARD did not converge. - if self.RegMethod == 'FastARD' and not clf_poly.converged: + if self.RegMethod == 'FastARD' and not clf_poly.converged and deg != 1: print("Degree {0} not converged!".format(deg)) break @@ -947,13 +947,13 @@ class aPCE: nitr = nSamples - self.ExpDesign.initNrSamples d = nitr if nitr != 0 and self.NofPa > 5 else 1 M_uptoMax = lambda maxDeg: np.array([math.factorial(ndim+d)//(math.factorial(ndim)*math.factorial(d)) for d in range(1,maxDeg+1)]) - deg = range(1,maxDeg+1)[np.argmin(abs(M_uptoMax(maxDeg)-nSamples*ndim*d))] + degNew = range(1,maxDeg+1)[np.argmin(abs(M_uptoMax(maxDeg)-nSamples*ndim*d))] self.q = np.array(self.q) if not np.isscalar(self.q) else np.array([self.q]) - self.DegreeArray = np.array([deg]) #np.arange(self.MinPceDegree,deg+1) #or np.array([deg]) + self.DegreeArray = np.arange(self.MinPceDegree,degNew+1) #or np.array([deg]) for deg in self.DegreeArray: # self.allBasisIndices = self.AutoVivification() - if deg not in self.allBasisIndices.keys(): + if deg not in np.fromiter(self.allBasisIndices.keys(), dtype=float): # Generate the polynomial basis indices for qidx, q in enumerate(self.q): self.allBasisIndices[str(deg)][str(q)] = self.PolyBasisIndices(degree=deg, q=q) -- GitLab