import torch
from torch_geometric.nn import GATv2Conv, global_add_pool, global_max_pool, global_mean_pool
from torch.nn import Linear, ModuleList
import torch.nn.functional as F
from skopt import Optimizer
from skopt.utils import use_named_args
from torch.optim import Adam
import preprocessing_graph_new
import preprocessing_data_new
import numpy as np
num_classes  = 2
target_column = 'ddg'
rel_pred = True
cbs_df, num_cond = preprocessing_data_new.preprocess_cbs(data_path='../Database/CBS_10-04-2023.csv',
                                                         n_classes = num_classes,
							                            target_column = target_column, aug_data = False, threshs=[57],
                                                         dropping_list=None, merge_columns=False,
                                                         verbose=True)
from sklearn.model_selection import train_test_split
print(cbs_df['ee'].mean())
cbs_df['label2'] = cbs_df['ee'].apply(lambda x: np.array([0.0, 1.0]) if x<= cbs_df['ee'].mean() else np.array([1.0, 0.0]))
cbs_df['label2'] = cbs_df['label2'].apply(lambda x: [float(i) for i in x])
if rel_pred:
    # if AC column is S, then -x
    # if AC column is R, then +x
    cbs_df['ddg'] = cbs_df.apply(lambda x: x['ddg'] if x['AC'] == 'S' else -x['ddg'], axis=1)
    # same for ee
    cbs_df['ee'] = cbs_df.apply(lambda x: x['ee'] if x['AC'] == 'S' else -x['ee'], axis=1)

train_data, val_data = train_test_split(cbs_df, test_size=0.2, random_state=1337, stratify=cbs_df['label2'])


def train(model, graph_loader, cond_loader, optimizer, epoch, cfg, loss_fn):
    model.train()
    loss_all = 0
    accs = 0
    # Enumerate over the data
    for data_graph, data_cond in zip(graph_loader, cond_loader):
        # Use GPU
        data_graph.to(device)
        data_cond.to(device)
        # Reset gradients
        optimizer.zero_grad()
        # Passing the node features and the connection info
        pred, _ = model(data_graph, data_cond)


        # Calculating the loss and gradients
        
        loss = loss_fn(pred, torch.max(data_cond.label, 1)[1])
        loss_all += loss.item() 

        acc = calc_accuracy(pred, data_cond.label)
        accs += acc

        loss.backward()
        # Update using the gradients
        optimizer.step()


    return loss_all / len(graph_loader), accs/ len(cond_loader)

def val(model, graph_loader, cond_loader, cfg, loss_fn):
    loss_all = 0
    accs = 0
    model.eval()
    for data_graph_val, data_cond_val in zip(graph_loader, cond_loader):
        with torch.no_grad():
            data_graph_val.to(device)
            data_cond_val.to(device)
            pred, _ = model(data_graph_val, data_cond_val)
            loss = loss_fn(pred, torch.max(data_cond_val.label, 1)[1])
            loss_all += loss.item() 

            acc = calc_accuracy(pred, data_cond_val.label)
            accs += acc

    return loss_all / len(graph_loader),  accs/ len(cond_loader)

class xAI_Net(torch.nn.Module):
    def __init__(self, gnn_dim, n_gnn_layers, n_ffnn_neurons, n_ffnn_layers, pool_type, activation_function, lr):
        super(xAI_Net, self).__init__()
        
        num_features = 79
        edge_dim = 11
        self.gnn_layers = ModuleList([GATv2Conv(
            num_features if i == 0 else gnn_dim, gnn_dim, edge_dim=edge_dim, add_self_loops=False) 
            for i in range(n_gnn_layers)
        ])
        
        self.ffnn_layers = ModuleList([Linear(
            gnn_dim if i == 0 else n_ffnn_neurons, n_ffnn_neurons) 
            for i in range(n_ffnn_layers)
        ])
        
        self.pool_type = pool_type
        self.activation = activation_function
        self.lr = lr
        
        self.mu = Linear(n_ffnn_neurons, 1)
    
    def forward(self, x, edge_index, batch, edge_weight=None):
        for conv in self.gnn_layers:
            x, _ = conv(x, edge_index, edge_weight, return_attention_weights=True)
            x = self.activation(x)
        
        if self.pool_type == 'mean':
            x = global_mean_pool(x, batch)
        elif self.pool_type == 'max':
            x = global_max_pool(x, batch)
        else:
            x = global_add_pool(x, batch)
            
        for lin in self.ffnn_layers:
            x = self.activation(lin(x))
            x = F.dropout(x, p=0.5, training=self.training)
        
        mu = self.mu(x)
        return mu

bs_val = len(val_data)
val_data = val_data.reset_index()
train_data = train_data.reset_index()
train_gloader, _, _ = preprocessing_graph_new.my_graph_only_dataloader(train_data,
                                                                        target_column=target_column,
                                                                        bs=best_parameters['bs'])
val_gloader, _, _ = preprocessing_graph_new.my_graph_only_dataloader(val_data,
                                                                               target_column=target_column,
                                                                               bs=bs_val)

# Objective function for Bayesian Optimization
def objective(**params):
    model = xAI_Net(**params)
    optimizer = Adam(model.parameters(), lr=params['lr'])
    criterion = torch.nn.MSELoss()
    
    # Assume train_loader is your training data loader
    for epoch in range(n_epochs):
        for data in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, data.y)
            loss.backward()
            optimizer.step()
    
    # Evaluate the model and return MSE (Assuming validation_loader is your validation data loader)
    # val_loss = 0
    # with torch.no_grad():
    #     for data in validation_loader:
    #         output = model(data)
    #         val_loss += criterion(output, data.y).item()
    
    # Replace the following line with the actual MSE
    mse = 0.05
    return mse

# Hyperparameter space
space = {
    'gnn_dim': (32, 512),
    'n_gnn_layers': (1, 5),
    'n_ffnn_neurons': (32, 512),
    'n_ffnn_layers': (1, 5),
    'pool_type': ['mean', 'max', 'add'],
    'activation_function': [F.relu, F.leaky_relu],
    'lr': (1e-4, 1e-2),
    'bs': (4, 32)
}

optimizer = Optimizer(
    dimensions=[space[key] for key in space.keys()],
    random_state=1,
    base_estimator='GP'
)

for i in range(50):  # Number of iterations
    suggested_params = optimizer.ask()
    named_params = {key: value for key, value in zip(space.keys(), suggested_params)}
    mse = objective(**named_params)
    optimizer.tell(suggested_params, mse)