"""Contains all evaluation classes and functions.

class LinearRegressionFitting: Represents a Linear Regression Fitting of a 2D x,y plot
"""

import os
import pandas as pd
import numpy as np
import altair as alt
import datetime



# Colors
rwthcolors= {'blau' : '#00549F',
             'schwarz' : '#000000',
             'magenta' : '#E30066',
             'gelb' : '#FFED00',
             'petrol' : '#006165',
             'türkis' : '#0098A1',
             'grün' : '#57AB27',
             'maigrün' : '#BDCD00',
             'orange' : '#F6A800',
             'rot' : '#CC071E',
             'bordeaux' : '#A11035',
             'violett' : '#612158',
             'lila' : '#7A6FAC',
            }
colorlist = [rwthcolors['blau'],
             rwthcolors['türkis'],
             rwthcolors['grün'],
             rwthcolors['orange'],
             rwthcolors['violett'],
            ]






# Class Definitions
class LinearRegressionFitting:
    """Represents a Linear Regression Fitting of a 2D x,y plot"""
    
    def __init__(self, PandaDataFrame):
        """Initializes the instance with a PandaDataFrame
        
        Attributes:
            x
            y
            slope
            intercept
            residuals
            diagnostics
            model
            R2
            df
            chart
        """
        self.input = PandaDataFrame
        self.x = PandaDataFrame.to_numpy()[:, 0]
        self.y = PandaDataFrame.to_numpy()[:, 1]

        # https://numpy.org/doc/stable/reference/generated/numpy.polyfit.html
        [self.slope, self.intercept], self.residuals, *self.diagnostics = np.polyfit(self.x, self.y, 1, full = True)
        
        # model = expression of polynom from polyfit, here basicly
        # y = model(x) <=> same as writing  y = slope * x + intercept
        self.model = np.poly1d([self.slope, self.intercept])
        
        # Bestimmtheitsmaß / Coefficient of determination / R²
        #  1 - ( residual sum of squares / SUM( (yi - y_mean)² ) )
        self.R2 = 1 - (self.residuals[0] /
                       np.sum( pow( self.y - np.average( self.y ), 2) )
                      )
        
        self.df = pd.DataFrame({'x': self.x,'f(x)': self.model(self.x)})
        self.chart = alt.Chart(self.df).mark_line().encode(x="x", y='f(x)')

class RT_Evaluation(LinearRegressionFitting):
    def __init__(self, PandaDataFrame, contactlenght):
        super().__init__(PandaDataFrame) 
        #self.input
        #self.x
        #self.y
        #self.slope
        #self.intercept
        #self.residuals
        #self.diagnostics
        #self.model
        #self.R2
        #self.df
        model_df_x = np.linspace(0, self.x[-1])
        self.df = pd.DataFrame({'x': model_df_x,'f(x)': self.model(model_df_x)})
        #self.chart
        self.chart = alt.Chart(self.df).mark_line().encode(x="x", y='f(x)')
        
        self.contactlenght = contactlenght

        # Kontaktwiderstand Rc [Ohm*mm]
        # = RT(d=0)/2 [Ohm] * Contactlenght [µm/1000] 
        self.Rc = (self.intercept/2) * (contactlenght/1000)

        # Schichtwiderstand [Ohm/sq = Ohm]
        # = slope [Ohm/µm] * Contactlenght [µm]
        self.Rsh = self.slope * contactlenght

        # Transferlänge LT [mm] RT(d) = 0
        # = slope [Ohm/µm] / RT(d=0) [Ohm] / 1000
        #self.LT = self.intercept / self.slope / 1000

        # Transferlänge LT [µm] RT(d) = 0
        # = slope [Ohm/µm] / RT(d=0) [Ohm]
        self.LT = self.intercept / self.slope / 2
        
        # LT = sqrt(rhoc/Rsh).
        # "Semiconductor Material and Device Characterization Third Edition",
        # D. Schroder, p. 140, Eq. 3.21 
        # Spezifischer Kontaktwiderstand rhoc = LT² * Rsh
        # = Ohm cm² | µm²*0.00000001 = cm²
        self.rhoc = self.LT*self.LT * self.Rsh * 1E-4 * 1E-4

        
class TlmMeasurement(object):
    def __init__(self, filelist, distances = (5, 10, 15, 20, 50), contactlenght = 50):
        """filelist as tuple
        distances as tuple ## Abstände der TLM in µm
        contactlenght # Kontaktweite der TLM Strukturen in µm 
        """
        self._creation_date = datetime.datetime.now()
        self.filelist = filelist
        self.path, self.files = self.importfiles(filelist)
        self.distances = distances
        self.contactlenght = contactlenght
        self.df = self.construct_dataframes(self.filelist)
        #self.df_org = self.df
        self.R, self.lin_reg_charts, self.R_statistics = self.R_from_lin_reg()
        self.RT0 = pd.DataFrame({'d/µm':self.distances, 'R_T/Ohm':self.R})
        
        self.eval0 = RT_Evaluation(self.RT0, self.contactlenght)
        self.eval1, self.eval2 = self.find_RT1_RT2()
        self.refined = False
        #self.results = self.results()
        
    def importfiles(self, filelist):
        if not len(filelist) == 5:
            raise Exception("Files Missing - I need 5 files for TLM or CTLM")
        filenames = []
        name = ""
        measurement = []
        for i in range(len(filelist)):
            filenames.append(os.path.split(filelist[i])[1])
            name, end = filenames[i].rsplit(sep="_",maxsplit=1)
            #print(name, "_", end)
            if i == 0:
                firstname = name
            elif not name == firstname:
                print("Files:", filenames)
                raise Exception("Filenames differ")
            measurement.append(end.split(sep=".")[0])
            #print(measurement)
            #print(i, len(files)-1)
            if (i == len(filelist)-1)  and (not measurement == ['1', '2', '3', '4', '5']):
                print("Files:", filenames)
                raise Exception("Not Measurement _1 to _5?")
        path = os.path.split(filelist[0])[0]
        return (path, filenames)
    
    def construct_dataframes(self, filelist = None):
        if filelist is None:
            filelist = self.filelist
        df = []
        for i in range(len(filelist)):
            df.append(pd.read_csv(filelist[i], sep=" ", skip_blank_lines=True, header=3, index_col=0, usecols=[0,1,2,3], names=["#","U/V","I/A","R/Ohm"]))
        return df
    
    def construct_ui_charts(self, PandaDataFrames = None):
        if PandaDataFrames is None:
            PandaDataFrames = self.df
        uicharts = []
        for i in range(len(PandaDataFrames)):
            uicharts.append( alt.Chart(self.df[i]).mark_point(color=colorlist[i]).encode(x="U/V", y="I/A").interactive() )
        return uicharts
    
    def R_from_lin_reg(self, PandaDataFrames = None):
        if PandaDataFrames is None:
            PandaDataFrames = self.df
        lin_reg_results = [] # R
        lin_reg_charts = []
        statistics_string = f"Statistics for Fitting of R Values \n    d / µm | R²     | RT / Ohm"
        #print(f"Statistics for Fitting of R Values \n {'d / µm' : >9} | {'R²' : <6} | RT / Ohm ")
        for i in range(len(PandaDataFrames)):
            fit = LinearRegressionFitting(PandaDataFrames[i])
            lin_reg_results.append( 1 / (fit.slope) )
            lin_reg_charts.append( alt.Chart(fit.df).mark_line(color=colorlist[i]).encode(x="x", y='f(x)') )
            #print(f"{self.distances[i] : 10d} | {fit.R2 : >5.4f} | {lin_reg_results[i] : 10.4f}")
            statistics_string += "\n" + f"{self.distances[i] : 10d} | {fit.R2 : >5.4f} | {lin_reg_results[i] : 10.4f}"
            #print(statistics_string)
        return lin_reg_results, lin_reg_charts, statistics_string
    
    def contruct_lin_reg_fit_charts(self):
        return self.lin_reg_charts 

    def find_RT1_RT2(self, R = None):
        if R is None:
            R = self.RT0

        # Remove One Measurement
        # Find maximal Bestimmtheitsmaß / Coefficient of determination / R²
        save_RT1 = None
        max = 0
        for i in range(len(R)):
            Ri = R.drop(labels=i, axis=0)
            evali = RT_Evaluation(Ri, self.contactlenght)
            if evali.R2 >= max:
                max = evali.R2
                save_RT1 = evali

        # Remove Two Measurements
        list_of_rows_to_remove = [[0,1],
                                  [0,2],
                                  [0,3],
                                  [0,4],
                                  [1,2],
                                  [1,3],
                                  [1,4],
                                  [2,3],
                                  [2,4],
                                  [3,4],
                                 ]
        save_RT2 = None
        max = 0
        for rows in list_of_rows_to_remove:
            Ri = R.drop(labels=rows, axis=0)
            evali = RT_Evaluation(Ri, self.contactlenght)
            if evali.R2 >= max:
                max = evali.R2
                save_RT2 = evali
        
        return save_RT1, save_RT2

    def refine_evaluated_range(self, newrange = None):
        if not type(newrange) is tuple:
            raise Exception("Wrong format for Newrange! Should be tuple (float, float)")
        if not len(newrange) == 2:
            raise Exception("Two few or many values in newrange! Should be tuple (float, float)")
        minI, maxI = newrange
        refined_df = self.df
        for i in range(len(refined_df)):
            # First vector where values for I/A are bigger than min and smaller than max = True
            # (dataframe.iloc[:,1] > min) & (dataframe.iloc[:,1] < max )
            # then this boolean vector is used as indexer for the dame dataframe
            # every value with False is sliced
            work = refined_df[i]
            #print(work)
            #print((dataframe.iloc[:,1] > minI) & (dataframe.iloc[:,1] < maxI ))
            work = work [ (work.iloc[:,1] > minI) & (work.iloc[:,1] < maxI ) ]
            #print(work)
            refined_df[i] = work
        self.df = refined_df
        self.R, self.lin_reg_charts, statistics_string = self.R_from_lin_reg()
        self.RT0 = pd.DataFrame({'d/µm':self.distances, 'R_T/Ohm':self.R})
        self.eval0 = RT_Evaluation(self.RT0, self.contactlenght)
        self.eval1, self.eval2 = self.find_RT1_RT2()
        self.refined = True

    def results(self, comment=""):
        #print("Path:\n", ctlm.path, "\n", sep = "")
        results_string = "Path: " + self.path + "\n"
        #print("Files: {}, {}, {}, {}, {} \n".format(*ctlm.files))
        results_string += "Files: {}, {}, {}, {}, {} \n".format(*self.files)
        #header = "Feld Rsh      R²     Rc       LT     rhoc      min I max I # removed values"
        #header = "Feld Rsh      R²     Rc       LT     rhoc      # removed values"
        results_string += comment + " \n"
        results_string += "Feld Rsh      R²     Rc       LT     rhoc      # removed values \n"
        #units =  "-    [Ohm/sq] -      [Ohm mm] [µm]   [Ohm cm²] [A]   [A]   -"
        #units =  "-    [Ohm/sq] -      [Ohm mm] [µm]   [Ohm cm²] -"
        results_string += "-    [Ohm/sq] -      [Ohm mm] [µm]   [Ohm cm²] - \n"
        format_string = "{:<4}{:7.2f} {:8.4f} {:6.2f} {:8.2f} {:8.2e} {} \n"
        results_string += format_string.format(self.files[0].rsplit("_")[0],
                                               self.eval0.Rsh,
                                               self.eval0.R2,
                                               self.eval0.Rc,
                                               self.eval0.LT,
                                               self.eval0.rhoc,
                                               #minI,
                                               #maxI,
                                               None,
                                              )
        results_string +=  format_string.format(self.files[0].rsplit("_")[0],
                                                self.eval1.Rsh,
                                                self.eval1.R2,
                                                self.eval1.Rc,
                                                self.eval1.LT,
                                                self.eval1.rhoc,
                                                #minI,
                                                #maxI,
                                                1,
                                               )
        results_string +=  format_string.format(self.files[0].rsplit("_")[0],
                                                self.eval2.Rsh,
                                                self.eval2.R2,
                                                self.eval2.Rc,
                                                self.eval2.LT,
                                                self.eval2.rhoc,
                                                #minI,
                                                #maxI,
                                                2,
                                               )
        
        return results_string
    
    def construct_rt_charts(self):
        rt_charts = []
        rt_charts.append( alt.Chart(self.RT0).mark_point(color=rwthcolors['blau']).encode(x='d/µm', y='R_T/Ohm').interactive() )
        rt_charts.append( alt.Chart(self.eval0.df).mark_line(color=rwthcolors['bordeaux']).encode(x='x', y='f(x)')             )
        rt_charts.append( alt.Chart(self.eval1.df).mark_line(color=rwthcolors['violett']).encode(x='x', y='f(x)')              )
        rt_charts.append( alt.Chart(self.eval2.df).mark_line(color=rwthcolors['lila']).encode(x='x', y='f(x)')                 )
        return rt_charts