diff --git a/src/bayesvalidrox/post_processing/post_processing.py b/src/bayesvalidrox/post_processing/post_processing.py index 0873d3c07cf4e4348e93704b700d86cfee5ee459..d67e3529fcc6eec98b63e390d905722faeb719e3 100644 --- a/src/bayesvalidrox/post_processing/post_processing.py +++ b/src/bayesvalidrox/post_processing/post_processing.py @@ -374,7 +374,7 @@ class PostProcessing: for plotidx, plot in enumerate(plot_list): fig, ax = plt.subplots() seq_dict = seq_list[plotidx] - name_util = list(seq_dict.keys()) + name_util = list(seq_dict.keys()) # TODO: same as engine.out_names? if len(name_util) == 0: continue diff --git a/tests/test_PostProcessing.py b/tests/test_PostProcessing.py index 205cbde579833f9e5fa5bf2f412736e1f7a6d43c..4867a34b7f531f83dff1b41834920776d2cbc36b 100644 --- a/tests/test_PostProcessing.py +++ b/tests/test_PostProcessing.py @@ -94,7 +94,11 @@ def basic_engine_sequential(): mod = PL() expdes = ExpDesigns(inp) - expdes.n_samples = 3 + expdes.n_init_samples = 4 + expdes.n_max_samples = 4 + expdes.X = np.array([[0, 0], [1, 1], [0.5, 0.5], [0.1, 0.5]]) # Two input dimensions + expdes.Y = {'Z': [[0.4], [0.5], [0.3], [0.4]]} # Output values + expdes.x_values = np.array([0]) engine = Engine(mm, mod, expdes) engine.out_names = ['Z'] @@ -103,13 +107,18 @@ def basic_engine_sequential(): engine.train_sequential() - engine.SeqModifiedLOO = {'Z': np.array([0.1, 0.2, 0.3])} - engine.seqValidError = {'Z': np.array([0.15, 0.25, 0.35])} - engine.SeqKLD = {'Z': np.array([0.05, 0.1, 0.15])} - engine.SeqBME = {'Z': np.array([0.02, 0.04, 0.06])} - engine.seqRMSEMean = {'Z': np.array([0.12, 0.14, 0.16])} - engine.seqRMSEStd = {'Z': np.array([0.03, 0.05, 0.07])} - engine.SeqDistHellinger = {'Z': np.array([0.08, 0.09, 0.1])} + engine.SeqModifiedLOO = {'DKL_rep_1': np.array([[1.31565589e-10], + [1.31413432e-10]])} + engine.seqValidError = {} + engine.SeqKLD = {'DKL_rep_1': np.array([[2.6296851 ], + [2.60875351]])} + engine.SeqBME = {'DKL_rep_1': np.array([[-19.33941695], + [-19.29572507]])} + engine.seqRMSEMean = {'DKL_rep_1': np.array([[1.02174823], + [1.02174727]])} + engine.seqRMSEStd = {'DKL_rep_1': np.array([[0.76724993], + [0.76725023]])} + engine.SeqDistHellinger = {} for key, array in engine.SeqModifiedLOO.items(): assert np.all(array != None), f"Array {key} contains None values." @@ -326,10 +335,17 @@ def test_plot_seq_design_diagnostics(basic_engine_sequential): Test the plot_seq_design_diagnostics method """ engine = basic_engine_sequential + engine.ExpDesign.n_max_samples = 4 + engine.ExpDesign.n_init_samples = 3 + post = PostProcessing(engine) post.plot_seq_design_diagnostics() # Check if the plot was created and saved - assert os.path.exists(f"./{post.out_dir}/Seq_Design_Diagnostics.{post.out_format}") + assert os.path.exists(f"./{post.out_dir}/seq_design_diagnostics/seq_BME.{post.out_format}") + assert os.path.exists(f"./{post.out_dir}/seq_design_diagnostics/seq_KLD.{post.out_format}") + assert os.path.exists(f"./{post.out_dir}/seq_design_diagnostics/seq_Modified_LOO_error.{post.out_format}") + assert os.path.exists(f"./{post.out_dir}/seq_design_diagnostics/seq_RMSEMean.{post.out_format}") + assert os.path.exists(f"./{post.out_dir}/seq_design_diagnostics/seq_RMSEStd.{post.out_format}") def test_plot_seq_design_diagnostics_with_custom_values(basic_engine_trained): """ @@ -341,7 +357,7 @@ def test_plot_seq_design_diagnostics_with_custom_values(basic_engine_trained): post = PostProcessing(engine) post.plot_seq_design_diagnostics() # Check if the plot was created and saved - assert os.path.exists(f"./{post.out_dir}/Seq_Design_Diagnostics.{post.out_format}") + assert os.path.exists(f"./{post.out_dir}/seq_design_diagnostics/Seq_Design_Diagnostics.{post.out_format}") def test_plot_seq_design_diagnostics_with_empty_values(basic_engine_trained): """