Skip to content
Snippets Groups Projects
Commit de60ea40 authored by Alina Lacheim's avatar Alina Lacheim
Browse files

added user defined output formats in PostProcessing for diagrams

parent d5937852
No related branches found
No related tags found
1 merge request!37Fix/post processing
......@@ -37,12 +37,13 @@ class PostProcessing:
"""
def __init__(self, engine, name="calib", out_dir=""):
def __init__(self, engine, name="calib", out_dir="", out_format="pdf"):
# PostProcessing only available for trained engines
if not engine.trained:
raise AttributeError('PostProcessing can only be performed on trained engines.')
self.engine = engine
self.name = name
self.out_format = out_format
self.par_names = self.engine.ExpDesign.par_names
self.x_values = self.engine.ExpDesign.x_values
......@@ -71,7 +72,7 @@ class PostProcessing:
# -------------------------------------------------------------------------
def plot_moments(self, plot_type: str = None):
"""
Plots the moments in a pdf format in the directory
Plots the moments in a user defined output format (standard is pdf) in the directory
`Outputs_PostProcessing`.
Parameters
......@@ -187,7 +188,7 @@ class PostProcessing:
plt.tight_layout()
# save the current figure
fig.savefig(f"{self.out_dir}Mean_Std_PCE_{key}.pdf", bbox_inches="tight")
fig.savefig(f"{self.out_dir}Mean_Std_PCE_{key}.{self.out_format}", bbox_inches="tight")
return self.means, self.stds
......@@ -228,20 +229,20 @@ class PostProcessing:
n_obs = self.model_out_dict[key].shape[1]
if n_obs == 1:
self._plot_validation()
else:
self._plot_validation_multi()
# if n_obs == 1:
# self._plot_validation()
# else:
self._plot_validation_multi()
# Zip the subdirectories
self.engine.Model.zip_subdirs(
f"{self.engine.Model.name}valid", f"{self.engine.Model.name}valid_"
)
# Zip the subdirectories
self.engine.Model.zip_subdirs(
f"{self.engine.Model.name}valid", f"{self.engine.Model.name}valid_"
)
# Zip the subdirectories
self.engine.Model.zip_subdirs(
f"{self.engine.Model.name}valid", f"{self.engine.Model.name}valid_"
)
# Zip the subdirectories
self.engine.Model.zip_subdirs(
f"{self.engine.Model.name}valid", f"{self.engine.Model.name}valid_"
)
# -------------------------------------------------------------------------
def check_accuracy(self, n_samples=None, samples=None, outputs=None) -> None:
......@@ -493,7 +494,7 @@ class PostProcessing:
# save the current figure
plot_name = plot.replace(" ", "_")
fig.savefig(f"./{newpath}/seq_{plot_name}.pdf", bbox_inches="tight")
fig.savefig(f"./{newpath}/seq_{plot_name}.{self.out_format}", bbox_inches="tight")
# Destroy the current plot
plt.close()
# Save arrays into files
......@@ -596,7 +597,7 @@ class PostProcessing:
# save the current figure
plot_name = plot.replace(" ", "_")
fig.savefig(f"./{newpath}/seq_{plot_name}.pdf", bbox_inches="tight")
fig.savefig(f"./{newpath}/seq_{plot_name}.{self.out_format}", bbox_inches="tight")
# Destroy the current plot
plt.close()
......@@ -748,13 +749,13 @@ class PostProcessing:
if sobol_type == "sobol":
plt.title(f"{i_order} order Sobol' indices of {output}")
fig.savefig(
f"{self.out_dir}Sobol_indices_{i_order}_{output}.pdf",
f"{self.out_dir}Sobol_indices_{i_order}_{output}.{self.out_format}",
bbox_inches="tight",
)
elif sobol_type == "totalsobol":
plt.title(f"Total Sobol' indices of {output}")
fig.savefig(
f"{self.out_dir}TotalSobol_indices_{output}.pdf",
f"{self.out_dir}TotalSobol_indices_{output}.{self.out_format}",
bbox_inches="tight",
)
plt.clf()
......@@ -822,7 +823,7 @@ class PostProcessing:
# save the current figure
fig1.savefig(
f"./{self.out_dir}/Residuals_vs_Par_{i+1}.pdf", bbox_inches="tight"
f"./{self.out_dir}/Residuals_vs_Par_{i+1}.{self.out_format}", bbox_inches="tight"
)
# Destroy the current plot
plt.close()
......@@ -842,7 +843,7 @@ class PostProcessing:
# save the current figure
fig2.savefig(
f"./{self.out_dir}/Fitted_vs_Residuals.pdf", bbox_inches="tight"
f"./{self.out_dir}/Fitted_vs_Residuals.{self.out_format}", bbox_inches="tight"
)
# Destroy the current plot
plt.close()
......@@ -870,7 +871,7 @@ class PostProcessing:
# save the current figure
fig3.savefig(
f"./{self.out_dir}/Hist_NormResiduals.pdf", bbox_inches="tight"
f"./{self.out_dir}/Hist_NormResiduals.{self.out_format}", bbox_inches="tight"
)
# Destroy the current plot
plt.close()
......@@ -887,7 +888,7 @@ class PostProcessing:
# save the current figure
plt.savefig(
f"./{self.out_dir}/QQPlot_NormResiduals.pdf", bbox_inches="tight"
f"./{self.out_dir}/QQPlot_NormResiduals.{self.out_format}", bbox_inches="tight"
)
# Destroy the current plot
plt.close()
......@@ -948,7 +949,7 @@ class PostProcessing:
# save the figure to file
fig.savefig(
f"./{self.out_dir}/3DPlot_{title}_{name}{t}.pdf",
f"./{self.out_dir}/3DPlot_{title}_{name}{t}.{self.out_format}",
bbox_inches="tight",
)
plt.close(fig)
......@@ -1041,6 +1042,7 @@ class PostProcessing:
for key, value in metamod._coeffs_dict["b_1"][key].items():
length_list.append(len(value))
n_predictors = min(length_list)
n_samples = y_pce_val[key].shape[0]
r_2 = r2_score(y_pce_val[key], y_val[key])
......@@ -1060,7 +1062,7 @@ class PostProcessing:
# save the current figure
plot_name = key.replace(" ", "_")
fig.savefig(
f"./{self.out_dir}/Model_vs_PCEModel_{plot_name}.pdf",
f"./{self.out_dir}/Model_vs_PCEModel_{plot_name}.{self.out_format}",
bbox_inches="tight",
)
......@@ -1150,6 +1152,6 @@ class PostProcessing:
plt.grid()
key = key.replace(" ", "_")
fig.savefig(
f"./{self.out_dir}/Model_vs_PCEModel_{key}.pdf", bbox_inches="tight"
f"./{self.out_dir}/Model_vs_PCEModel_{key}.{self.out_format}", bbox_inches="tight"
)
plt.close()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment