import numpy as np
from sklearn import linear_model
from scipy.spatial import ConvexHull
from itertools import combinations
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib import cm
import warnings
import copy
from modeling.Hinging_Hyperplane import hinge_finding_algorithm as hfa


def piecewiseLinearModel_bt(data, numOfHinges, loops, hull_acc, *existing_model):
    """
    Function as an interface to the AutoMoG algorithm.
    Since local minimum can be found, the models can vary.If the given number of models is created, the best
    is selected.

    Parameters
    ----------
    data:
    numOfHinges:    number of linear elements - 1
    loops:          number of models to be created (the beset is chosen)

    Returns
    -------

    """
    print(numOfHinges + 1, 'sections')

    if numOfHinges > 0:
        # list with all models
        if existing_model:
            models = [copy.deepcopy(existing_model[0]) for i in range(loops)]
            new_model = [False for i in range(loops)]
        else:
            models = [HHmodel(copy.deepcopy(data)) for i in range(loops)]
            new_model = [True for i in range(loops)]
        err = np.zeros(loops)

        # create hinging hyperplane model
        if existing_model:
            # if model with at least one hinge existed, add sections
            for i in range(loops):
                successful = False
                count = 0
                while not successful:
                    count += 1
                    if count == 1:
                        # print("refine model")
                        successful = models[i].refine_hingingHyperplaneModel(numOfHinges)
                    else:
                        new_model[i] = True
                        successful = models[i].hingingHyperplaneModel(numOfHinges)

                    if not successful:
                        models[i] = HHmodel(copy.deepcopy(data))
                        print(f"redo model {i}")
                err[i] = models[i].Vn_total
        else:
            for i in range(loops):
                successful = False
                count = 0
                while not successful:
                    count += 1
                    successful = models[i].hingingHyperplaneModel(numOfHinges)
                    if not successful:
                        models[i] = HHmodel(copy.deepcopy(data))
                        print(f"redo model {i}")
                err[i] = models[i].Vn_total

        idx = np.argmin(err)
        # overwrite model with the best model
        model = models[idx]
        new_model = new_model[idx]

        model.find_constraints(data)

    # for linear regression only one loop is necessary
    if numOfHinges == 0:
        new_model = True
        model = HHmodel(data)
        model.hingingHyperplaneModel(numOfHinges)

        model.find_constraints(data)

        # print('intercept: ', model.theta[0])
        # print('slope: ', model.theta[1:len(model.theta)])

    if new_model:
        # print('calculate hull')
        model.hull = CnvHull(data, hull_acc)

    # print('calculate milp')
    model.calculate_milp(data, model.theta)
    
    return model


class HHmodel(object):
    def __init__(self, data = None, normalize=True,
                 milp = None, theta = None, from_dict = False):
        # initialize via data
        if not from_dict:
            self.theta = []
            self.theta_norm = None
            self.delta = []
            self.delta_norm = None  # hinges for normalized data
            self.data = []  # data for each linear section
            self.data_norm = None
            self.sign = []
            self.numOfHinges = None
            self.Vn = []  # squared error for each linear
            self.Vn_total = None  # squared error for whole model
            self.Vn_norm = []
            self.sum_err = []  # sum  squared error
            self.sum_err_norm = []
            self.mse = None  # mean squared error
            self.noHinge_mat = {}  # list containing the names of all data sets that could not be further split
            self.constraints = []
            self.con_sign = []
            #
            self.data_points = copy.deepcopy(data)
            self.numData = len(data['Y'])
            self.normalized = normalize
            # normalization for numerical reasons
            if normalize:
                self.normData, self.normVar = normalizeData(data)
            else:
                self.normData = copy.deepcopy(data)
                self.normVar = {'X': np.ones(len(data['X'][0, :])), 'Y': 1}
                
            self.milp = None
            self.hull = None
            
        # initialize via HHmodel.from_dict
        else:
            self.milp  = MILP.from_dict(milp)
            self.data  = data
            self.theta = theta

    @classmethod
    def from_dict(cls, init_dict: dict):
        return cls(**init_dict)
        
    def to_dict(self):
        return {'milp': self.milp.to_dict(),
                'data': self.data,
                'theta': self.theta,
                'from_dict': True}
    
    def hingingHyperplaneModel(self, numOfHinges):
        """
        Function for creating a linear model using a hinging hyperplane tree.

        Parameters
        ----------
        numOfHinges:     number of data sets can be predefined in which the data room is to be divided

        Returns
        -------

        """
        self.numOfHinges = numOfHinges

        if self.numOfHinges == 0:
            # if no hinge is required, a simple linear regression is performed over the entire data set
            reg = linear_model.LinearRegression().fit(self.normData['X'], self.normData['Y'])

            coef = np.zeros(len(self.data_points['X'][0, :])+1)
            coef[0] = reg.intercept_

            for i in range(len(reg.coef_)):
                coef[i + 1] = reg.coef_[i]

            self.theta = copy.deepcopy(coef)

            successful = True

        if self.numOfHinges > 0:
            inDataset = False  # initialize the while loop
            count = 0  # iterations passed
            while not inDataset:  # iterations until a suitable first hinge is found
                count = count + 1
                if count >= 10:
                    print('No first split was found using the hinge finding Algorithm. Error in hingingHyperplaneModel')
                    return False

                # calculation of the first data split using the hinge finding algorithm
                output = hfa.hingeFindingAlgorithm(self.normData)

                inDataset = output['in_dataset']
                if inDataset:
                    self.theta.append(output['theta_plus'])
                    self.theta.append(output['theta_minus'])
                    self.data.append(output['data_plus'])
                    self.data.append(output['data_minus'])
                    self.sign.append(output['sign_plus'])
                    self.sign.append(output['sign_minus'])
                    self.delta.append(output['delta'])
                    self.Vn.append(output['Vn_plus'])
                    self.Vn.append(output['Vn_minus'])
                    self.sum_err.append(output['sum_err_plus'])
                    self.sum_err.append(output['sum_err_minus'])

            # if more than 1 hinge is required, the loop is started
            if self.numOfHinges >= 2:

                successful = self.create_hinges()

            else:
                print('specified number of hinges reached')
                successful = True

        self.process_results()
        return successful

    def refine_hingingHyperplaneModel(self, numOfHinges):
        """
        Function for refining an existing model using a hinging hyperplane tree.

        Parameters
        ----------
        numOfHinges:     number of data sets can be predefined in which the data room is to be divided

        Returns
        -------

        """
        self.delta = copy.deepcopy(self.delta_norm)
        self.theta = copy.deepcopy(self.theta_norm)
        self.data = copy.deepcopy(self.data_norm)
        self.Vn = copy.deepcopy(self.Vn_norm)

        self.numOfHinges = numOfHinges

        successful = self.create_hinges()

        self.process_results()
        return successful

    def create_hinges(self):
        # initializing the values for a further division of the data sets
        decision = []  # variable that contains the decision for the next data split
        no_stop = True  # logical value for the while loop

        iter1 = 0  # variable to count the passes through the while loop

        # in this while loop the number of linear elements is increased by 1 in each iteration
        while no_stop:
            iter1 += 1  # counting the loop passes
            if iter1 == 1:
                # in the first run it must be determined which element will be split up again by the HFA
                decision = self.greedy_decision()  # calculation of the element with the worst fit

            if decision == 'break':
                print("no element could be split")
                return False

            # next data split
            inDataset = False  # to start the loop the value is set to False
            iter2 = 0  # count the passes of the inner loop
            while not inDataset:
                # if the hinge determined by the Hinge Finding Algorithm is not suitable, it will be used again
                iter2 += 1
                output = hfa.hingeFindingAlgorithm(self.data[decision])
                inDataset = output['in_dataset']
                if iter2 > 20:
                    # after more than 20 iterations it is unlikely that a suitable hinge for this data set can
                    # be found.
                    break

            # update the data sets for which no hinge can be found
            # if in_dataset == 0 then the data set is added to no_hinge_mat
            self.noHingeMat_update(inDataset, decision, self.data[decision])

            if inDataset:
                # if a new data split is found, it is added
                self.addNewSplit(decision, output)
            else:
                if iter1 > self.numOfHinges + 5:
                    print('could not find', self.numOfHinges, 'hinges. terminated with ', len(self.theta) - 1,
                          'hinges')
                    return False
                continue

            decision = self.greedy_decision()  # calculate the element with the worst fit

            if len(self.noHinge_mat.keys()) >= len(self.data):
                # if all data sets can no longer be split up, the algorithm must be aborted
                print('maximum number of possible data splits is reached using the hingingHyperplaneModel')
                return False

            if len(self.theta) == self.numOfHinges + 1:
                # if the specified number of hinges reached
                print('specified number of hinges reached')
                return True

    def process_results(self):
        # save normalized hinges
        self.delta_norm = copy.deepcopy(self.delta)
        self.data_norm = copy.deepcopy(self.data)
        self.theta_norm = copy.deepcopy(self.theta)

        # calculate Vn of the whole model
        if self.numOfHinges == 0:
            # denormalize
            self.data, self.theta, self.delta = denormalize(self.normData, self.normVar, self.theta, self.delta)

            self.Vn, self.sum_err = sosqerr(self.data, self.theta, self.normVar)

            self.Vn_total = self.Vn  # if there is no hinge these are the same
            self.mse = self.Vn_total/len(self.data_points['Y'])

        else:
            self.Vn_norm = copy.deepcopy(self.Vn)
            self.Vn = []
            self.sum_err_norm = copy.deepcopy(self.sum_err)
            self.sum_err = []

            # denormalize
            j = 0
            for i in range(len(self.data)):
                if not isinstance(self.delta[i], list):
                    if i > 1:
                        j += 1
                    if len(self.data) > 2:
                        if i != 1:
                            self.data[i], self.theta[i], self.delta[j] = denormalize(self.data[i], self.normVar,
                                                                                     self.theta[i],
                                                                                     self.delta[j])
                        else:
                            self.data[i], self.theta[i], _ = denormalize(self.data[i], self.normVar, self.theta[i], [])

                    else:
                        if len(self.delta) == 1:
                            self.delta = self.delta[0]  # if delta is an array in an list, unpack it

                        if i+1 == len(self.data) and len(self.data) != 1:  # don't normalize delta again
                            self.data[i], self.theta[i], _ = denormalize(self.data[i], self.normVar, self.theta[i],
                                                                                  [])
                        else:
                            self.data[i], self.theta[i], self.delta = denormalize(self.data[i], self.normVar, self.theta[i],
                                                                                  self.delta)

                else:
                    if i+1 == len(self.data) and len(self.data) != 1:  # don't normalize delta again
                        self.data[i], self.theta[i],  self.delta[i] = denormalize(self.data[i], self.normVar, self.theta[i],
                                                                                  self.delta[i])
                    else:
                        self.data[i], self.theta[i], self.delta[i] = denormalize(self.data[i], self.normVar, self.theta[i],
                                                                                 self.delta[i])

                Vn, sum_err = sosqerr(self.data[i], self.theta[i], self.normVar)
                self.Vn.append(Vn)
                self.sum_err.append(sum_err)

                self.Vn_total = sum(self.Vn)
                self.mse = self.Vn_total/len(self.data_points['Y'])

    def find_constraints(self, data):
        if len(self.data) > 2:
            constraints = []
            con_sign = []

            for i in range(self.numOfHinges + 1):
                data_loop = copy.deepcopy(data)
                constraints.append([])
                con_sign.append([])
                constraint = {}

                # if isinstance(self.delta[0], list) or type(self.delta[0]) is np.ndarray or (len(self.delta) == 1):
                #     num_theor_const = len(self.delta[i])
                # else:
                #     num_theor_const = 1
                if isinstance(self.delta[i], list) or (len(self.delta) == 1):
                    num_theor_const = len(self.delta[i])
                else:
                    num_theor_const = 1

                if num_theor_const > 1:
                    for j in range(num_theor_const - 1, -1, -1):
                        data_plus, data_minus = hfa.dataSeparation(data_loop, self.delta[i][j])

                        if self.sign[i][j] > 0:
                            if len(data_minus['Y']) > 0:
                                constraint[j] = 1
                            else:
                                constraint[j] = 0

                            data_loop = copy.deepcopy(data_plus)
                        else:
                            if len(data_plus['Y']) > 0:
                                constraint[j] = 1
                            else:
                                constraint[j] = 0

                            data_loop = copy.deepcopy(data_minus)

                else:  # num_tehor_constr == 1
                    if isinstance(self.delta[i], list):
                        delta = copy.deepcopy(self.delta[i][0])
                    else:
                        delta = copy.deepcopy(self.delta[i])
                    data_plus, data_minus = hfa.dataSeparation(data_loop, delta)

                    if self.sign[i] > 0:
                        if len(data_minus['Y']) > 0:
                            constraint[0] = 1
                        else:
                            constraint[0] = 0
                    else:
                        if len(data_plus['Y']) > 0:
                            constraint[0] = 1
                        else:
                            constraint[0] = 0

                keys = list(constraint.keys())
                keys.sort()
                for key in keys:
                    if constraint[key]:
                        constraints[i].append(self.delta[i][key])
                        if isinstance(self.sign[i], list):
                            con_sign[i].append(self.sign[i][key])
                        else:
                            con_sign[i].append(self.sign[i])

            self.constraints = copy.deepcopy(constraints)
            self.con_sign = copy.deepcopy(con_sign)

        else:
            if self.numOfHinges > 0:
                constraints = []
                con_sign = []

                for i in range(self.numOfHinges + 1):
                    if i == 0:
                        constraints.append([copy.deepcopy(self.delta)])
                        con_sign.append([copy.deepcopy(self.sign[i])])
                    elif i == self.numOfHinges:
                        constraints.append([copy.deepcopy(self.delta)])
                        con_sign.append([copy.deepcopy(self.sign[i])])

                self.constraints = constraints
                self.con_sign = con_sign

    def noHingeMat_update(self, inDataset, decision, data):
        if not inDataset:
            if decision != []:
                self.noHinge_mat[decision] = data

    def greedy_decision(self):
        """
        In this function the decision for the data set which is to be used as next split is made. The data set which is
        next is the one with the worst value for Vn.
        """

        if len(self.noHinge_mat.keys()) != 0:
            # if some data sets should not be used for a further split, the name of these elements is read
            no_hinge_name = self.noHinge_mat.keys()
            no_hinge_number = len(self.noHinge_mat.keys())
        else:
            no_hinge_number = []

        data_number = len(self.data)  # number of linear elements
        Vn_all = []  # initialize von Vn

        for i in range(data_number):
            if no_hinge_number != []:  # if a data area should be excluded it is filtered out here
                if i in no_hinge_name:
                    Vn_all.append(-1)
                    continue

            # to not split elements with too few data points these are sorted out here
            if len(self.data[i]['Y']) < 2 * len(self.data[i]['X'][0, :]):
                # data set is added to no_hinge mat (can't exist yet, otherwise the previous if condition would have
                # triggered)
                self.noHinge_mat[i] = self.data[i]
                no_hinge_name = self.noHinge_mat.keys()  # new determination of the names of all data sets
                no_hinge_number = len(self.noHinge_mat.keys())  # number of data sets
                Vn_all.append(-1)
                continue

            Vn_all.append(self.Vn[i])

        if not Vn_all:
            # if no element can be split, the HingingHyperplaneAlgorithm is aborted
            decision = 'break'
        else:
            # otherwise the element with the maximum Vn is selected for the next data split
            decision = np.argmax(np.asarray(Vn_all))

        return decision

    def addNewSplit(self, decision, output):
        """
        Function to extend all data sets with the data split determined by the HFA.

        Parameters
        ----------
        decision
        output

        Returns
        -------

        """

        self.theta[decision] = output['theta_plus']
        self.theta.insert(decision + 1, output['theta_minus'])

        self.data[decision] = output['data_plus']
        self.data.insert(decision + 1, output['data_minus'])

        if not isinstance(self.sign[decision], list):
            self.sign.insert(decision + 1, [self.sign[decision], output['sign_minus']])
            self.sign[decision] = [self.sign[decision], output['sign_plus']]
        else:
            temp = copy.deepcopy(self.sign[decision])
            temp.append(output['sign_minus'])
            self.sign.insert(decision + 1, temp)
            self.sign[decision].append(output['sign_plus'])

        if not isinstance(self.delta[0], list):
            self.delta = [self.delta, self.delta]

        temp = copy.deepcopy(self.delta[decision])
        temp.append(output['delta'])
        self.delta.insert(decision + 1, copy.deepcopy(temp))
        self.delta[decision] = copy.deepcopy(temp)

        self.Vn[decision] = output['Vn_plus']
        self.Vn.insert(decision + 1, output['Vn_minus'])

        self.sum_err[decision] = output['sum_err_plus']
        self.sum_err.insert(decision + 1, output['sum_err_minus'])

    def calculate_milp(self, data, theta):
        self.milp = MILP(theta)
        self.milp.cal_corner_multi_dim(data, self.hull, self.numOfHinges, self.constraints, self.con_sign)


def denormalize(data, normVar, theta, delta):
    """
    Function to reverse the normalization.

    Parameters
    ----------
    delta
    data:           data['X']: input data, data['Y']: output data
    normVar:        difference between maximum and minimum value for each column in data before normalization
    theta:          contains the coefficients of the linear elements of the data

    Returns
    -------
    theta_denom:    coefficients for the data before normalization

    """
    if len(normVar['X']) > len(data['X'][0, :]) and any(data['X'][:, 0] != 1):
        normVar['X'] = normVar['X'][1:len(normVar['X'])]

    intercept = theta[0]
    slope = np.asarray(theta[1:])
    for i in range(len(normVar['X'])):
        if len(data['X'][0, :]) == len(normVar['X'])+1:
            data['X'][:, i+1] = data['X'][:, i] * normVar['X'][i]
        else:
            data['X'][:, i] = data['X'][:, i] * normVar['X'][i]
        slope[i] = slope[i] * (normVar['Y'] / normVar['X'][i])

        if delta != []:  # delta is not None
            if isinstance(delta[0], list) or type(delta[0]) is np.ndarray or len(delta) == 1:
                for j in range(len(delta)):
                    delta[j][i + 1] = delta[j][i + 1] / normVar['X'][i]
            else:
                delta[i + 1] = delta[i + 1] / normVar['X'][i]

    data['Y'] = data['Y'] * normVar['Y']
    intercept = intercept * normVar['Y']

    theta_denom = np.zeros(len(theta))
    theta_denom[0] = intercept
    theta_denom[1:] = slope

    return data, theta_denom, delta


def normalizeData(data):
    """
    Function to normalize the data set given in data. The data
    in each dimension is divided maximum

    Parameters
    ----------
    data:        the data set to be normalized

    Returns
    -------
    data:       normalized data
    normVar:    difference between maximum and minimum value for each column in data

    """

    var_num = len(data['X'][0, :])  # dimensions input
    max_X = np.zeros(var_num)
    for i in range(var_num):
        max_X[i] = np.max(data['X'][:, i])
        data['X'][:, i] = data['X'][:, i] / max_X[i]

    delta_max_Y = np.max(data['Y'])
    data['Y'] = data['Y'] / delta_max_Y
    normVar = {'X': max_X, 'Y': delta_max_Y}

    return data, normVar


def linear_function(x, theta):
    """
    Function for calculating the y values of the linear function with the Coefficients theta.

    Parameters
    ----------
    theta:      coefficients of the linear element
    x:          vector/matrix of the input data

    Returns
    -------
    y:   values of the linear function

    """
    # add a column with 1 for matrix multiplication if necessary
    if any(x[:, 0] != 1) and len(x[0, :]) != len(theta):
        x = np.column_stack((np.ones(len(x[:, 0])), x))

    y = np.dot(x, theta)

    return y


def sosqerr(data, theta, *normVar):
    """
    Function to determine the error of a linear element.

    Parameters
    ----------
    theta:      coefficients
    data:       contains the data points

    Returns
    -------
    Vn:         used to calculate the error measure defined in "On the Hinge-Finding Algorithm for Hinging
                Hyperplanes" by P. Pucar and J. Sjöberg.
    sum_err:    sum of the absolute errors
    """

    # Calculation of the deviation of the linear model for each data point
    if len(data['X'][0, :]) == 2 and data['X'][0, 0] == 1 and len(normVar[0]['X']) == 2 and normVar[0]['X'][0] == 1:
        data['X'] = np.column_stack((np.ones(len(data['X'][:, 0])), data['X']))  # necessary to avoid error

    error = linear_function(data['X'], theta) - data['Y']
    squared_err = np.sum(error ** 2)
    sum_squared_err = np.sum(np.abs(error))

    return squared_err, sum_squared_err


class MILP(object):
    def __init__(self, theta, 
                 points = [], hull_lin_elements = {}):
        
        self.theta = theta
        self.points = points
        self.hull_lin_elements = hull_lin_elements
        
    @classmethod
    def from_dict(cls, init_dict: dict):
        return cls(**init_dict)
    
    def to_dict(self):
        return {'theta': self.theta,
                'points': self.points,
                'hull_lin_elements': self.hull_lin_elements}

    def cal_corner_multi_dim(self, data, hull, numOfHinges, model_constraints, model_con_sign):
        dim = len(data['X'][0, :])

        ad_boundaries = hull.lower_boundaries_opt.T
        cons_sig_ad_bound = np.ones(len(ad_boundaries[0, :]))

        ad_boundaries = np.column_stack((ad_boundaries, hull.upper_boundaries_opt.T))
        ad_boundaries[0, :] = -1 * ad_boundaries[0, :]

        cons_sig_ad_bound = np.append(cons_sig_ad_bound, -1 * np.ones(len(hull.upper_boundaries_opt[:, 0])))

        hull_boundaries = hull.hull_boundaries.T
        signum_hull_boundaries = -1 * np.ones(len(hull_boundaries[0, :]))
        hull_constraints = {}
        signum_cons = {}
        test_model_constraints = copy.deepcopy(model_constraints)

        constraints = {}
        cons_sign = {}
        if numOfHinges > 0:
            for i in range(numOfHinges + 1):
                try:
                    for x in range(len(test_model_constraints[i])):
                        test_model_constraints[i][x] = np.roll(test_model_constraints[i][x], -1)

                    constraints[i] = np.column_stack((ad_boundaries, np.asarray(model_constraints[i]).T))
                    hull_constraints[i] = np.column_stack((hull_boundaries, np.asarray(test_model_constraints[i]).T))

                except:
                    test=1
                cons_sign[i] = np.append(cons_sig_ad_bound, np.asarray(model_con_sign[i]))
                signum_cons[i] = np.append(signum_hull_boundaries, np.asarray(model_con_sign[i]))
        else:
            constraints[0] = ad_boundaries
            cons_sign[0] = cons_sig_ad_bound
            hull_constraints[0] = hull.hull_boundaries.T
            signum_cons[0] = signum_hull_boundaries

        points = {}
        test_points = {}
        for i in range(numOfHinges + 1):
            number_boundaries = len(signum_cons[i])
            combs = value_comb_dim(number_boundaries, dim)

            intersects = np.zeros((len(combs[:, 0]), dim))
            boundaries_intersects = np.zeros((len(combs[:, 0]), dim))
            count = 0

            for comb in combs:
                boundary1 = constraints[i][:, comb[0]]
                boundary2 = constraints[i][:, comb[1]]
                bd1 = hull_constraints[i][:, comb[0]]
                bd2 = hull_constraints[i][:, comb[1]]
                if dim == 2:
                    A = np.column_stack((boundary1[1:], boundary2[1:]))
                    b = np.vstack((-1 * boundary1[0], -1 * boundary2[0]))
                    A_test = np.column_stack((bd1[:-1], bd2[:-1]))
                    b_test = np.vstack((-1 * bd1[-1], -1 * bd2[-1]))
                if dim == 3:
                    boundary3 = constraints[i][:, comb[2]]
                    bd3 = hull_constraints[i][:, comb[2]]
                    A = np.column_stack((boundary1[1:], boundary2[1:], boundary3[1:]))
                    b = np.vstack((-1 * boundary1[0], -1 * boundary2[0], -1 * boundary3[0]))
                    A_test = np.column_stack((bd1[:-1], bd2[:-1], bd3[:-1]))
                    b_test = np.vstack((-1 * bd1[-1], -1 * bd2[-1], -1 * bd3[-1]))
                try:
                    #intersects[count, :] = np.linalg.inv(A.T).dot(b).T
                    boundaries_intersects[count, :] = np.linalg.inv(A_test.T).dot(b_test).T
                except:
                    #intersects[count, :] = np.array([np.nan for i in range(dim)])
                    boundaries_intersects[count, :] = np.array([np.nan for i in range(dim)])
                count += 1

            #X_mesh = np.column_stack((np.ones(len(combs[:, 0])), intersects))
            hull_X_mesh = np.column_stack((boundaries_intersects, np.ones(len(combs[:, 0]))))

            for j in range(number_boundaries):
                #boundary = constraints[i][:, j]
                #upper_or_lower = cons_sign[i][j]
                hull_bound = hull_constraints[i][:, j]
                test_up_o_lo = signum_cons[i][j]
                #X_mesh = boundary_cut(X_mesh, boundary, upper_or_lower)
                hull_X_mesh = boundary_cut(hull_X_mesh, hull_bound, test_up_o_lo)

            #X_mesh = X_mesh[:, 1:]
            hull_X_mesh = hull_X_mesh[:, :-1]
            conv_hull = ConvexHull(hull_X_mesh)

            points[i] = {'X': calculate_corner_points(conv_hull)}

        # store points in object
        self.points = points
        # calculate Y values
        self.__calc_value_at_point(self.theta)

    def __calc_value_at_point(self, theta):
        points = copy.deepcopy(self.points)
        for i in points.keys():
            points_loop = np.column_stack((np.ones(len(points[i]['X'][:, 0])), points[i]['X']))
            if isinstance(theta, list):
                self.points[i]['Y'] = np.dot(points_loop, theta[i])

            else:
                self.points[i]['Y'] = np.dot(points_loop, theta)


def calculate_corner_points(hull):
    vertices = hull.vertices
    points = hull.points

    cornerPoints = np.zeros((len(vertices), len(hull.points[0, :])))
    for i in range(len(vertices)):
        for j in range(len(hull.points[0, :])):
            cornerPoints[i, j] = points[vertices[i], j]

    return cornerPoints


def value_comb_dim(number_boundaries, dim):
    comb = np.asarray(list(combinations(range(number_boundaries), dim)))
    return comb


def boundary_cut(X_mesh, boundary, upper_or_lower):
    """
    Function to check a single boundary condition for the given data area X_mesh

    Parameters
    ----------
    X_mesh
    boundary
    upper_or_lower

    Returns
    -------
    X_mesh

    """
    warnings.filterwarnings('ignore')
    if upper_or_lower == -1:
        X_mesh = X_mesh[np.round(X_mesh.dot(boundary), 10) <= 0, :]
    elif upper_or_lower == 1:
        X_mesh = X_mesh[np.round(X_mesh.dot(boundary), 10) >= 0, :]

    return X_mesh


class CnvHull(object):
    def __init__(self, data, hull_acc):
        self.hull = None
        self.cornerPoints = None
        self.reduced_hull = None
        self.reduced_cornerPoints = None
        self.deletedData = []
        self.boundaries = []
        self.upper_boundaries = []
        self.lower_boundaries = []
        self.accuracy = hull_acc
        self.upper_boundaries_opt = []
        self.lower_boundaries_opt = []
        self.hull_boundaries = []
        # calculate convex hull
        self.convexHull(data['X'])

    def convexHull(self, data):

        if len(data[0]) > 1:  # convex hull only if dimension of input greater 1
            self.hull = ConvexHull(data)
            self.cornerPoints = calculate_corner_points(self.hull)

            if self.accuracy > 0:
                self.reduceComplexity(self.accuracy)
            else:
                self.reduced_hull = self.hull

            #self.__find_boundaries_matrix()
            #self.__lower_upper_check(data)

            # boundaries for optimization problem
            #self.__boundaries_opt(data)
            self.__find_hull_boundaries(data)

        else:
            self.upper_boundaries = np.max(data)
            self.lower_boundaries = np.min(data)
            self.upper_boundaries_opt = np.max(data)
            self.lower_boundaries_opt = np.min(data)

    def reduceComplexity(self, accuracy):
        volume = self.hull.volume
        reduceComplexity = True

        self.reduced_cornerPoints = copy.deepcopy(self.cornerPoints)
        # create a hull only with the corner points
        self.reduced_hull = ConvexHull(self.reduced_cornerPoints)

        count = 0
        ac = 0
        while reduceComplexity:

            rel_delta_volume = np.ones(len(self.reduced_cornerPoints))
            for i in range(len(self.reduced_cornerPoints[:, 0])):
                # compute convex hulls with one corner point removed
                data_temp = np.delete(self.reduced_hull.points, np.where(self.reduced_hull.points ==
                                                                         self.reduced_cornerPoints[i, 0])[0][0], 0)
                hull_temp = ConvexHull(data_temp)
                rel_delta_volume[i] = (volume - hull_temp.volume) / volume

            ac = np.min(rel_delta_volume)
            if ac < accuracy:
                self.deletedData.append(self.reduced_hull.points[np.where(
                    self.reduced_hull.points == self.reduced_cornerPoints[np.where(rel_delta_volume == np.min(
                        rel_delta_volume))[0][0], 0])[0][0]])
                data = np.delete(self.reduced_hull.points, np.where(
                    self.reduced_hull.points == self.reduced_cornerPoints[np.where(rel_delta_volume == np.min(
                        rel_delta_volume))[0][0], 0])[0][0], 0)
                self.reduced_hull = ConvexHull(data)
                self.reduced_cornerPoints = calculate_corner_points(self.reduced_hull)
                count = count + 1
                self.accuracy = ac

            else:
                self.deletedData = np.asanyarray(self.deletedData)
                print(count, 'outer points discarded')
                reduceComplexity = False

    def __find_hull_boundaries(self, data):
        hull_bounds = self.reduced_hull.equations

        for i in range(len(data[0, :])):
            # Add maximum and minimum values for each dimension to boundary conditions
            hull_bounds = np.vstack((hull_bounds, np.zeros(len(hull_bounds[0, :]))))
            hull_bounds[-1, i] = 1
            hull_bounds[-1, -1] = -np.max(data[:, i])
            hull_bounds = np.vstack((hull_bounds, np.zeros(len(hull_bounds[0, :]))))
            hull_bounds[-1, i] = -1
            hull_bounds[-1, -1] = np.min(data[:, i])

        self.hull_boundaries = hull_bounds
        self.upper_boundaries_opt = hull_bounds
        self.lower_boundaries_opt = hull_bounds

    def __find_boundaries_matrix(self):
        """
        X_end = X_0 * m_0 + X_1 * m_1 + ... + X_(end-1) * m_(end-1) + n
        Calculates the parameters for the boundaries.
        boundaries = [m_0, m_1, ..., m_(end-1), n]
        """
        simplices = self.reduced_hull.simplices
        points = self.reduced_hull.points

        coeff = []
        for s in simplices:
            p = points[s]
            A = np.column_stack((p[:, 0:-1], np.ones(len(p))))
            B = p[:, -1]
            coeff.append(np.linalg.inv(A).dot(B))

        self.boundaries = np.asarray(coeff)

    def __lower_upper_check(self, data):
        boundaries = self.boundaries

        data_mean = np.zeros(len(data[0, :]))
        for i in range(len(data_mean)):
            data_mean[i] = np.mean(data[:, i])

        X1 = copy.deepcopy(data_mean)
        X1[-1] = 1
        X2 = data_mean[-1]

        upper_boundaries = []
        lower_boundaries = []
        for i in range(len(boundaries[:, 0])):
            if X1.dot(boundaries[i, :]) <= X2:
                lower_boundaries.append(boundaries[i, :])
            else:
                upper_boundaries.append(boundaries[i, :])

        self.upper_boundaries = np.asarray(upper_boundaries)
        self.lower_boundaries = np.asarray(lower_boundaries)

    def __boundaries_opt(self, data):
        # switch boundaries
        lower_boundaries = copy.deepcopy(self.upper_boundaries)
        upper_boundaries = copy.deepcopy(self.lower_boundaries)

        lower_boundaries = np.column_stack((-1*lower_boundaries[:, -1], lower_boundaries[:, 0:-1],
                                            -1*np.ones(len(lower_boundaries[:, 0]))))
        upper_boundaries = np.column_stack((-1*upper_boundaries[:, -1], upper_boundaries[:, 0:-1],
                                            -1*np.ones(len(upper_boundaries[:, 0]))))

        for i in range(len(data[0, :])):
            # For negative maximum and minimum values, upper and lower must be swapped for the boundary conditions
            if np.min(data[:, i]) < 0:
                upper_boundaries = np.vstack((upper_boundaries, np.zeros(len(upper_boundaries[0, :]))))
                upper_boundaries[-1, i+1] = 1/np.min(data[:, i])
                upper_boundaries[-1, 0] = 1
            elif min(data[:, i]) == 0:
                lower_boundaries = np.vstack((lower_boundaries, np.zeros(len(lower_boundaries[0, :]))))
                lower_boundaries[-1, i + 1] = 1
                lower_boundaries[-1, 0] = 0
            else:
                lower_boundaries = np.vstack((lower_boundaries, np.zeros(len(lower_boundaries[0, :]))))
                lower_boundaries[-1, i + 1] = 1/np.min(data[:, i])
                lower_boundaries[-1, 0] = 1

            if np.max(data[:, i]) < 0:
                lower_boundaries = np.vstack((lower_boundaries, np.zeros(len(lower_boundaries[0, :]))))
                lower_boundaries[-1, i + 1] = 1 / np.max(data[:, i])
                lower_boundaries[-1, 0] = 1
            elif np.max(data[:, i]) < 0:
                upper_boundaries = np.vstack((upper_boundaries, np.zeros(len(upper_boundaries[0, :]))))
                upper_boundaries[-1, i + 1] = 1
                upper_boundaries[-1, 0] = 0
            else:
                upper_boundaries = np.vstack((upper_boundaries, np.zeros(len(upper_boundaries[0, :]))))
                upper_boundaries[-1, i + 1] = 1 / np.max(data[:, i])
                upper_boundaries[-1, 0] = 1

        self.upper_boundaries_opt = upper_boundaries
        self.lower_boundaries_opt = lower_boundaries


def plot_model(ax, model, dim, x_plot, y_plot, color_scatter, label, plot_scatter=False,
               limits = {'x': (None,None), 
                         'y': (None,None), 
                         'z': (None,None)}):
    milp = model.milp
    if dim == 2:
        # create colors
        colors = iter(cm.Greys(np.linspace(0.5, 0.8, len(milp.points.keys()))))
        for i in milp.points.keys():
            x1 = list(milp.points[i]['X'][:, 0])
            x2 = list(milp.points[i]['X'][:, 1])
            y = list(milp.points[i]['Y'])
            verts = [list(zip(x1, x2, y))]
            polygon = Poly3DCollection(verts, alpha=0.7, linewidths=0.5, color=next(colors))
            polygon.set_edgecolor('k')
            ax.add_collection3d(polygon)

            # plot corner points
            x1 = milp.points[i]['X'][:, 0]
            x2 = milp.points[i]['X'][:, 1]
            y = milp.points[i]['Y']
            ax.scatter(x1, x2, y, color='k', s=10, alpha=1)

        if plot_scatter:
            # plot points
            ax.scatter(x_plot[:, 0], x_plot[:, 1], y_plot, 
                       color=color_scatter, alpha=0.5, label=label)

        if not all(limits['x']):
            limits['x'] = np.min(x_plot[:, 0]), np.max(x_plot[:, 0])
        else:
            limits['x'] = min(limits['x'][0],np.min(x_plot[:, 0])), max(limits['x'][1],np.max(x_plot[:, 0]))
        
        if not all(limits['y']):
            limits['y'] = np.min(x_plot[:, 1]), np.max(x_plot[:, 1])
        else:
            limits['y'] = min(limits['y'][0],np.min(x_plot[:, 1])), max(limits['y'][1],np.max(x_plot[:, 1]))
            
        if not all(limits['z']):
            limits['z'] = np.min(y_plot), np.max(y_plot) 
        else:
            limits['z'] = min(limits['z'][0],np.min(y_plot)), max(limits['z'][1],np.max(y_plot))
            
        ax.set_xlim(limits['x'])
        ax.set_ylim(limits['y'])
        ax.set_zlim(limits['z'])
            
        # plt.show()
    elif dim == 3:
        plots = combinations(range(dim), 2)
        # for pl in plots:
        pl = list(plots)[0]
        # create colors
        colors = iter(cm.Greys(np.linspace(0.5, 0.8, len(milp.points.keys()))))
        for i in milp.points.keys():

            conv_hull = ConvexHull(milp.points[i]['X'][:, [pl[0], pl[1]]])
            points = np.zeros((len(conv_hull.vertices), len(conv_hull.points[0, :])))
            for k in range(len(conv_hull.vertices)):
                for j in range(len(conv_hull.points[0, :])):
                    points[k, j] = conv_hull.points[conv_hull.vertices[k], j]

            x1 = list(points[:, 0])
            x2 = list(points[:, 1])

            y = []
            for idx in range(len(x1)):
                tmp = list(milp.points[i]['Y'][milp.points[i]['X'][:, pl[0]] == x1[idx]])
                y.append(tmp[0])

            # plot corner points
            ax.scatter(x1, x2, y, color='k', s=10, alpha=1)

            verts = [list(zip(x1, x2, y))]
            polygon = Poly3DCollection(verts, alpha=0.7, linewidths=0.5, color=next(colors))
            polygon.set_edgecolor('k')
            ax.add_collection3d(polygon)

        if plot_scatter:
            # plot points
            ax.scatter(x_plot[:, pl[0]], x_plot[:, pl[1]], y_plot, color=color_scatter, alpha=0.5)

        if not all(limits['x']):
            limits['x'] = np.min(x_plot[:, pl[0]]), np.max(x_plot[:, pl[0]])
        else:
            limits['x'] = min(limits['x'][0],np.min(x_plot[:, pl[0]])), max(limits['x'][1],np.max(x_plot[:, pl[0]]))
            
        if not all(limits['y']):
            limits['y'] = np.min(x_plot[:, pl[1]]), np.max(x_plot[:, pl[1]])
        else:
            limits['y'] = min(limits['y'][0],np.min(x_plot[:, pl[1]])), max(limits['y'][1],np.max(x_plot[:, pl[1]]))
            
        if not all(limits['z']):
            limits['z'] = np.min(y_plot), np.max(y_plot)
        else:
            limits['z'] = min(limits['x'][0],np.min(y_plot)), max(limits['x'][1],np.max(y_plot))
            
        ax.set_xlim(limits['x'])
        ax.set_ylim(limits['y'])
        ax.set_zlim(limits['z'])
            
    return limits