diff --git a/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py b/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py index 693afc82dd95e9a44072e504b45d09c794cb3097..4d829f61a118b7544298e469272fc05f5d0f0b7f 100644 --- a/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py +++ b/src/bayesvalidrox/bayes_inference/bayes_model_comparison.py @@ -20,6 +20,27 @@ plt.style.use(os.path.join(os.path.split(__file__)[0], class BayesModelComparison: + """ + A class to perform Bayesian Analysis. + + + Attributes + ---------- + justifiability : bool, optional + Whether to perform the justifiability analysis. The default is + `True`. + perturbed_data : array of shape (n_bootstrap_itrs, n_obs), optional + User defined perturbed data. The default is `None`. + n_bootstarp : int + Number of bootstrap iteration. The default is `1000`. + data_noise_level : float + A noise level to perturb the data set. The default is `0.01`. + just_n_meas : int + Number of measurements considered for visualization of the + justifiability results. + + """ + def __init__(self, justifiability=True, perturbed_data=None, n_bootstarp=1000, data_noise_level=0.01, just_n_meas=2): @@ -31,6 +52,35 @@ class BayesModelComparison: # -------------------------------------------------------------------------- def create_model_comparison(self, model_dict, opts_dict): + """ + Starts the two-stage model comparison. + Stage I: Compare models using Bayes factors. + Stage II: Compare models via justifiability analysis. + + Parameters + ---------- + model_dict : dict + A dictionary including the metamodels. + opts_dict : dict + A dictionary given the `BayesInference` options. + + Example: + + >>> opts_bootstrap = { + "bootstrap": True, + "n_samples": 10000, + "Discrepancy": DiscrepancyOpts, + "emulator": True, + "plot_post_pred": True + } + + Returns + ------- + output : dict + A dictionary containing the objects and the model weights for the + comparison using Bayes factors and justifiability analysis. + + """ # Bayes factor bayes_dict_bf, model_weights_dict_bf = self.compare_models( @@ -54,17 +104,42 @@ class BayesModelComparison: return output # -------------------------------------------------------------------------- - def compare_models(self, modelDict, optsDict, justifiability=False): + def compare_models(self, model_dict, opts_dict, justifiability=False): + """ + Passes the options to instantiates the BayesInference class for each + model and passes the options from `opts_dict`. Then, it starts the + computations. + It also creates a folder and saves the diagrams, e.g., Bayes factor + plot, confusion matrix, etc. + + Parameters + ---------- + model_dict : dict + A dictionary including the metamodels. + opts_dict : dict + A dictionary given the `BayesInference` options. + justifiability : bool, optional + Whether to perform the justifiability analysis. The default is + `False`. + + Returns + ------- + bayes_dict : dict + A dictionary with `BayesInference` objects. + model_weights_dict : dict + A dictionary containing the model weights. + + """ - if not isinstance(modelDict, dict): + if not isinstance(model_dict, dict): raise Exception("To run model comparsion, you need to pass a " "dictionary of models.") # Extract model names - self.model_names = [*modelDict] + self.model_names = [*model_dict] # Compute total number of the measurement points - MetaModel = list(modelDict.items())[0][1] + MetaModel = list(model_dict.items())[0][1] MetaModel.ModelObj.read_observation() self.n_meas = MetaModel.ModelObj.n_obs @@ -77,18 +152,18 @@ class BayesModelComparison: # Create dataset justData = self.generate_dataset( - modelDict, justifiability, n_bootstarp=n_bootstarp) + model_dict, justifiability, n_bootstarp=n_bootstarp) # Run create Interface for each model - bayesDict = {} - for model in modelDict.keys(): + bayes_dict = {} + for model in model_dict.keys(): print("-"*20) print("Bayesian inference of {}.\n".format(model)) - BayesOpts = BayesInference(modelDict[model]) + BayesOpts = BayesInference(model_dict[model]) # Set BayesInference options - for key, value in optsDict.items(): + for key, value in opts_dict.items(): if key in BayesOpts.__dict__.keys(): if key == "Discrepancy" and isinstance(value, dict): setattr(BayesOpts, key, value[model]) @@ -99,16 +174,16 @@ class BayesModelComparison: BayesOpts.perturbed_data = justData BayesOpts.just_analysis = justifiability - bayesDict[model] = BayesOpts.create_inference() + bayes_dict[model] = BayesOpts.create_inference() print("-"*20) # Compute model weights BME_Dict = dict() - for modelName, bayesObj in bayesDict.items(): + for modelName, bayesObj in bayes_dict.items(): BME_Dict[modelName] = np.exp(bayesObj.log_BME, dtype=np.float128) # BME correction in BayesInference class - model_weights = self.cal_modelWeight( + model_weights = self.cal_model_weight( BME_Dict, justifiability, n_bootstarp=n_bootstarp) # Plot model weights @@ -122,7 +197,7 @@ class BayesModelComparison: model_weights_dict = {key: weights for key, weights in zip(model_names, list_ModelWeights)} - self.plot_JustAnalysis(model_weights_dict) + self.plot_just_analysis(model_weights_dict) else: # Create box plot for model weights self.plot_model_weights(model_weights, 'model_weights') @@ -134,15 +209,34 @@ class BayesModelComparison: model_weights_dict = {key: weights for key, weights in zip(self.model_names, model_weights)} - return bayesDict, model_weights_dict + return bayes_dict, model_weights_dict # ------------------------------------------------------------------------- - def generate_dataset(self, modelDict, justifiability=False, + def generate_dataset(self, model_dict, justifiability=False, n_bootstarp=1): + """ + Generates the perturbed data set for the Bayes factor calculations and + the data set for the justifiability analysis. + + Parameters + ---------- + model_dict : dict + A dictionary including the metamodels. + bool, optional + Whether to perform the justifiability analysis. The default is + `False`. + n_bootstarp : int, optional + Number of bootstrap iterations. The default is `1`. + + Returns + ------- + all_just_data: array + Created data set. + """ # Compute some variables all_just_data = [] - metaModel = list(modelDict.items())[0][1] + metaModel = list(model_dict.items())[0][1] out_names = metaModel.ModelObj.Output.names # Perturb observations for Bayes Factor @@ -157,7 +251,7 @@ class BayesModelComparison: # Evaluate metamodel runs = {} - for key, metaModel in modelDict.items(): + for key, metaModel in model_dict.items(): y_hat, _ = metaModel.eval_metamodel(nsamples=n_bootstarp) runs[key] = y_hat @@ -166,7 +260,7 @@ class BayesModelComparison: y_data = self.perturbed_data[i].reshape(1, -1) justData = np.tril(np.repeat(y_data, y_data.shape[1], axis=0)) # Use surrogate runs for data-generating process - for key, metaModel in modelDict.items(): + for key, metaModel in model_dict.items(): model_data = np.array( [runs[key][out][i] for out in out_names]).reshape(y_data.shape) justData = np.vstack(( @@ -228,19 +322,19 @@ class BayesModelComparison: return final_data # ------------------------------------------------------------------------- - def cal_modelWeight(self, BME_Dict, justifiability=False, n_bootstarp=1): + def cal_model_weight(self, BME_Dict, justifiability=False, n_bootstarp=1): """ Normalize the BME (Asumption: Model Prior weights are equal for models) Parameters ---------- - BME_Dict : TYPE - DESCRIPTION. + BME_Dict : dict + A dictionary containing the BME values. Returns ------- - model_weights : TYPE - DESCRIPTION. + model_weights : array + Model weights. """ # Stack the BME values for all models @@ -257,19 +351,33 @@ class BayesModelComparison: return model_weights # ------------------------------------------------------------------------- - def plot_JustAnalysis(self, ModelWeights_dict): + def plot_just_analysis(self, model_weights_dict): + """ + Visualizes the confusion matrix and the model wights for the + justifiability analysis. + + Parameters + ---------- + model_weights_dict : dict + Model weights. + + Returns + ------- + None. + + """ directory = 'Outputs_Comparison/' os.makedirs(directory, exist_ok=True) Color = [*mcolors.TABLEAU_COLORS] - names = [*ModelWeights_dict] + names = [*model_weights_dict] model_names = [model.replace('_', '$-$') for model in self.model_names] for name in names: fig, ax = plt.subplots() for i, model in enumerate(model_names[1:]): plt.plot(list(range(1, self.n_meas+1)), - ModelWeights_dict[name][i], + model_weights_dict[name][i], color=Color[i], marker='o', ms=10, linewidth=2, label=model ) @@ -288,7 +396,7 @@ class BayesModelComparison: epsilon = 1 if self.just_n_meas != 1 else 0 for index in range(0, self.n_meas+epsilon, self.just_n_meas): weights = np.array( - [ModelWeights_dict[key][:, index] for key in ModelWeights_dict] + [model_weights_dict[key][:, index] for key in model_weights_dict] ) g = sns.heatmap( weights.T, annot=True, cmap='Blues', xticklabels=model_names, @@ -307,7 +415,23 @@ class BayesModelComparison: plt.close() # ------------------------------------------------------------------------- - def plot_model_weights(self, modelWeights, plotName): + def plot_model_weights(self, model_weights, plot_name): + """ + Visualizes the model weights resulting from BMS via the observation + data. + + Parameters + ---------- + model_weights : array + Model weights. + plot_name : str + Plot name. + + Returns + ------- + None. + + """ font_size = 40 # mkdir for plots directory = 'Outputs_Comparison/' @@ -317,8 +441,8 @@ class BayesModelComparison: fig, ax = plt.subplots() # Filter data using np.isnan - mask = ~np.isnan(modelWeights.T) - filtered_data = [d[m] for d, m in zip(modelWeights, mask.T)] + mask = ~np.isnan(model_weights.T) + filtered_data = [d[m] for d, m in zip(model_weights, mask.T)] # Create the boxplot bp = ax.boxplot(filtered_data, patch_artist=True, showfliers=False) @@ -366,13 +490,29 @@ class BayesModelComparison: # Save the figure fig.savefig( - f'./{directory}ModelWeights{plotName}.pdf', bbox_inches='tight' + f'./{directory}{plot_name}.pdf', bbox_inches='tight' ) plt.close() # ------------------------------------------------------------------------- def plot_bayes_factor(self, BME_Dict, plot_name=''): + """ + Plots the Bayes factor distibutions in a :math:`N_m \\times N_m` + matrix, where :math:`N_m` is the number of the models. + + Parameters + ---------- + BME_Dict : dict + A dictionary containing the BME values of the models. + plot_name : str, optional + Plot name. The default is ''. + + Returns + ------- + None. + + """ font_size = 40