import preprocessing_data_new
import numpy as np
import matplotlib.pyplot as plt
import config
import torch
import pandas as pd 
import preprocessing_graph_new
from sklearn.metrics import r2_score
from datetime import date
from models import Net

def train(model, loader, optimizer, loss_fn, device='cuda'):
    model.train()
    losses = 0
    for batch in loader:
        batch.to(device)
        optimizer.zero_grad()
        pred, _, _ = model(batch.x,
                     batch.edge_index,
                     batch.batch,
                     batch.edge_attr)
        loss = loss_fn(pred, batch.ee)
        loss.backward()
        losses += loss.item()
        optimizer.step()
        
    return losses / len(loader)

def val(model, loader, loss_fn, device='cuda'):
    model.eval()
    losses = 0
    with torch.no_grad():
        for batch in loader:
            batch.to(device)
            pred, _, _ = model(batch.x,
                        batch.edge_index,
                        batch.batch,
                        batch.edge_attr)
            loss = loss_fn(pred, batch.ee)
            losses += loss.item()
            
    return losses/len(loader)

def get_predictions(model, loader, device='cuda'):
    model.eval()
    preds = []
    trues = []
    # Predict the labels of the test set
    with torch.no_grad():
        for batch in loader:
            batch.to(device)
            pred, _, _ = model(batch.x,
                        batch.edge_index,
                        batch.batch,
                        batch.edge_attr)
            true = batch.ee
            # Convert tensors to lists of floats and extend the lists
            preds.extend(pred.view(-1).tolist())
            trues.extend(true.view(-1).tolist())    

    return preds, trues


def main():
    best_parameters =  config.best_parameters
    date_ = date.today()
    failed_runs = []
    epochs = 5000
    delta = 1e-5
    patience = 200
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    loss_fn = torch.nn.MSELoss()
    num_classes  = 2
    target_column = 'ddg'
    rel_pred = True
    cbs_df, num_cond = preprocessing_data_new.preprocess_cbs(data_path='../Database/CBS_10-04-2023.csv',
                                                            n_classes = num_classes,
                                                            target_column = target_column, aug_data = False, threshs=[56],
                                                            dropping_list=None, merge_columns=False,
                                                            verbose=True)
    print(cbs_df['ee'].mean())
    cbs_df['label2'] = cbs_df['ee'].apply(lambda x: np.array([0.0, 1.0]) if x<= cbs_df['ee'].mean() else np.array([1.0, 0.0]))
    cbs_df['label2'] = cbs_df['label2'].apply(lambda x: [float(i) for i in x])
    
    if rel_pred:
        # if AC column is S, then -x
        # if AC column is R, then +x
        cbs_df['ddg'] = cbs_df.apply(lambda x: x['ddg'] if x['AC'] == 'S' else -x['ddg'], axis=1)
        # same for ee
        cbs_df['ee'] = cbs_df.apply(lambda x: x['ee'] if x['AC'] == 'S' else -x['ee'], axis=1)

    # possible entries for r1 in the cbs_df
    r1_list = cbs_df['r1'].unique()
    # possible entries for cat in the cbs_df
    cat_list = cbs_df['cat'].unique()

    print(r1_list, cat_list)
    print('----------------')
    print(f'Starting CV')
    # Perform 10-fold cross-validation
    for mode in ['r1', 'cat']:
        mse_scores = []
        train_preds = []
        val_preds = []
        train_trues = []
        val_trues = []
        if mode == 'r1':
            mol_list = r1_list
        else:
            mol_list = cat_list

        for molecule in mol_list:
            run_counter = 1
            try_counter = 0
            r2_value = 0
            while r2_value < 0.9 and try_counter < 10:
                print(f'Starting run: {run_counter}')
                model = Net(gnn_dim=best_parameters['embedding_size'],
                n_gnn_layers=best_parameters['n_gnn_layers'],
                n_ffnn_neurons=best_parameters['hl_size'],
                n_ffnn_layers=best_parameters['n_hl'],
                pool_type=best_parameters['pooling'],
                activation_function=best_parameters['activation']
                ).to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
                
                # get the indices of the current molecule
                indices = cbs_df[cbs_df[mode] == molecule].index
                
                # get the indices of the other molecules
                other_indices = cbs_df[cbs_df[mode] != molecule].index

                train_df, val_df = cbs_df.iloc[other_indices], cbs_df.iloc[indices]

                train_gloader, train_dloader, train_graph_list = preprocessing_graph_new.my_graph_only_dataloader(train_df,
                                                                                                target_column=target_column,
                                                                                                bs=best_parameters['bs'])
                val_gloader, val_dloader, val_graph_list = preprocessing_graph_new.my_graph_only_dataloader(val_df,
                                                                                            target_column=target_column,
                                                                                            bs=1)
                best_loss = np.inf
                for epoch in range(5000):
                    loss = train(model, train_gloader, optimizer, loss_fn, device)
                    val_loss = val(model, val_gloader, loss_fn, device)
                    
                    if epoch % 5 == 0:
                        print(f"< Epoch : {epoch} ||  Loss : {loss} || <<>> Val-Loss : {val_loss}")

                    # Check for early stopping and save the best model
                    if val_loss + delta < best_loss:
                        best_loss = val_loss
                        counter = 0
                        torch.save(model.state_dict(), 'cv_models/model_oli_catsubs.pt')  # Save the best model
                    else:
                        counter += 1
                        if counter >= patience:
                            print(f'Validation loss did not improve for {patience} epochs. Training stopped.')
                            break
                train_gloader, train_dloader, train_graph_list = preprocessing_graph_new.my_graph_only_dataloader(train_df,
                                                                                                target_column=target_column,
                                                                                                bs=1)
                model.load_state_dict(torch.load(f"cv_models/model_oli_catsubs.pt"))

                train_pred, train_true = get_predictions(model, train_gloader, device)
                val_pred, val_true = get_predictions(model, val_gloader, device)   
                                
                r2_value = r2_score(train_pred, train_true)
                try_counter += 1
            
            # if try_counter >= 10:
            #     print('Training failed')
            #     failed_runs.append(run_counter)
            # else:    
            print('Training successfull')
            
            train_preds.extend(train_pred)
            train_trues.extend(train_true)
            
            val_preds.extend(val_pred)
            val_trues.extend(val_true)
            
            
            # Plot actual vs predicted values
            plt.scatter(train_true, train_pred, label='train', alpha=0.3)
            plt.scatter(val_true, val_pred, label='val')
            plt.plot([0, 4], [0, 4], '--', color='r', label='optimal')
            plt.xlabel("Actual Values")
            plt.ylabel("Predicted Values")
            plt.legend()
            plt.savefig(f"cv_models/scatter_{run_counter}_{mode}.png")  
            plt.close()
            
            run_counter += 1
        print('----------------')
        
        data = {'train_true' : train_trues,
                'train_pred' : train_preds}
        df = pd.DataFrame(data)
        df.to_csv(f'cv_models/train_results_{date_}_{mode}.csv', index=False)
        
        data = {'val_true' : val_trues,
                'val_pred' : val_preds}
        df = pd.DataFrame(data)
        df.to_csv(f'cv_models/val_results_{date_}_{mode}.csv', index=False)
            
        # Plot actual vs predicted values
        plt.scatter(train_trues, train_preds, label='train', alpha=0.3)
        plt.scatter(val_trues, val_preds, label='val')
        plt.plot([0, 4], [0, 4], '--', color='r', label='optimal')
        plt.xlabel("Actual Values")
        plt.ylabel("Predicted Values")
        plt.legend()
        plt.savefig(f"cv_models/scatter_{date_}_{mode}.png")  
        
        print(f'Failed runs: {failed_runs}')
if __name__ == '__main__':
    main()