# -*- coding: utf-8 -*-
"""
Test the PostProcessing class in bayesvalidrox.
Tests are available for the following functions
Class PostProcessing: 
    init
    plot_moments
    valid_metamodel
    check_accuracy
    plot_seq_design
    sobol_indices
    check_req_quality
    eval_pce_model_3d
    _get_sample
    _eval_model
    _plot_validation
    _plot_validation_multi
"""
import sys
sys.path.append("../src/")
import numpy as np
import pytest

from bayesvalidrox.post_processing.post_processing import PostProcessing
from bayesvalidrox.surrogate_models.inputs import Input
from bayesvalidrox.surrogate_models.input_space import InputSpace
from bayesvalidrox.surrogate_models.exp_designs import ExpDesigns
from bayesvalidrox.surrogate_models.surrogate_models import MetaModel
from bayesvalidrox.surrogate_models.polynomial_chaos import PCE
from bayesvalidrox.pylink.pylink import PyLinkForwardModel as PL
from bayesvalidrox.surrogate_models.engine import Engine


@pytest.fixture
def basic_engine():
    inp = Input()
    inp.add_marginals()
    inp.Marginals[0].dist_type = 'normal'
    inp.Marginals[0].parameters = [0, 1]
    mm = MetaModel(inp)
    mod = PL()
    expdes = ExpDesigns(inp)
    engine = Engine(mm, mod, expdes)
    engine.out_names = ['Z']
    engine.emulator = True
    return engine

@pytest.fixture
def pce_engine():
    inp = Input()
    inp.add_marginals()
    inp.Marginals[0].dist_type = 'normal'
    inp.Marginals[0].parameters = [0, 1]
    expdes = ExpDesigns(inp)
    expdes.init_param_space(max_deg=1)
    expdes.X = np.array([[0], [1], [0.5]])
    expdes.Y = {'Z': [[0.4], [0.5], [0.45]]}
    expdes.x_values = [0]

    mm = PCE(inp)
    mm.fit(expdes.X, expdes.Y)
    mod = PL()
    engine = Engine(mm, mod, expdes)
    engine.out_names = ['Z']
    engine.emulator = True
    return engine

#%% Test PostProcessing init

def test_postprocessing_noengine():
    None

def test_postprocessing(basic_engine) -> None:
    engine = basic_engine
    PostProcessing(engine)
    
    
#%% plot_moments

def test_plot_moments_pce(pce_engine) -> None:
    """
    Plot moments for PCE metamodel
    """
    engine = pce_engine
    post = PostProcessing(engine)
    mean, stdev = post.plot_moments()
    # Check the mean dict
    assert list(mean.keys()) == ['Z']
    assert mean['Z'].shape == (1,)
    assert mean['Z'][0] == pytest.approx(0.4, abs=0.01)
    # Check the stdev dict
    assert list(stdev.keys()) == ['Z']
    assert stdev['Z'].shape == (1,)
    assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01)
    
def test_plot_moments_pcebar(pce_engine) -> None:
    """
    Plot moments for PCE metamodel with bar-plot
    """
    engine = pce_engine
    post = PostProcessing(engine)
    mean, stdev = post.plot_moments(plot_type='bar')
    # Check the mean dict
    assert list(mean.keys()) == ['Z']
    assert mean['Z'].shape == (1,)
    assert mean['Z'][0] == pytest.approx(0.4, abs=0.01)
    # Check the stdev dict
    assert list(stdev.keys()) == ['Z']
    assert stdev['Z'].shape == (1,)
    assert stdev['Z'][0] == pytest.approx(0.1, abs=0.01)
    

#%% valid_metamodel

#%% check_accuracy

def test_check_accuracy_pce(pce_engine) -> None:
    """
    Check accuracy for PCE metamodel 
    """
    engine = pce_engine
    post = PostProcessing(engine)
    post.check_accuracy(samples = engine.ExpDesign.X, outputs = engine.ExpDesign.Y)


#%% plot_seq_design
#%% sobol_indices

def test_sobol_indices_nopce(basic_engine) -> None:
    """
    Calculate sobol indices for non-PCE metamodel
    """
    engine = basic_engine
    post = PostProcessing(engine)
    with pytest.raises(AttributeError) as excinfo:
        post.sobol_indices()
    assert str(excinfo.value) == 'Sobol indices only support PCE-type models!'
    
def test_sobol_indices_pce(pce_engine) -> None:
    """
    Calculate sobol indices for PCE metamodel
    """
    engine = pce_engine
    post = PostProcessing(engine)
    sobol = post.sobol_indices()
    assert list(sobol.keys()) == ['Z']
    assert sobol['Z'].shape == (1,1)
    assert sobol['Z'][0,0] == 1

#%% check_reg_quality
#%% eval_pce_model_3d

def test_eval_pce_model_3d_nopce(basic_engine) -> None:
    """
    3d eval of non-PCE metamodel
    """
    engine = basic_engine
    post = PostProcessing(engine)
    with pytest.raises(AttributeError) as excinfo:
        post.eval_pce_model_3d()
    assert str(excinfo.value) == 'This evaluation only support PCE-type models!'
    

#%% _get_sample
#%% _eval_model
#%% _plot_validation

def test_plot_validation_nopce(basic_engine) -> None:
    """
    Plot validation of non-PCE metamodel
    """
    engine = basic_engine
    post = PostProcessing(engine)
    with pytest.raises(AttributeError) as excinfo:
        post._plot_validation()
    assert str(excinfo.value) == 'This evaluation only support PCE-type models!'
    
#%% _plot_validation_multi
    
def test_plot_validation_multi_nopce(basic_engine) -> None:
    """
    Plot multi-validation of non-PCE metamodel
    """
    engine = basic_engine
    post = PostProcessing(engine)
    with pytest.raises(AttributeError) as excinfo:
        post._plot_validation_multi()
    assert str(excinfo.value) == 'This evaluation only support PCE-type models!'