import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
import random
import torch
from rdkit import Chem
from rdkit.Chem.Descriptors import ExactMolWt
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)


def my_H_converter(df, column_list):
    hartee_kcal_mol = 627.5096080305927
    for column in column_list:
        if isinstance(df[column].values[0], float):
            def converter(x):
                return x * hartee_kcal_mol
            df[column] = df[column].apply(converter)
        else:
            df[column] = df[column]
    return df[column_list]

def my_og_normalizer(df, column_list, target_column, mode="min"):
    max_list = []
    if mode == "min":
        max_target = np.amin(df[target_column])
    elif mode == "max":
        max_target = np.amax(df[target_column])
    print(max_target)
    for column in column_list:
        if isinstance(df[column].values[0], float):
            if mode == "min":
                max_val = np.amin(df[column])
            elif mode == "max":
                max_val = np.amax(df[column])
                
            print(f"Max Value {max_val} for column: {column}")
            max_list.append([column , max_val])
            def devider(x):
                return x / max_val
            df[column] = df[column].apply(devider)
        else:
            df[column] = df[column]
    return df[column_list], max_target, max_list     

def my_normalizer(df, column_list, target_column, mode_energies='min', mode_target="max"):
    max_list = []
    if mode_target == "min":
        max_target = np.amin(df[target_column])
    elif mode_target == "max":
        max_target = np.amax(df[target_column])
    print(max_target)
    for column in column_list:
        if column == target_column:
            if isinstance(df[column].values[0], float):
                if mode_target == "min":
                    max_val = np.amin(df[column])
                elif mode_target == "max":
                    max_val = np.amax(df[column])
                    
                print(f"Max Value {max_val} for column: {column}")
                max_list.append([column , max_val])
                
                def devider(x):
                    return x / max_val
                
                df[column] = df[column].apply(devider)
            else:
                df[column] = df[column]  
                
        if column != target_column:        
            if isinstance(df[column].values[0], float):
                if mode_energies == "min":
                    max_val = np.amin(df[column])
                elif mode_energies == "max":
                    max_val = np.amax(df[column])
                if max_val != 0.0:    
                    print(f"Max Value {max_val} for column: {column}")
                    max_list.append([column , max_val])
                    
                    def devider(x):
                        return x / max_val
                    
                    df[column] = df[column].apply(devider)
                else:
                    max_list.append([column , 1.0])
                    df[column] = df[column]   
            else:
                df[column] = df[column]                        
                max_list.append([column , 1.0])

    return df[column_list], max_target, max_list      

def my_standardizer(df, column_list):
    for column in column_list:
        if isinstance(df[column].values[0], np.float64):
            mean = np.mean(df[column])
            std = np.std(df[column])
            def std_scaler(x):
                return (x - mean) / std
            df[column] = df[column].apply(std_scaler)
        else:
            df[column] = df[column]
    return df[column_list], mean, std

def my_dlpno_pipeline(database:str,
                      target_column='inkre',
                      want_normalize=True,
                      want_kcal=True,
                      random_state=None,
                      mw_scaling=True,
                      og_norm=False,
                      remove_sulfur=True,
                      extra_feat=['E_SCF_au',
                                  'E_FSP_au',
                                  'Dipole_Debye',
                                  'DFT_Ex_Eh',
                                  'DFT_Ec_Eh',
                                  'HOMO_ev',
                                  'LUMO_ev',
                                  'ZPVE_Eh',
                                  'DispCorr']):
    
    df = pd.read_csv(f"{database}")
    print(len(df))
    print('Remove weird SMILES')
    drop_list = []
    for i in range(len(df['SMILES'])):
        try:
            mw = ExactMolWt(Chem.MolFromSmiles(df['SMILES'][i]))
        except:
            print('Drop:')
            print(df['SMILES'][i])
            drop_list.append(i)
    df = df.drop(drop_list)
    print(len(df))
    print('Remove bad computations')
    print(df.columns)
    df = df.dropna(subset=extra_feat)
    df = df.dropna(subset=["E_CCSDt_au"])
    df = df.dropna(axis="columns")
    if remove_sulfur:
        df = df[~df['SMILES'].str.contains('S', regex=False)]

    print(len(df))

    print(len(df))

    df["inkre"] = df["E_CCSDt_au"] - df["E_FSP_au"]
    
    columns = ["SMILES", "E_CCSDt_au", "inkre", 'xyz'] + extra_feat 
    if mw_scaling:
        df['mw'] = df['SMILES'].apply(lambda x: ExactMolWt(Chem.MolFromSmiles(x)))
        df["inkre"] = df["inkre"]/df['mw']


    if want_normalize and not want_kcal:
        # use this to normalie the data!


        if og_norm:
            df_scaled, max_target, max_list = my_og_normalizer(df, columns, target_column)
        else:
            df_scaled, max_target, max_list = my_normalizer(df, columns, target_column)            

        train, test = train_test_split(df_scaled, test_size=0.2, random_state=random_state)
        train_data = train.dropna(axis="rows")
        test_data = test.dropna(axis="rows")

    elif want_kcal and not want_normalize:

        df_scaled = my_H_converter(df, columns)

        train, test = train_test_split(df_scaled, test_size=0.2, random_state=random_state)
        train_data = train.dropna(axis="rows")
        test_data = test.dropna(axis="rows")
        max_target = 1
        max_list = [1,1,1,1,1]
        
    elif want_kcal and want_normalize:

        df_scaled = my_H_converter(df, columns)
        #print(df_scaled)
        if og_norm:
            df_scaled, max_target, max_list = my_og_normalizer(df, columns, target_column)
        else:
            df_scaled, max_target, max_list = my_normalizer(df, columns, target_column)   
        
        #print(df_scaled)
        train, test = train_test_split(df_scaled, test_size=0.2, random_state=random_state)
        train_data = train.dropna(axis="rows")
        test_data = test.dropna(axis="rows")

    else:
        train_data, test_data = train_test_split(df, test_size=0.2, random_state=random_state)

    # combine extra features
    train_data['extra_feat'] = train_data[extra_feat].values.tolist()
    test_data['extra_feat'] = test_data[extra_feat].values.tolist()
    
    return train_data, test_data, max_target, max_list


def my_eval_normalizer(df, column_list, max_list):
    for column, scaling in zip(column_list, max_list):
        if isinstance(df[column].values[0], float):
            def devider(x):
                return x / scaling[1] # 0 is name 1 is value
            #print(f'{column} is scaled with {scaling[1]}')
            df[column] = df[column].apply(devider)
        else:
            print(f'{column} not float!')
            df[column] = df[column]
    return df[column_list]   

def my_eval_dlpno_pipeline(database:str,
                           max_list,
                           want_normalize=True,
                           want_kcal=True,
                           mw_scaling=True, 
                           extra_feat=['E_SCF_au',
                                  'E_FSP_au',
                                  'Dipole_Debye',
                                  'DFT_Ex_Eh',
                                  'DFT_Ec_Eh',
                                  'HOMO_ev',
                                  'LUMO_ev',
                                  'ZPVE_Eh',
                                  'DispCorr']
                           ):
    df = pd.read_csv(f"{database}")
    df = df.dropna(subset=extra_feat)
    df = df.dropna(subset=["E_CCSDt_au"])
    df = df.dropna(axis="columns")
    df["inkre"] = df["E_CCSDt_au"] - df["E_FSP_au"]
    print(len(df))

    print(len(df))
    columns = ["SMILES", "E_CCSDt_au", "inkre", 'xyz'] + extra_feat 

    if mw_scaling:
        df['mw'] = df['SMILES'].apply(lambda x: ExactMolWt(Chem.MolFromSmiles(x)))
        df["inkre"] = df["inkre"]/df['mw']

    if want_normalize and not want_kcal:
        # use this to normalie the data!

        df_scaled = my_eval_normalizer(df, columns, max_list)            

    elif want_kcal and not want_normalize:

        df_scaled = my_H_converter(df, columns)
        
    elif want_kcal and want_normalize:

        df_scaled = my_H_converter(df, columns)
        #print(df_scaled)
        df_scaled = my_eval_normalizer(df, columns, max_list)  

    df_scaled['extra_feat'] = df_scaled[extra_feat].values.tolist()
    return df_scaled