import pytest
import numpy as np
from psimpy.inference.active_learning import ActiveLearning
from psimpy.simulator.run_simulator import RunSimulator
from psimpy.simulator.mass_point_model import MassPointModel
from psimpy.sampler.latin import LHS
from psimpy.sampler.saltelli import Saltelli
from psimpy.emulator.robustgasp import ScalarGaSP, PPGaSP
from scipy.stats import uniform, norm
from scipy import optimize
from beartype.roar import BeartypeCallHintParamViolation

@pytest.mark.parametrize(
    "run_sim_obj, prior, likelihood, lhs_sampler, scalar_gasp, optimizer",
    [
        (MassPointModel(), uniform.pdf, norm.pdf, LHS(1), ScalarGaSP(1),
        optimize.brute), 
        (RunSimulator(MassPointModel.run, ['coulomb_friction']),
        None, norm.pdf, LHS(1), ScalarGaSP(1), optimize.brute),
        (RunSimulator(MassPointModel.run, ['coulomb_friction']),
        uniform.pdf, None, LHS(1), ScalarGaSP(1), optimize.brute),
        (RunSimulator(MassPointModel.run, ['coulomb_friction']),
        uniform.pdf, norm.pdf, Saltelli(1), ScalarGaSP(1), optimize.brute),
        (RunSimulator(MassPointModel.run, ['coulomb_friction']),
        uniform.pdf, norm.pdf, LHS(1), PPGaSP(1), optimize.brute),
        (RunSimulator(MassPointModel.run, ['coulomb_friction']),
        uniform.pdf, norm.pdf, LHS(1), ScalarGaSP(1), None)       
    ]
)  
def test_ActiveLearning_init_TypeError(run_sim_obj, prior, likelihood,
    lhs_sampler, scalar_gasp, optimizer):
    ndim = 1
    bounds = np.array([[0,1]])
    data  = np.array([1,2,3])
    with pytest.raises(BeartypeCallHintParamViolation):
        _ = ActiveLearning(ndim, bounds, data, run_sim_obj, prior, likelihood,
            lhs_sampler, scalar_gasp, optimizer=optimizer)


@pytest.mark.parametrize(
    "run_sim_obj, lhs_sampler, scalar_gasp",
    [
        (RunSimulator(MassPointModel.run,
            ['coulomb_friction', 'turbulent_friction']),
        LHS(1), ScalarGaSP(1)
        ),
        (RunSimulator(MassPointModel.run, ['coulomb_friction']),
        LHS(2), ScalarGaSP(1)),
        (RunSimulator(MassPointModel.run, ['coulomb_friction']),
        LHS(1), ScalarGaSP(3))      
    ]
)  
def test_ActiveLearning_init_RuntimeError(run_sim_obj, lhs_sampler,
    scalar_gasp):
    ndim = 1
    bounds = np.array([[0,1]])
    data  = np.array([1,2,3])
    prior = uniform.pdf
    likelihood = norm.pdf
    with pytest.raises(RuntimeError):
        _ = ActiveLearning(ndim, bounds, data, run_sim_obj, prior, likelihood,
            lhs_sampler, scalar_gasp)


@pytest.mark.parametrize(
    "scalar_gasp_mean, indicator",
    [
        ('cubic', 'entropy'),
        ('linear', 'divergence')
    ]
)  
def test_ActiveLearning_init_NotImplementedError(scalar_gasp_mean, indicator):
    ndim = 1
    bounds = np.array([[0,1]])
    data  = np.array([1,2,3])
    run_sim_obj = RunSimulator(MassPointModel.run, ['coulomb_friction'])
    lhs_sampler = LHS(1)
    scalar_gasp = ScalarGaSP(1)
    prior = uniform.pdf
    likelihood = norm.pdf
    with pytest.raises(NotImplementedError):
        _ = ActiveLearning(ndim, bounds, data, run_sim_obj, prior, likelihood,
            lhs_sampler, scalar_gasp, scalar_gasp_mean=scalar_gasp_mean,
            indicator=indicator)

def test_ActiveLearning_init_ValueError():
    ndim = 1
    bounds = np.array([[0,1]])
    data  = np.array([1,2,3])
    run_sim_obj = RunSimulator(MassPointModel.run, ['coulomb_friction'])
    lhs_sampler = LHS(1)
    scalar_gasp = ScalarGaSP(1)
    prior = uniform.pdf
    likelihood = norm.pdf
    kwgs_optimizer = {"NS":50}
    with pytest.raises(ValueError):
        _ = ActiveLearning(ndim, bounds, data, run_sim_obj, prior, likelihood,
            lhs_sampler, scalar_gasp, kwgs_optimizer=kwgs_optimizer)