import torch
from torch_geometric.nn import GATv2Conv, global_add_pool, global_max_pool, global_mean_pool
from torch.nn import Linear, ModuleList
import torch.nn.functional as F

class xAI_Net(torch.nn.Module):
    def __init__(self, gnn_dim, n_gnn_layers, n_ffnn_neurons, n_ffnn_layers, pool_type, activation_function):
        super(xAI_Net, self).__init__()
        
        num_features = 79
        edge_dim = 11
        self.gnn_layers = ModuleList([GATv2Conv(
            num_features if i == 0 else gnn_dim, gnn_dim, edge_dim=edge_dim, add_self_loops=False) 
            for i in range(n_gnn_layers)
        ])
        
        self.ffnn_layers = ModuleList([Linear(
            gnn_dim if i == 0 else n_ffnn_neurons, n_ffnn_neurons) 
            for i in range(n_ffnn_layers)
        ])
        
        self.pool_type = pool_type
        self.activation = activation_function
        
        self.mu = Linear(n_ffnn_neurons, 1)
    
    def forward(self, x, edge_index, batch, edge_weight=None):
        for conv in self.gnn_layers:
            x, _ = conv(x, edge_index, edge_weight, return_attention_weights=True)
            x = self.activation(x)
        
        if self.pool_type == 'mean':
            x = global_mean_pool(x, batch)
        elif self.pool_type == 'max':
            x = global_max_pool(x, batch)
        else:
            x = global_add_pool(x, batch)
            
        for lin in self.ffnn_layers:
            x = self.activation(lin(x))
            x = F.dropout(x, p=0.5, training=self.training)
        
        mu = self.mu(x)
        return mu

class FFNN_Net(torch.nn.Module):
    def __init__(self, n_ffnn_neurons, n_ffnn_layers, activation_function, fp_dim):
        super(FFNN_Net, self).__init__()
                
        self.ffnn_layers = ModuleList([Linear(
            fp_dim if i == 0 else n_ffnn_neurons, n_ffnn_neurons) 
            for i in range(n_ffnn_layers)
        ])
        self.activation = activation_function
        
        self.mu = Linear(n_ffnn_neurons, 1)
    
    def forward(self, x):

        for lin in self.ffnn_layers:
            x = self.activation(lin(x))
            x = F.dropout(x, p=0.5, training=self.training)
        
        mu = self.mu(x)

        return mu