import preprocessing_data_new
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import LeaveOneOut
import config
import torch
import pandas as pd 
import preprocessing_graph_new
from sklearn.metrics import r2_score
from datetime import date
from models import Net

def train(model, loader, optimizer, loss_fn, device='cuda'):
    model.train()
    losses = 0
    for batch in loader:
        batch.to(device)
        optimizer.zero_grad()
        pred, _, _ = model(batch.x,
                     batch.edge_index,
                     batch.batch,
                     batch.edge_attr)
        loss = loss_fn(pred, batch.ee)
        loss.backward()
        losses += loss.item()
        optimizer.step()
        
    return losses / len(loader)

def val(model, loader, loss_fn, device='cuda'):
    model.eval()
    losses = 0
    with torch.no_grad():
        for batch in loader:
            batch.to(device)
            pred, _, _ = model(batch.x,
                        batch.edge_index,
                        batch.batch,
                        batch.edge_attr)
            loss = loss_fn(pred, batch.ee)
            losses += loss.item()
            
    return losses/len(loader)

def get_predictions(model, loader, device='cuda'):
    model.eval()
    preds = []
    trues = []
    # Predict the labels of the test set
    with torch.no_grad():
        for batch in loader:
            batch.to(device)
            pred, _, _ = model(batch.x,
                        batch.edge_index,
                        batch.batch,
                        batch.edge_attr)
            true = batch.ee
            # Convert tensors to lists of floats and extend the lists
            preds.extend(pred.view(-1).tolist())
            trues.extend(true.view(-1).tolist())    

    return preds, trues


def main():
    best_parameters =  config.best_parameters
    date_ = date.today()
    failed_runs = []
    epochs = 5000
    delta = 1e-5
    patience = 200
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    loss_fn = torch.nn.MSELoss()
    num_classes  = 2
    target_column = 'ddg'
    rel_pred = True
    cbs_df, num_cond = preprocessing_data_new.preprocess_cbs(data_path='../Database/CBS_10-04-2023.csv',
                                                            n_classes = num_classes,
                                                            target_column = target_column, aug_data = False, threshs=[56],
                                                            dropping_list=None, merge_columns=False,
                                                            verbose=True)
    print(cbs_df['ee'].mean())
    cbs_df['label2'] = cbs_df['ee'].apply(lambda x: np.array([0.0, 1.0]) if x<= cbs_df['ee'].mean() else np.array([1.0, 0.0]))
    cbs_df['label2'] = cbs_df['label2'].apply(lambda x: [float(i) for i in x])
    
    if rel_pred:
        # if AC column is S, then -x
        # if AC column is R, then +x
        cbs_df['ddg'] = cbs_df.apply(lambda x: x['ddg'] if x['AC'] == 'S' else -x['ddg'], axis=1)
        # same for ee
        cbs_df['ee'] = cbs_df.apply(lambda x: x['ee'] if x['AC'] == 'S' else -x['ee'], axis=1)

    # Perform LOOCV
    loo = LeaveOneOut()
    mse_scores = []
    
    train_preds = []
    val_preds = []
    train_trues = []
    val_trues = []

    print('----------------')
    print(f'Starting CV')

    run_counter = 1
    for train_index, test_index in loo.split(cbs_df):
        try_counter = 0
        r2_value = 0
        while r2_value < 0.9 and try_counter < 10:
            print(f'Starting run: {run_counter}')
            model = Net(gnn_dim=best_parameters['embedding_size'],
                n_gnn_layers=best_parameters['n_gnn_layers'],
                n_ffnn_neurons=best_parameters['hl_size'],
                n_ffnn_layers=best_parameters['n_hl'],
                pool_type=best_parameters['pooling'],
                activation_function=best_parameters['activation']
                ).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
            
            train_df, val_df = cbs_df.iloc[train_index], cbs_df.iloc[test_index]

            train_gloader, train_dloader, train_graph_list = preprocessing_graph_new.my_graph_only_dataloader(train_df,
                                                                                            target_column=target_column,
                                                                                            bs=best_parameters['bs'])
            val_gloader, val_dloader, val_graph_list = preprocessing_graph_new.my_graph_only_dataloader(val_df,
                                                                                        target_column=target_column,
                                                                                        bs=1)
            best_loss = np.inf
            for epoch in range(5000):
                loss = train(model, train_gloader, optimizer, loss_fn, device)
                val_loss = val(model, val_gloader, loss_fn, device)
                
                if epoch % 5 == 0:
                    print(f"< Epoch : {epoch} ||  Loss : {loss} || <<>> Val-Loss : {val_loss}")

                # Check for early stopping and save the best model
                if val_loss + delta < best_loss:
                    best_loss = val_loss
                    counter = 0
                    torch.save(model.state_dict(), 'cv_models/model_oli_loocv.pt')  # Save the best model
                else:
                    counter += 1
                    if counter >= patience:
                        print(f'Validation loss did not improve for {patience} epochs. Training stopped.')
                        break
            train_gloader, train_dloader, train_graph_list = preprocessing_graph_new.my_graph_only_dataloader(train_df,
                                                                                            target_column=target_column,
                                                                                            bs=1)
            model.load_state_dict(torch.load(f"cv_models/model_oli_loocv.pt"))

            train_pred, train_true = get_predictions(model, train_gloader, device)
            val_pred, val_true = get_predictions(model, val_gloader, device)   
                            
            r2_value = r2_score(train_pred, train_true)
            try_counter += 1
        
        if try_counter >= 10:
            print('Training failed')
            failed_runs.append(run_counter)
        else:    
            print('Training successfull')
            
            train_preds.extend(train_pred)
            train_trues.extend(train_true)
            
            val_preds.extend(val_pred)
            val_trues.extend(val_true)
            
            
            # Plot actual vs predicted values
            plt.scatter(train_true, train_pred, label='train', alpha=0.3)
            plt.scatter(val_true, val_pred, label='val')
            plt.plot([0, 4], [0, 4], '--', color='r', label='optimal')
            plt.xlabel("Actual Values")
            plt.ylabel("Predicted Values")
            plt.legend()
            plt.savefig(f"cv_models/scatter_{run_counter}.png")  
            plt.close()
        
        run_counter += 1
    print('----------------')
    
    data = {'train_true' : train_trues,
            'train_pred' : train_preds}
    df = pd.DataFrame(data)
    df.to_csv(f'cv_models/train_results_{date_}.csv', index=False)
    
    data = {'val_true' : val_trues,
            'val_pred' : val_preds}
    df = pd.DataFrame(data)
    df.to_csv(f'cv_models/val_results_{date_}.csv', index=False)
        
    # Plot actual vs predicted values
    plt.scatter(train_trues, train_preds, label='train', alpha=0.3)
    plt.scatter(val_trues, val_preds, label='val')
    plt.plot([0, 4], [0, 4], '--', color='r', label='optimal')
    plt.xlabel("Actual Values")
    plt.ylabel("Predicted Values")
    plt.legend()
    plt.savefig(f"cv_models/scatter_{date_}.png")  
    
    print(f'Failed runs: {failed_runs}')
if __name__ == '__main__':
    main()