From ba90b0adc616f812e5aed2fc10d006ac7fbfbc3f Mon Sep 17 00:00:00 2001 From: Farid Mohammadi <farid.mohammadi@iws.uni-stuttgart.de> Date: Wed, 14 Sep 2022 16:09:44 +0200 Subject: [PATCH] [example][borehole] add the script for the solver comparison. --- .../borehole/data/sparse_solver_comparison.py | 240 ++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 examples/borehole/data/sparse_solver_comparison.py diff --git a/examples/borehole/data/sparse_solver_comparison.py b/examples/borehole/data/sparse_solver_comparison.py new file mode 100644 index 000000000..5245bb89d --- /dev/null +++ b/examples/borehole/data/sparse_solver_comparison.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Sat Sep 10 09:44:05 2022 + +@author: farid +""" + +import numpy as np +import joblib +import os +import scipy.stats as st + +import sys +sys.path.append("../../../src/bayesvalidrox/") + +from pylink.pylink import PyLinkForwardModel +from surrogate_models.inputs import Input +from surrogate_models.surrogate_models import MetaModel +from post_processing.post_processing import PostProcessing +from bayes_inference.bayes_inference import BayesInference +from bayes_inference.discrepancy import Discrepancy +import matplotlib +matplotlib.use('agg') +from matplotlib.backends.backend_pdf import PdfPages +import matplotlib.ticker as ticker +from matplotlib.offsetbox import AnchoredText +from matplotlib.patches import Patch +import matplotlib.pyplot as plt +# Load the mplstyle +plt.style.use(os.path.join( + os.path.split(__file__)[0], + '../../../src/bayesvalidrox/', 'bayesvalidrox.mplstyle')) + + +def plot_seq_design_diagnostics(meta_model, util_funcs, + ref_BME_KLD=None, save_fig=True): + """ + Plots the Bayesian Model Evidence (BME) and Kullback-Leibler divergence + (KLD) for the sequential design. + + Parameters + ---------- + ref_BME_KLD : array, optional + Reference BME and KLD . The default is `None`. + save_fig : bool, optional + Whether to save the figures. The default is `True`. + + Returns + ------- + None. + + """ + n_init_samples = meta_model.ExpDesign.n_init_samples + n_total_samples = meta_model.ExpDesign.X.shape[0] + + if save_fig: + newpath = f'{path}/boxplot_{model_name}/' + if not os.path.exists(newpath): + os.makedirs(newpath) + + plotList = ['Modified LOO error', 'Validation error', 'KLD', 'BME', + 'RMSEMean', 'RMSEStd', 'Hellinger distance'] + seqList = [meta_model.SeqModifiedLOO, meta_model.seqValidError, + meta_model.SeqKLD, meta_model.SeqBME, meta_model.seqRMSEMean, + meta_model.seqRMSEStd, meta_model.SeqDistHellinger] + + markers = ('x', 'o', 'd', '*', '+') + colors = ('k', 'darkgreen', 'b', 'navy', 'darkred') + + # Plot the evolution of the diagnostic criteria of the + # Sequential Experimental Design. + for plotidx, plot in enumerate(plotList): + fig, ax = plt.subplots(figsize=(27, 15)) + seq_dict = seqList[plotidx] + name_util = list(seq_dict.keys()) + + if len(name_util) == 0: + continue + + # Box plot when Replications have been detected. + if any(int(name.split("rep_", 1)[1]) > 1 for name in name_util): + # Extract the values from dict + sorted_seq_opt = {} + # Number of replications + n_reps = meta_model.ExpDesign.n_replication + + # Get the list of utility function names + # Handle if only one UtilityFunction is provided + if not isinstance(util_funcs, list): + util_funcs = [util_funcs] + + for util in util_funcs: + sortedSeq = {} + # min number of runs available from reps + n_runs = min([seq_dict[f'{util}_rep_{i+1}'].shape[0] + for i in range(n_reps)]) + + for runIdx in range(n_runs): + values = [] + for key in seq_dict.keys(): + if util in key: + values.append(seq_dict[key][runIdx].mean()) + sortedSeq['SeqItr_'+str(runIdx)] = np.array(values) + sorted_seq_opt[util] = sortedSeq + + # BoxPlot + def draw_plot(data, labels, edge_color, fill_color, idx): + pos = labels - (4*idx-6) + bp = plt.boxplot(data, positions=pos, labels=labels, + patch_artist=True, sym='', widths=3) + + ax.plot(pos, np.median(data, axis=0), lw=4, color=fill_color[idx]) + + elements = ['boxes', 'whiskers', 'fliers', 'means', + 'medians', 'caps'] + for element in elements: + plt.setp(bp[element], color=edge_color[idx], alpha=0.6) + + for patch in bp['boxes']: + patch.set(facecolor=fill_color[idx], alpha=0.6) + + if meta_model.ExpDesign.n_new_samples != 1: + step1 = meta_model.ExpDesign.n_new_samples + step2 = 1 + else: + step1 = 10 + step2 = 10 + edge_color = ['red', 'blue', 'green', 'black'] + fill_color = ['tan', 'cyan', 'lightgreen', 'grey'] + plot_label = plot + # Plot for different Utility Functions + for idx, util in enumerate(util_funcs): + all_errors = np.empty((n_reps, 0)) + + for key in list(sorted_seq_opt[util].keys()): + errors = sorted_seq_opt.get(util, {}).get(key)[:, None] + all_errors = np.hstack((all_errors, errors)) + + # Special cases for BME and KLD + if plot == 'KLD' or plot == 'BME': + # BME convergence if refBME is provided + if ref_BME_KLD is not None: + if plot == 'BME': + refValue = ref_BME_KLD[0] + plot_label = r'BME/BME$^{Ref.}$' + if plot == 'KLD': + refValue = ref_BME_KLD[1] + plot_label = '$D_{KL}[p(\\theta|y_*),p(\\theta)]'\ + ' / D_{KL}^{Ref.}[p(\\theta|y_*), '\ + 'p(\\theta)]$' + + # Difference between BME/KLD and the ref. values + all_errors = np.divide(all_errors, + np.full((all_errors.shape), + refValue)) + + # Plot baseline for zero, i.e. no difference + plt.axhline(y=1.0, xmin=0, xmax=1, c='green', + ls='--', lw=2) + + # Plot each UtilFuncs + labels = np.arange(n_init_samples, n_total_samples+1, step1) + draw_plot(all_errors[:, ::step2], labels, edge_color, + fill_color, idx) + # labels = np.array([10, 30, 50, 70, 90, 120, 150, 200]) + # indices = [0, 20, 40, 60, 80, 110, 140, 190] + # draw_plot(all_errors[:, indices], labels, edge_color, + # fill_color, idx) + + plt.xticks(labels) + # Set the major and minor locators + # ax.xaxis.set_major_locator(ticker.AutoLocator()) + # ax.xaxis.set_minor_locator(ticker.AutoMinorLocator()) + # ax.xaxis.grid(True, which='major', linestyle='-') + # ax.xaxis.grid(True, which='minor', linestyle='--') + + # Shade + for center in labels[::2]: + ax.axvspan(center-8, center+8, alpha=0.1, color='grey') + + # Legend + legend_elements = [] + for idx, util in enumerate(util_funcs): + legend_elements.append(Patch(facecolor=fill_color[idx], + edgecolor=edge_color[idx], + label=util)) + plt.legend(handles=legend_elements[::-1], loc='best') + + if plot != 'BME' and plot != 'KLD': + plt.yscale('log') + plt.autoscale(True) + # ax.yaxis.set_minor_locator(ticker.LogLocator(numticks=999, subs="auto")) + ax.yaxis.grid(True, which='minor', linestyle='--') + plt.xlabel('\\# of training samples', fontsize=f_size) + plt.ylabel(plot_label, fontsize=f_size) + # plt.title(plot) + plt.xticks(fontsize=f_size) + plt.yticks(fontsize=f_size) + + if save_fig: + # save the current figure + plot_name = plot.replace(' ', '_') + fig.savefig( + f'{newpath}/boxplot_solver_ishigami_{plot_name}.pdf', + bbox_inches='tight' + ) + # Destroy the current plot + plt.clf() + return + + +if __name__ == "__main__": + # Set variables + model_name = 'borehole' + solvers = ['BCS', 'FastARD', 'OMP', 'OLS'] + path = f'/home/farid/bwSyncShare/Scientific_LH2/Promotion/dissertation/surrogate/data-borehole/' + f_size = 45 + + all_loo_errors = {} + all_valid_errors = {} + for solver in solvers: + # reading the data from the file + with open(f"{path}/{solver}/PCEModel_{model_name}.pkl", "rb") as input: + meta_model = joblib.load(input) + + # Update name and Concatenate + all_valid_errors.update({key.replace('ALM', solver): value + for key, value in + meta_model.seqValidError.items() + }) + all_loo_errors.update({key.replace('ALM', solver): value + for key, value in + meta_model.SeqModifiedLOO.items() + }) + meta_model.seqValidError = all_valid_errors + meta_model.SeqModifiedLOO = all_loo_errors + + # Plot box plot + plot_seq_design_diagnostics(meta_model, solvers) -- GitLab