import numpy as np
import pandas as pd
import warnings
from sklearn.model_selection import train_test_split

warnings.filterwarnings('ignore')

def calc_ddg(ee, T):
    '''
    ee in %
    T in degree Celsius
    er = (1 + ee)/(1 - ee)
    er = exp(-ddg/RT)
    ln(er) = -ddg/RT
    ddg = -ln(er)RT
    '''
    T = float(T) + 273.15 # K
    R_const = 8.31446261815324/1000 # kJ / (mol K)
    er = (1 + ee/100) / (1 - ee/100)
    ddg = np.log(er) * R_const * T # kJ / mol
    ddg *= 1/4.1839954  # kcal / mol
    return ddg

# Contains some legacy code and functions that are not used anymore
# they were used when we were still working with literature data.
# We let them in so if you want to use them you can.
# We also devided the data in two classes, which has no effect on our current model.
def preprocess_cbs(
                    data_path = '../Database/CBS_09-07-2022.csv',
                    target_column = 'ee',
                    n_classes = 2,
                    thresh_ee = 80,
                    merge_columns = True,
                    ee_clipping = 100,
                    aug_data = False,
                    aug_reactant = False,
                    diff_mode = False,
                    threshs = None,
                    verbose = False,
                    dropping_list = ['Si',
                                     '+',
                                     '-',
                                     'F',
                                     'Br',
                                     'N',
                                     'S'],
                    substr = None
                   ):
    """
    Function to preprocess cbs data
    """
    target_column = target_column
    df = pd.read_csv(data_path)
    df["exp_ee"] = df["ee"].apply(lambda x: np.exp(x))
    
    df['ddg'] = df[['ee', 'T']].apply(lambda x: calc_ddg(*x), axis=1)
    
    if verbose:
        print(f'Length of initial dataframe: {len(df)}')
    df["ee"] = df["ee"].apply(lambda x: np.float64(x))
    if verbose:
        print(f'Inital Mean ee: {df["ee"].mean()}')
        print(f'Inital Median ee: {df["ee"].median()}')
    # determine if good or bad

    if aug_data:
        df_copy = df.copy()
        df_copy['r1'] = df_copy['r1'].apply(lambda x: x.replace('@', ''))
        df_copy['cat'] = df_copy['cat'].apply(lambda x: x.replace('@', ''))
        df_copy['ee'] = df_copy['ee'].apply(lambda x: 0.0)
        df = pd.concat([df, df_copy], axis=0, ignore_index=True)      
        if aug_reactant:
            def non_pro_chiral(s):
                if s == 'CC(C1=CC=CC=C1)=O':
                    r = 'CCC(CC)=O'
                elif s == 'CC(C(C)(C)C)=O':
                    r = 'O=C(C(C)C)C(C)C'
                elif s == 'CCC(C)=O':
                    r = 'CCC(CC)=O'
                elif s == 'CC(CCCCC)=O':
                    r = 'CCCCC(CCCC)=O'
                elif s == 'CC(C1CCCCC1)=O':
                    r = 'C2CCCCC2C(C1CCCCC1)=O'
                elif s == 'CC(/C=C/C1=CC=CC=C1)=O':
                    r = 'O=C(/C=C/C1=CC=CC=C1)/C=C/C2=CC=CC=C2' 
                elif s == 'O=C(C)C1=CC=CC=C1':
                    r = 'CCC(CC)=O'
                else:
                    r = 'DENNIS-EGG'
                # remove the rest with identifier
                # because there should be 
                # covered enough              
                return r
                
            df_copy_r = df.copy()          
            df_copy_r['r1'] = df_copy_r['r1'].apply(non_pro_chiral)
            df_copy_r = df_copy_r[~df_copy_r["r1"].str.contains('DENNIS-EGG',
                                                                regex=False)]           
            df_copy_r['ee'] = df_copy_r['ee'].apply(lambda x: 0.0)
            
            df = pd.concat([df, df_copy_r],
                           axis=0,
                           ignore_index=True)

    """
    Set number of classes and thresholds below
    """
    num_classes = n_classes

    if num_classes == 1:
        # set this threshold with you have only 2 classes 
        threshold_ee = thresh_ee
        df["quality"] = df["ee"].apply(lambda x: 1.0 if x > threshold_ee else 0.0)
        
    elif num_classes == 2:
        if threshs is not None:
            for i in range(len(threshs)):
                label_str = f"class{i}"
                df[label_str] = df["ee"].apply(lambda x: 1.0 if x > threshs[i] else 0.0)                     
                label_str = f"class{i+1}"
                df[label_str] = df["ee"].apply(lambda x: 1.0 if x <= threshs[i] else 0.0)
                            
    elif num_classes > 2:
        if threshs is not None:
            for i in range(len(threshs) + 1):
                label_str = f"class{i}"
                if i == 0:
                    df[label_str] = df["ee"].apply(lambda x: 1.0 if x > threshs[i] else 0.0)                     
                elif i > 0 and i < len(threshs):
                    df[label_str] = df["ee"].apply(lambda x: 1.0 if x <= threshs[i-1] and x >= threshs[i] else 0.0)
                else:
                    df[label_str] = df["ee"].apply(lambda x: 1.0 if x < threshs[i-1] else 0.0)      
                                             
    #make one hot encoding with the dataframe
    df_encoded = pd.get_dummies(df, prefix=["borane", "solv"], columns=["borane", "solv"], dtype="float32")
    # drop columns we don't want
    #df_droped = df_encoded.drop(columns=["p","er", "ln_ee", "ln_er"])
    df_droped = df_encoded
    print(df_droped.columns)
    if dropping_list is not None:
        for s in dropping_list:
            df_droped = df_droped[~df_encoded["r1"].str.contains(s, regex=False)]
        if verbose:
            print(f'Length of final dataframe: {len(df_droped)}')

    #print(df_droped.columns)    
    if merge_columns:
        if num_classes == 1:
            df_droped['cond'] = df_droped.iloc[:,9:].apply(
                lambda x: ','.join(x.astype(str)), axis=1)
        elif num_classes > 1:

            df_droped['cond'] = df_droped.iloc[:, np.r_[4, 12+num_classes:23]].apply(
                lambda x: ','.join(x.astype(str)), axis=1)
            df_droped['label'] = df_droped.iloc[:,12:12+num_classes].apply(
                lambda x: ','.join(x.astype(str)), axis=1)
            
        print(df_droped['cond'])
        print(df_droped['cond'])                
        df_droped['cond'] = df_droped['cond'].apply(lambda x: x.split(','))

        df_droped['cond'] = df_droped['cond'].apply(lambda x: [float(i) for i in x])
        df_droped['label'] = df_droped['label'].apply(lambda x: x.split(','))
        df_droped['label'] = df_droped['label'].apply(lambda x: [float(i) for i in x])
        num_cond = len(df_droped['cond'][0])
    else:
        num_cond = 0
    
    if ee_clipping < 100 and ee_clipping > 0:
        df_droped = df_droped[df_droped['ee'] < ee_clipping]
        if verbose:
            print(f'Length of new dataframe: {len(df_droped)}')
        df_droped["ee"] = df_droped["ee"].apply(lambda x: np.float64(x))
        if verbose:
            print(f'New Mean ee: {df_droped["ee"].mean()}')
            print(f'New Median ee: {df_droped["ee"].median()}')
        
    if  ee_clipping < 0:
        df_droped = df_droped[df_droped['ee'] > -ee_clipping]
        if verbose:
            print(f'Length of new dataframe: {len(df_droped)}')
        df_droped["ee"] = df_droped["ee"].apply(lambda x: np.float64(x))
        if verbose:
            print(f'New Mean ee: {df_droped["ee"].mean()}')
            print(f'New Median ee: {df_droped["ee"].median()}')
            
    if diff_mode:
        df_droped['cat2'] = df_droped['cat'].apply(lambda x: x.replace('@@', 'LOL').replace('@', '@@').replace('LOL', '@'))
        
    if substr != None:
        df_droped = df_droped[df_droped['r1'] == substr]

    def make_classes(x):
        class_list = [0.0]*n_classes
        for i in range(n_classes):
            threshhold_down = int(100/n_classes)*(i)
            threshhold_up = int(100/n_classes)*(i+1) # for n = 5, 20, 40, 60, 80
            if x <= threshhold_up and x > threshhold_down:
                class_list[i] = 1.0

        return np.array(class_list)

    df_droped["ee"] = df_droped["ee"].apply(lambda x: np.float64(x))
    df_droped['label'] = df_droped['ee'].apply(make_classes)
    df_droped['label'] = df_droped['label'].apply(lambda x: [float(i) for i in x])
    return df_droped, num_cond


def my_normalizer(df, column_list, target_column, mode="max", verbose=False):
    if mode == "min":
        max_target = np.amin(df[target_column])
    elif mode == "max":
        max_target = np.amax(df[target_column])
    if verbose:
        print(max_target)
    for column in column_list:
        if isinstance(df[column].values[0], float):
            if mode == "min":
                max_val = np.amin(df[column])
            elif mode == "max":
                max_val = np.amax(df[column])
            if verbose:
                print(f'Max value: {max_val} in column: {column}')
            if max_val != 0:
                def devider(x):
                    return x / max_val
                df[column] = df[column].apply(devider)
        else:
            df[column] = df[column]
    return df[column_list], max_target      

def my_standardizer(df, column_list):
    for column in column_list:
        if isinstance(df[column].values[0], np.float64):
            mean = np.mean(df[column])
            std = np.std(df[column])
            def std_scaler(x):
                return (x - mean) / std
            df[column] = df[column].apply(std_scaler)
        else:
            df[column] = df[column]
    return df[column_list], mean, std

def train_test_cbs(
    dataframe, 
    normalize = True,
    standardize = False,
    random_state = 42,
    test_size = 0.2,
    target_column = 'ee'):

    if normalize:
        # use this to normalie the data!
        columns = dataframe.columns
        df_scaled, max_target  = my_normalizer(dataframe, columns, target_column, mode="max")
        train, test = train_test_split(df_scaled, test_size=test_size, random_state=random_state, stratify=df_scaled['label'])
        train_data = train.dropna(axis="rows")
        test_data = test.dropna(axis="rows")

    elif standardize:
        columns = dataframe.columns
        df_scaled, mean, std = my_standardizer(dataframe, columns)
        #print(df_scaled[target_column])
        train, test = train_test_split(df_scaled, test_size=test_size, random_state=random_state, stratify=df_scaled['label'])
        train_data = train.dropna(axis="rows")
        test_data = test.dropna(axis="rows")

    elif (standardize or normalize) == False:
        train, test = train_test_split(dataframe, test_size=test_size, random_state=random_state, stratify=df_scaled['label'])
        train_data = train.dropna(axis="rows")
        test_data = test.dropna(axis="rows")

    return train_data, test_data
