diff --git a/__pycache__/__init__.cpython-311.pyc b/__pycache__/__init__.cpython-311.pyc index a50469ed17864f696db773fddcadf9eca9a3dc9d..a56ee1f6b4dae7b5e88e838093159afdbc581359 100644 Binary files a/__pycache__/__init__.cpython-311.pyc and b/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/bayesvalidrox/post_processing/post_processing.py b/src/bayesvalidrox/post_processing/post_processing.py index a9aaefd34811381879ceac66adbbea5d1a9703c4..3776743e48c3bafc152410a79f45c93f582243eb 100644 --- a/src/bayesvalidrox/post_processing/post_processing.py +++ b/src/bayesvalidrox/post_processing/post_processing.py @@ -36,7 +36,16 @@ class PostProcessing: Name of the PostProcessing object to be used for saving the generated files. The default is 'calib'. out_dir : string - Output directory in which the images are placed. The default is ''. + Output directory in which the PostProcessing results are placed. + The results are contained in a subfolder '/Outputs_PostProcessing_name' + The default is ''. + out_format : string + Format of the saved plots. Supports 'png' and 'pdf'. The default is 'pdf'. + + Raises + ------ + AttributeError + `engine` must be trained. """ @@ -71,17 +80,21 @@ class PostProcessing: self.stds = None # ------------------------------------------------------------------------- - def plot_moments(self, plot_type: str = None): + def plot_moments(self, plot_type: str = 'line'): """ Plots the moments in a user defined output format (standard is pdf) in the directory `Outputs_PostProcessing`. Parameters ---------- - xlabel : str, optional - String to be displayed as x-label. The default is `'Time [s]'`. plot_type : str, optional - Options: bar or line. The default is `None`. + Supports 'bar' for barplots and 'line' + for lineplots The default is `line`. + + Raises + ------ + AttributeError + Plot type must be 'bar' or 'line'. Returns ------- @@ -91,6 +104,8 @@ class PostProcessing: Standard deviation of the model outputs. """ + if plot_type not in ['bar', 'line']: + raise AttributeError('The wanted plot-type is not supported.') bar_plot = bool(plot_type == "bar") meta_model_type = self.engine.MetaModel.meta_model_type @@ -205,12 +220,10 @@ class PostProcessing: Samples to be evaluated. The default is None. model_out_dict: dict The model runs using the samples provided. - x_axis : str, optional - Label of x axis. The default is `'Time [s]'`. Returns ------- - None. + None """ if samples is None: @@ -263,7 +276,7 @@ class PostProcessing: Returns ------- - None. + None """ # Set the number of samples @@ -324,19 +337,19 @@ class PostProcessing: Parameters ---------- - ref_BME_KLD : array, optional + ref_bme_kld : array, optional Reference BME and KLD . The default is `None`. Returns ------- - None. + None """ engine = self.engine n_init_samples = engine.ExpDesign.n_init_samples n_total_samples = engine.ExpDesign.X.shape[0] - newpath = f"Outputs_PostProcessing_{self.name}/seq_design_diagnostics/" + newpath = f"{self.out_dir}/seq_design_diagnostics/" if not os.path.exists(newpath): os.makedirs(newpath) @@ -614,26 +627,27 @@ class PostProcessing: np.save(f"./{newpath}/seq_{plot_name}.npy", seq_values) # ------------------------------------------------------------------------- - def sobol_indices(self, plot_type: str = None, save: bool = True): + def sobol_indices(self, plot_type: str = 'line', save: bool = True): """ Visualizes and writes out Sobol' and Total Sobol' indices of the trained metamodel. One file is created for each index and output key. Parameters ---------- - xlabel : str, optional - Label of the x-axis. The default is `'Time [s]'`. plot_type : str, optional - Plot type. The default is `None`. This corresponds to line plot. + Plot type, supports 'line' for lineplots and 'bar' for barplots. + The default is `line`. Bar chart can be selected by `bar`. save : bool, optional Write out the inidces as csv files if set to True. The default - is True + is True. Raises ------ AttributeError MetaModel in given Engine needs to be of type 'pce' or 'apce'. + AttributeError + Plot-type must be 'line' or 'bar'. Returns ------- @@ -649,6 +663,9 @@ class PostProcessing: raise AttributeError("Sobol indices only support PCE-type models!") if metamod.meta_model_type.lower() not in ["pce", "apce"]: raise AttributeError("Sobol indices only support PCE-type models!") + + if plot_type not in ['line', 'bar']: + raise AttributeError('The wanted plot type is not supported.') # Extract the necessary variables max_order = np.max(metamod._pce_deg) @@ -722,7 +739,7 @@ class PostProcessing: Returns ------- - None. + None """ sobol = None @@ -820,7 +837,7 @@ class PostProcessing: Returns ------- - None. + None """ if samples is None: @@ -830,11 +847,6 @@ class PostProcessing: else: n_samples = samples.shape[0] - # if samples is not isinstance(samples, np.ndarray): - # raise TypeError - - if outputs is None or outputs == {}: - raise AttributeError("Please provide the outputs of the model!") # Evaluate the original and the surrogate model y_val = outputs if y_val is None: @@ -946,7 +958,7 @@ class PostProcessing: Returns ------- - None. + None """ if self.engine.ExpDesign.ndim != 2: @@ -1000,7 +1012,7 @@ class PostProcessing: Returns ------- - None. + None """ # List of markers and colors diff --git a/tests/test_PostProcessing.py b/tests/test_PostProcessing.py index 5ba76449b9f90f0ea0e093bd9d1f341a5ceb93cd..6ec891854fcfa50e9c4b7c932029b64e9e16d0e7 100644 --- a/tests/test_PostProcessing.py +++ b/tests/test_PostProcessing.py @@ -255,22 +255,6 @@ def test_plot_moments_gpebar(gpe_engine) -> None: assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01) #%% valid_metamodel -def test_plot_validation_multi_pce(pce_engine): - engine = pce_engine - post = PostProcessing(engine) - out_mean = {'Z': np.array([[0.4], [0.5], [0.45], [0.4]])} - out_std = {'Z': np.array([[0.1], [0.1], [0.1], [0.1]])} - post.model_out_dict = {'Z': np.array([[0.4], [0.5],[0.3],[0.4]])} - post._plot_validation_multi(out_mean, out_std) - -def test_plot_validation_multi_gpe(gpe_engine): - engine = gpe_engine - post = PostProcessing(engine) - out_mean = {'Z': np.array([[0.4], [0.5], [0.45]])} - out_std = {'Z': np.array([[0.1], [0.1], [0.1]])} - post.model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])} - post._plot_validation_multi(out_mean, out_std) - def test_valid_metamodel_pce(pce_engine): engine = pce_engine post = PostProcessing(engine) @@ -311,7 +295,7 @@ def test_plot_seq_design_diagnostics(basic_engine_trained): post = PostProcessing(engine) post.plot_seq_design_diagnostics() # Check if the plot was created and saved - assert os.path.exists(f"./{engine.out_dir}/Seq_Design_Diagnostics.{engine.out_format}") + assert os.path.exists(f"./{post.out_dir}/Seq_Design_Diagnostics.{post.out_format}") def test_plot_seq_design_diagnostics_with_custom_values(basic_engine_trained): """ @@ -323,7 +307,7 @@ def test_plot_seq_design_diagnostics_with_custom_values(basic_engine_trained): post = PostProcessing(engine) post.plot_seq_design_diagnostics() # Check if the plot was created and saved - assert os.path.exists(f"./{engine.out_dir}/Seq_Design_Diagnostics.{engine.out_format}") + assert os.path.exists(f"./{post.out_dir}/Seq_Design_Diagnostics.{post.out_format}") def test_plot_seq_design_diagnostics_with_empty_values(basic_engine_trained): """ @@ -428,13 +412,18 @@ def test_plot_validation_multi(pce_engine_3d_plot): assert os.path.exists(f"./{post.out_dir}/Model_vs_pceModel_Y.{post.out_format}") assert os.path.exists(f"./{post.out_dir}/Model_vs_pceModel_Z.{post.out_format}") -def test_plot_validation_multi_with_empty_values(pce_engine_3d_plot) -> None: - """ - Test the _plot_validation_multi method with empty values - """ - engine = pce_engine_3d_plot +def test_plot_validation_multi_pce(pce_engine): + engine = pce_engine post = PostProcessing(engine) - with pytest.raises(ValueError) as excinfo: - post._plot_validation_multi({},{},{}) - assert "y_val and y_val_std cannot be empty" in str(excinfo.value) - \ No newline at end of file + out_mean = {'Z': np.array([[0.4], [0.5], [0.45], [0.4]])} + out_std = {'Z': np.array([[0.1], [0.1], [0.1], [0.1]])} + model_out_dict = {'Z': np.array([[0.4], [0.5],[0.3],[0.4]])} + post._plot_validation_multi(out_mean, out_std, model_out_dict) + +def test_plot_validation_multi_gpe(gpe_engine): + engine = gpe_engine + post = PostProcessing(engine) + out_mean = {'Z': np.array([[0.4], [0.5], [0.45]])} + out_std = {'Z': np.array([[0.1], [0.1], [0.1]])} + model_out_dict = {'Z': np.array([[0.4], [0.5], [0.45]])} + post._plot_validation_multi(out_mean, out_std, model_out_dict)