import preprocessing_data_new
from screening_models import xAI_Net, FFNN_Net
import torch
import preprocessing_graph_new
import numpy as np
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from training import EarlyStopping
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from rdkit import Chem
from rdkit.Chem import AllChem

num_classes  = 2
target_column = 'ddg'
rel_pred = True
cbs_df, num_cond = preprocessing_data_new.preprocess_cbs(data_path='Data/CBS_10-04-2023.csv',
                                                         n_classes = num_classes,
							                            target_column = target_column, aug_data = False, threshs=[57],
                                                         dropping_list=None, merge_columns=False,
                                                         verbose=True)
print(cbs_df['ee'].mean())
cbs_df['label2'] = cbs_df['ee'].apply(lambda x: np.array([0.0, 1.0]) if x<= cbs_df['ee'].mean() else np.array([1.0, 0.0]))
cbs_df['label2'] = cbs_df['label2'].apply(lambda x: [float(i) for i in x])
if rel_pred:
    # if AC column is S, then -x
    # if AC column is R, then +x
    cbs_df['ddg'] = cbs_df.apply(lambda x: x['ddg'] if x['AC'] == 'S' else -x['ddg'], axis=1)
    # same for ee
    cbs_df['ee'] = cbs_df.apply(lambda x: x['ee'] if x['AC'] == 'S' else -x['ee'], axis=1)

torch.manual_seed(1337)

train_data, val_data = train_test_split(cbs_df, test_size=0.2, random_state=1337, stratify=cbs_df['label2'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss_fn = torch.nn.MSELoss()

bs_val = len(val_data)
val_data = val_data.reset_index()
train_data = train_data.reset_index()

train_data, val_data = train_test_split(cbs_df, test_size=0.2, random_state=1337, stratify=cbs_df['label2'])

t_KeyInt, t_disc = preprocessing_graph_new.my_graph_only_dataloader(val_data,
                                                                    target_column=target_column,
                                                                    bs=16)

v_KeyInt, v_disc = preprocessing_graph_new.my_graph_only_dataloader(val_data,
                                                                    target_column=target_column,
                                                                    bs=bs_val)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss_fn = torch.nn.MSELoss()

def train(model, graph_loader, optimizer, loss_fn):
    model.train()
    loss_all = 0
    # Enumerate over the data
    for data_graph in graph_loader:
        # Use GPU
        data_graph.to(device)
        # Reset gradients
        optimizer.zero_grad()
        # Passing the node features and the connection info
        mu = model(data_graph.x, data_graph.edge_index, data_graph.batch, data_graph.edge_attr)

        # Calculating the loss and gradients
        loss = loss_fn(mu, data_graph.ee)  
        loss_all += loss.item() 
        loss.backward()
        # Update using the gradients
        optimizer.step()


    return loss_all / len(graph_loader)

def val(model, graph_loader, loss_fn):
    loss_all = 0
    model.eval()
    for data_graph_val in graph_loader:
        with torch.no_grad():
            data_graph_val.to(device)
            mu = model(data_graph_val.x, data_graph_val.edge_index, data_graph_val.batch, data_graph_val.edge_attr)

            # Calculating the loss and gradients
            loss = loss_fn(mu, data_graph_val.ee)  
            loss_all += loss.item() 

    return loss_all / len(graph_loader)

print('GNN training')

for loadertype in ['key', 'disc.']:
    if loadertype == 'key':
        t_loader = t_KeyInt
        v_loader = v_KeyInt
    else:
        t_loader = t_disc
        v_loader = v_disc

    model = xAI_Net(gnn_dim=256,
                    n_gnn_layers=3,
                    n_ffnn_neurons=256,
                    n_ffnn_layers=3,
                    pool_type='add',
                    activation_function=F.relu).to(device)

    early_stopping = EarlyStopping(patience=200,
                                verbose=False,
                                path=f"Trained_Models/MolRep_Screening.pt",
                                delta=0.0001)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    losses = []
    val_losses = []
    accs = []
    val_accs = []

    best_epoch = 0
    best_loss = 10000
    best_acc = 0

    for epoch in range(2000):
        loss = train(model, t_loader, optimizer, loss_fn)
        val_loss = val(model, v_loader, loss_fn)

        losses.append(loss)
        val_losses.append(val_loss)

        early_stopping(val_loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    model.load_state_dict(torch.load("Trained_Models/MolRep_Screening.pt"))

    model.eval()

    for data_graph_val in v_loader:
        with torch.no_grad():
            data_graph_val.to(device)
            mu = model(data_graph_val.x, data_graph_val.edge_index, data_graph_val.batch, data_graph_val.edge_attr)
            mu = mu.detach().cpu().numpy()
            ee = data_graph_val.ee.detach().cpu().numpy()
            print(f'RESULTS--{loadertype}')
            print(f'MAE: {mean_absolute_error(mu, ee):.4f}, RMSE: {mean_squared_error(mu, ee):.4f}, R2_score: {r2_score(mu, ee):.4f}')
            print('RESULTS')


# NN with FP

from rdkit import DataStructs

def get_morgan_fingerprint(smiles, radius=2, nBits=2048):
    mol = Chem.MolFromSmiles(smiles)
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=nBits)
    arr = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(fp, arr)
    return arr

# train function for NN
def train(model, loader, optimizer, loss_fn):
    model.train()
    loss_all = 0
    # Enumerate over the data
    for data_fp, data_label in loader:
        # Use GPU
        data_fp = data_fp.to(device)
        data_label = data_label.to(device)
        # Reset gradients
        optimizer.zero_grad()
        # Passing the node features and the connection info
        mu = model(data_fp)

        # Calculating the loss and gradients
        loss = loss_fn(mu, data_label)  
        loss_all += loss.item() 
        loss.backward()
        # Update using the gradients
        optimizer.step()


    return loss_all / len(loader)

def val(model, loader, loss_fn):
    model.eval()
    loss_all = 0
    for data_fp, data_label in loader:
        with torch.no_grad():
            data_fp = data_fp.to(device)
            data_label = data_label.to(device)
            mu = model(data_fp)

            # Calculating the loss and gradients
            loss = loss_fn(mu, data_label)  
            loss_all += loss.item()
    
    return loss_all / len(loader)

for radius in [2, 4]:

    # get the morgan fingerprints
    train_data['r1_fp'] = train_data['r1'].apply(lambda x: get_morgan_fingerprint(x, radius=radius))
    train_data['cat_fp'] = train_data['cat'].apply(lambda x: get_morgan_fingerprint(x, radius=radius))
    val_data['r1_fp'] = val_data['r1'].apply(lambda x: get_morgan_fingerprint(x, radius=radius))
    val_data['cat_fp'] = val_data['cat'].apply(lambda x: get_morgan_fingerprint(x, radius=radius))
    # combine r1_fp and cat_fp 
    train_data['fp'] = train_data.apply(lambda x: np.concatenate((x['r1_fp'], x['cat_fp'])), axis=1)
    val_data['fp'] = val_data.apply(lambda x: np.concatenate((x['r1_fp'], x['cat_fp'])), axis=1)

    # train a NN on the morgan fingerprints to predict the ee
    # using torch 
    # define the model
    model = FFNN_Net(n_ffnn_neurons=256,
                     n_ffnn_layers=3,
                    activation_function=F.relu,
                    fp_dim=2*2048).to(device)
    
    # train the model
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = torch.nn.MSELoss()

    early_stopping = EarlyStopping(patience=200,
                                verbose=False,
                                path=f"Trained_Models/FFNN_Screening.pt",
                                delta=0.0001)
    


    # convert the data to torch tensors and loaders
    t_fp = torch.tensor(np.vstack(train_data['fp'].values), dtype=torch.float)
    v_fp = torch.tensor(np.vstack(val_data['fp'].values), dtype=torch.float)
    t_label = torch.tensor(np.vstack(train_data['ddg'].values), dtype=torch.float)
    v_label = torch.tensor(np.vstack(val_data['ddg'].values), dtype=torch.float)

    # generate dataloader
    # set batch size to 16
    from torch.utils.data import DataLoader
    t_fp_dataset = torch.utils.data.TensorDataset(t_fp, t_label)
    t_fp_dataset = DataLoader(t_fp_dataset, batch_size=16)
    v_fp_dataset = torch.utils.data.TensorDataset(v_fp, v_label)
    v_fp_dataset = DataLoader(v_fp_dataset, batch_size=bs_val)


    for epoch in range(2000):
        loss = train(model, t_fp_dataset, optimizer, loss_fn)
        val_loss = val(model, v_fp_dataset, loss_fn)

        losses.append(loss)
        val_losses.append(val_loss)

        early_stopping(val_loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break
    
    model.load_state_dict(torch.load("Trained_Models/FFNN_Screening.pt"))

    model.eval()
    for data_fp, data_label in v_fp_dataset:
        with torch.no_grad():
            data_fp = data_fp.to(device)
            data_label = data_label.to(device)
            mu = model(data_fp)
            mu = mu.detach().cpu().numpy()
            ee = data_label.detach().cpu().numpy()
            print(f'RESULTS--{radius}-NN')
            print(f'MAE: {mean_absolute_error(mu, ee):.4f}, RMSE: {mean_squared_error(mu, ee):.4f}, R2_score: {r2_score(mu, ee):.4f}')
            print('RESULTS')

# GBRT and SVM with FP

from sklearn.ensemble import GradientBoostingRegressor
from sklearn.svm import SVR

for radius in [2, 4]:
    for method in ['GBRT', 'SVR']:
        # get the morgan fingerprints
        train_data['r1_fp'] = train_data['r1'].apply(lambda x: get_morgan_fingerprint(x, radius=radius))
        train_data['cat_fp'] = train_data['cat'].apply(lambda x: get_morgan_fingerprint(x, radius=radius))
        val_data['r1_fp'] = val_data['r1'].apply(lambda x: get_morgan_fingerprint(x, radius=radius))
        val_data['cat_fp'] = val_data['cat'].apply(lambda x: get_morgan_fingerprint(x, radius=radius))
        # combine r1_fp and cat_fp 
        train_data['fp'] = train_data.apply(lambda x: np.concatenate((x['r1_fp'], x['cat_fp'])), axis=1)
        val_data['fp'] = val_data.apply(lambda x: np.concatenate((x['r1_fp'], x['cat_fp'])), axis=1)

        # train depending on the method
        if method == 'GradientBoostingRegressor':
            model = GradientBoostingRegressor()
        else:
            model = SVR()

        model.fit(np.vstack(train_data['fp'].values), np.vstack(train_data['ddg'].values))

        # predict
        mu = model.predict(np.vstack(val_data['fp'].values))
        ee = np.vstack(val_data['ddg'].values)

        print(f'RESULTS--{radius}-{method}')
        print(f'MAE: {mean_absolute_error(mu, ee):.4f}, RMSE: {mean_squared_error(mu, ee):.4f}, R2_score: {r2_score(mu, ee):.4f}')
        print('RESULTS')
