import random
import numpy as np
from numpy.random import seed
import torch
import torch.functional
from rdkit import Chem
from rdkit.Chem.Descriptors import ExactMolWt
import pandas as pd

from models import OneToRuleThemAll
import config
from sklearn.metrics import mean_absolute_error
from data_preprocessing import my_dlpno_pipeline, my_eval_dlpno_pipeline
from graph_preprocessing import my_knnG_dataloader

best_parameters = config.hps
random_states = config.random_states

mode = 'Dimers' # or Monomers
database = f'DLPNO_{mode}_SI.csv'
test_database = f'TestDatabase_{mode}.csv'

device = 'cpu'

# seeds
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
seed(0)

torch.backends.cudnn.benchmark = False
g = torch.Generator()
g.manual_seed(0)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
                                                 
model_cc = OneToRuleThemAll(
            gnn_layer_type=best_parameters['gnn_layer_type'],
            mol_gnn_layer_type=best_parameters['mol_gnn_layer_type'],
            num_gnn_layers=best_parameters['num_gnn_layers'],
            num_mol_gnn_layers=best_parameters['num_mol_gnn_layers'],
            gnn_output_size=best_parameters['gnn_output_size'],
            mol_gnn_output_size=best_parameters['mol_gnn_output_size'],
            gnn_heads=best_parameters['gnn_heads'],
            mol_gnn_heads=best_parameters['mol_gnn_heads'],
            gnn_activation=best_parameters['gnn_activation'],
            mol_gnn_activation=best_parameters['mol_gnn_activation'],

            attentionlayer_size=best_parameters['attentionlayer_size'],
            embedding_size=best_parameters['embedding_size'],
            num_afp_layers=best_parameters['num_afp_layers'],
            num_timesteps=best_parameters['num_timesteps'],
            afp_dropout=best_parameters['afp_dropout'],

            dnn_activation=best_parameters['dnn_activation'],
            num_hl=best_parameters['num_hl'],
            num_gnn_hl=best_parameters['num_gnn_hl'],
            num_mol_gnn_hl=best_parameters['num_mol_gnn_hl'],
            gnn_hl_size=best_parameters['gnn_hl_size'],
            mol_gnn_hl_size=best_parameters['mol_gnn_hl_size'],
            hl_size=best_parameters['hl_size'],
            hl_dropout=best_parameters['hl_dropout'],

            mol_max_pool=best_parameters['mol_max_pool'],
            mol_mean_pool=best_parameters['mol_mean_pool'],
            mol_add_pool=best_parameters['mol_add_pool'],
            max_pool=best_parameters['max_pool'],
            mean_pool=best_parameters['mean_pool'],
            add_pool=best_parameters['add_pool'],

            triangle=best_parameters['triangle'],
            u_net_type=best_parameters['u_net_type'],
            mol_u_net_type=best_parameters['mol_u_net_type'],
            seperated_dnn=best_parameters['seperated_dnn'],
            node_dim=39,
            extra_energies=9
            ).to(device)

testing_pred_list = []
testing_true_list = []
testing_error_list = []
testing_smiles_list = []
testing_var_list = []


for rs in range(1,11):
    print(f'Model {rs}')
    train_data, val_data, max_target, max_list = my_dlpno_pipeline(database=database,
                                                                    random_state=random_states[rs])
    print(max_list)
    if rs < 10:
        model_cc.load_state_dict(torch.load(f"DFT_DLPNO_{mode}_0{rs}.pt",
                                    map_location=torch.device('cpu')))
    else:
        model_cc.load_state_dict(torch.load(f"DFT_DLPNO_{mode}_10.pt",
                                    map_location=torch.device('cpu')))         
    

    df_test = my_eval_dlpno_pipeline(test_database,
                                    max_list = max_list)
    df_test = df_test[~df_test['SMILES'].str.contains('S', regex=False)]
    test_smiles = df_test['SMILES'].values.tolist()
    xyz_features = True
    test_gloader, test_knn_loader, test_glist, test_knnlist = my_knnG_dataloader(df_test,
                                                        bs = len(df_test),
                                                        xyz_features = xyz_features,
                                                        explicit_H = False,
                                                        cut_off_radius = best_parameters['cutoff'],
                                                        target_column = 'inkre',
                                                        )  
    model_cc.eval()
    for test_graph, test_knn_graph in zip(test_gloader, test_knn_loader):
        with torch.no_grad():
            test_graph.to(device)
            test_knn_graph.to(device)
            mu, var = model_cc(test_graph, test_knn_graph) 
            
    ps = [p.detach().cpu().numpy()[0] for p in mu]
    rs = [r.detach().cpu().numpy()[0] for r in test_graph.incre]
    vs = [v.detach().cpu().numpy() [0] for v in var]
    
    #rescale with molecular weight
    x_test = [x*ExactMolWt(Chem.MolFromSmiles(smile)) for x, smile in zip(ps, test_smiles)]
    y_test = [y*ExactMolWt(Chem.MolFromSmiles(smile)) for y, smile in zip(rs, test_smiles)]
    v_test = [v*ExactMolWt(Chem.MolFromSmiles(smile)) for v, smile in zip(vs, test_smiles)]
    # rescale with normalizing value 
    inverse_transform = lambda x: x * max_target

    # x predicted
    # y target
    if isinstance(x_test, list):
        x_test = np.array(x_test)
    if isinstance(y_test, list):
        y_test = np.array(y_test)
    if isinstance(v_test, list):
        v_test = np.array(v_test)
                
    x_test = inverse_transform(x_test)
    y_test = inverse_transform(y_test)
    v_test = inverse_transform(v_test)
    
    print(f'MAE: {mean_absolute_error(x_test, y_test)}')
    testing_error_list.extend([(pre - real) for pre, real in zip(x_test, y_test)])
    testing_smiles_list.extend(test_smiles)
    testing_pred_list.extend(x_test)
    testing_true_list.extend(y_test)
    testing_var_list.extend(v_test)

err_dict = {"Deviations": testing_error_list,
            "SMILES": testing_smiles_list,
            "Pred": testing_pred_list,
            "True": testing_true_list,
            "Var": testing_var_list}
err_df = pd.DataFrame(err_dict, columns=['Deviations', 'SMILES', 'Pred', 'True', 'Var'])
err_df.to_csv(f"MAE_List_testing_dlpno_{mode}.csv",index=False)
