#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This test deals with the surrogate modeling of a Ishigami function.

You will see how to:
    Check the quality of your regression model
    Perform sensitivity analysis via Sobol Indices

Author: Farid Mohammadi, M.Sc.
E-Mail: farid.mohammadi@iws.uni-stuttgart.de
Department of Hydromechanics and Modelling of Hydrosystems (LH2)
Institute for Modelling Hydraulic and Environmental Systems (IWS), University
of Stuttgart, www.iws.uni-stuttgart.de/lh2/
Pfaffenwaldring 61
70569 Stuttgart

"""

import numpy as np
import joblib

# import bayesvalidrox
# Add BayesValidRox path
import sys
sys.path.append("../../src/bayesvalidrox/")

from bayesvalidrox import PyLinkForwardModel, Input, ExpDesigns, PCE, PostProcessing, BayesInference, Discrepancy, Engine 

import matplotlib
matplotlib.use('agg')

if __name__ == "__main__":

    # =====================================================
    # =============   COMPUTATIONAL MODEL  ================
    # =====================================================
    Model = PyLinkForwardModel()

    # Define model options
    Model.link_type = 'Function'
    Model.py_file = 'Ishigami'
    Model.name = 'Ishigami'

    Model.Output.names = ['Z']

    # =====================================================
    # =========   PROBABILISTIC INPUT MODEL  ==============
    # =====================================================
    Inputs = Input()

    Inputs.add_marginals()
    Inputs.Marginals[0].name = '$X_1$'
    Inputs.Marginals[0].dist_type = 'unif'
    Inputs.Marginals[0].parameters = [-np.pi, np.pi]

    Inputs.add_marginals()
    Inputs.Marginals[1].name = '$X_2$'
    Inputs.Marginals[1].dist_type = 'unif'
    Inputs.Marginals[1].parameters = [-np.pi, np.pi]

    Inputs.add_marginals()
    Inputs.Marginals[2].name = '$X_3$'
    Inputs.Marginals[2].dist_type = 'unif'
    Inputs.Marginals[2].parameters = [-np.pi, np.pi]

    # =====================================================
    # ======  POLYNOMIAL CHAOS EXPANSION METAMODELS  ======
    # =====================================================
    MetaModelOpts = PCE(Inputs, Model)

    # Select your metamodel method
    # 1) PCE (Polynomial Chaos Expansion) 2) aPCE (arbitrary PCE)
    # 3) GPE (Gaussian Process Emulator)
    MetaModelOpts.meta_model_type = 'aPCE'

    # ------------------------------------------------
    # ------------- PCE Specification ----------------
    # ------------------------------------------------
    # Select the sparse least-square minimization method for
    # the PCE coefficients calculation:
    # 1)OLS: Ordinary Least Square  2)BRR: Bayesian Ridge Regression
    # 3)LARS: Least angle regression  4)ARD: Bayesian ARD Regression
    # 5)FastARD: Fast Bayesian ARD Regression
    # 6)BCS: Bayesian Compressive Sensing
    # 7)OMP: Orthogonal Matching Pursuit
    # 8)VBL: Variational Bayesian Learning
    # 9)EBL: Emperical Bayesian Learning
    MetaModelOpts.pce_reg_method = 'BCS'

    # Specify the max degree to be compared by the adaptive algorithm:
    # The degree with the lowest Leave-One-Out cross-validation (LOO)
    # error (or the highest score=1-LOO)estimator is chosen as the final
    # metamodel. pce_deg accepts degree as a scalar or a range.
    MetaModelOpts.pce_deg = 14

    # q-quasi-norm 0<q<1 (default=1)
    MetaModelOpts.pce_q_norm = 1.0

    # Print summary of the regression results
    # MetaModelOpts.verbose = True

    # ------------------------------------------------
    # ------ Experimental Design Configuration -------
    # ------------------------------------------------
    ExpDesign = ExpDesigns(Inputs)

    # One-shot (normal) or Sequential Adaptive (sequential) Design
    ExpDesign.method = 'normal'
    ExpDesign.n_init_samples = 50

    # Sampling methods
    # 1) random 2) latin_hypercube 3) sobol 4) halton 5) hammersley 6) korobov
    # 7) chebyshev(FT) 8) grid(FT) 9) nested_grid(FT) 10)user
    ExpDesign.sampling_method = 'latin_hypercube'

    # Provide the experimental design object with a hdf5 file
    # MetaModelOpts.ExpDesign.hdf5_file = 'ExpDesign_Ishigami.hdf5'

    # Sequential experimental design (needed only for sequential ExpDesign)
    ExpDesign.n_new_samples = 1
    ExpDesign.n_max_samples = 200  # 150
    ExpDesign.mod_LOO_threshold = 1e-16

    # ------------------------------------------------
    # ------- Sequential Design configuration --------
    # ------------------------------------------------
    # 1) None 2) 'equal' 3)'epsilon-decreasing' 4) 'adaptive'
    ExpDesign.tradeoff_scheme = None
    # MetaModelOpts.ExpDesign.n_replication = 50
    # -------- Exploration ------
    # 1)'Voronoi' 2)'random' 3)'latin_hypercube' 4)'dual annealing'
    ExpDesign.explore_method = 'latin_hypercube'

    # Use when 'dual annealing' chosen
    ExpDesign.max_func_itr = 200

    # Use when 'Voronoi' or 'random' or 'latin_hypercube' chosen
    ExpDesign.n_canddidate = 1000
    ExpDesign.n_cand_groups = 4

    # -------- Exploitation ------
    # 1)'BayesOptDesign' 2)'VarOptDesign' 3)'alphabetic' 4)'Space-filling'
    ExpDesign.exploit_method = 'Space-filling'

    # BayesOptDesign -> when data is available
    # 1)DKL (Kullback-Leibler Divergence) 2)DPP (D-Posterior-percision)
    # 3)APP (A-Posterior-percision)
    # MetaModelOpts.ExpDesign.util_func = 'DKL'

    # VarBasedOptDesign -> when data is not available
    # Only with Vornoi >>> 1)Entropy 2)EIGF, 3)ALM, 4)LOOCV
    ExpDesign.util_func = 'ALM'

    # alphabetic
    # 1)D-Opt (D-Optimality) 2)A-Opt (A-Optimality)
    # 3)K-Opt (K-Optimality)
    # MetaModelOpts.ExpDesign.util_func = 'D-Opt'

    ExpDesign.valid_samples = np.load("data/valid_samples.npy")
    ExpDesign.valid_model_runs = {'Z': np.load("data/valid_outputs.npy")}
    # >>>>>>>>>>>>>>>>>>>>>> Build Surrogate <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
    MetaModelOpts.ExpDesign = ExpDesign
    engine = Engine(MetaModelOpts, Model, ExpDesign)
    engine.start_engine()
    #engine.train_sequential()
    engine.train_normal()

    # Save PCE models
    with open(f'PCEModel_{Model.name}.pkl', 'wb') as output:
        joblib.dump(engine.MetaModel, output, 2)

    # =====================================================
    # =========  POST PROCESSING OF METAMODELS  ===========
    # =====================================================
    PostPCE = PostProcessing(engine)

    # Plot to check validation visually.
    PostPCE.valid_metamodel(n_samples=200)

    # Check the quality of your regression model
    PostPCE.check_reg_quality()

    # PostPCE.eval_PCEmodel_3D()
    # Compute and print RMSE error
    PostPCE.check_accuracy(n_samples=3000)

    # Plot the evolution of the KLD,BME, and Modified LOOCV error
    if ExpDesign.method == 'sequential':
        PostPCE.plot_seq_design_diagnostics()

    # Plot the sobol indices
    total_sobol = PostPCE.sobol_indices(plot_type='bar')