import torch
from torch.nn import ModuleList
from torch.nn import Dropout
import warnings
import warnings; warnings.filterwarnings(action='once')
import torch
from torch.nn import Linear
from torch_geometric.nn import CGConv
from torch_geometric.nn import GCNConv
from torch_geometric.nn import TransformerConv
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn import global_add_pool as gap
from torch_geometric.nn import global_mean_pool as gmp
from torch_geometric.nn import global_max_pool as gmaxp 

class OneToRuleThemAll(torch.nn.Module):
    def __init__(self,
                #GNNS 
                gnn_layer_type,
                mol_gnn_layer_type,
                num_gnn_layers,
                num_mol_gnn_layers,
                gnn_output_size,
                mol_gnn_output_size,
                gnn_heads,
                mol_gnn_heads,
                gnn_activation,
                mol_gnn_activation,
                #AFP
                attentionlayer_size,
                embedding_size,
                num_afp_layers,
                num_timesteps,
                afp_dropout,
                #DNN
                dnn_activation,
                num_hl,
                num_gnn_hl,
                num_mol_gnn_hl,
                gnn_hl_size,
                mol_gnn_hl_size,
                hl_size,
                hl_dropout,
                #Pooling
                mol_max_pool = True,
                mol_mean_pool = True,
                mol_add_pool = True,
                max_pool = True,
                mean_pool = True,
                add_pool = True,
                #Architecture
                triangle = False,
                invariant_hl_size = True,
                u_net_type = True,
                mol_u_net_type = True,
                seperated_dnn = False,
                #Constants?                                
                node_dim=78,
                edge_dim=10,
                extra_energies=1
                ):
        super(OneToRuleThemAll, self).__init__()    

        self.gnn_layer_type = gnn_layer_type
        self.mol_gnn_layer_type = mol_gnn_layer_type  
         
        self.num_mol_gnn_layers = num_mol_gnn_layers
        self.num_gnn_layers = num_gnn_layers
        
        self.gnn_output_size = gnn_output_size
        self.mol_gnn_output_size = mol_gnn_output_size
        
        self.gnn_heads = gnn_heads
        self.mol_gnn_heads = mol_gnn_heads
        
        self.gnn_hl_size = gnn_hl_size 
        self.mol_gnn_hl_size = mol_gnn_hl_size 
        self.num_gnn_hl = num_gnn_hl
        self.num_mol_gnn_hl = num_mol_gnn_hl

        self.gnn_activation = gnn_activation 
        self.mol_gnn_activation = mol_gnn_activation
        self.dnn_activation = dnn_activation
        
        self.hl_dropout = Dropout(p=hl_dropout)
        self.num_hl = num_hl
        self.hl_size = hl_size
        
        self.mol_max_pool = mol_max_pool
        self.mol_mean_pool = mol_mean_pool
        self.mol_add_pool = mol_add_pool

        self.max_pool = max_pool
        self.mean_pool = mean_pool
        self.add_pool = add_pool        

        # make sure at least one pooling takes place
        if not mol_max_pool and not mol_mean_pool and not mol_add_pool:
            self.mol_max_pool = True
            
        if not max_pool and not mean_pool and not add_pool:
            self.max_pool = True
        
        self.triangle = triangle
        
        self.u_net_type = u_net_type
        self.mol_u_net_type = mol_u_net_type
        self.seperated_dnn = seperated_dnn
        
        # AttentiveFingerprint
        # ------------------------------------------
        self.afp = AttentiveFP(node_dim,
                            attentionlayer_size,
                            embedding_size,
                            edge_dim,
                            num_layers=num_afp_layers,
                            num_timesteps=num_timesteps,
                            dropout=afp_dropout)
        # ------------------------------------------

        # knn part
        # ------------------------------------------
        self.gnn_layers = ModuleList([])
        # possible types: from torch_geometric.nn import CGConv GCNConv TransformerConv
        if gnn_layer_type in [CGConv]:
            self.gnn_output_size = node_dim
            self.g1 = gnn_layer_type(node_dim)
            for _ in range(self.num_gnn_layers):
                self.gnn_layers.append(gnn_layer_type(node_dim))  
        elif gnn_layer_type in [GCNConv]:
            self.g1 = gnn_layer_type(in_channels=node_dim,
                                     out_channels=gnn_output_size)            
            for _ in range(self.num_gnn_layers):
                self.gnn_layers.append(gnn_layer_type(in_channels=gnn_output_size,
                                                    out_channels=gnn_output_size))              
        elif gnn_layer_type in [TransformerConv]:
            self.g1 = gnn_layer_type(in_channels=node_dim,
                                     out_channels=gnn_output_size,
                                     heads=gnn_heads,
                                     concat=False)            
            for _ in range(self.num_gnn_layers):
                self.gnn_layers.append(gnn_layer_type(in_channels=gnn_output_size,
                                                    out_channels=gnn_output_size,
                                                    heads=gnn_heads,
                                                    concat=False))  
        # ------------------------------------------

        # mol graph part
        # ------------------------------------------
        self.mol_gnn_layers = ModuleList([])
        
        if mol_gnn_layer_type in [CGConv]:
            self.mol_gnn_output_size = node_dim
            self.mol_g1 = mol_gnn_layer_type(node_dim, edge_dim)
            for _ in range(self.num_mol_gnn_layers):
                self.mol_gnn_layers.append(mol_gnn_layer_type(node_dim, edge_dim))  
        elif mol_gnn_layer_type in [GCNConv]:
            self.mol_g1 = mol_gnn_layer_type(in_channels=node_dim,
                                     out_channels=mol_gnn_output_size)            
            for _ in range(self.num_mol_gnn_layers):
                self.mol_gnn_layers.append(mol_gnn_layer_type(in_channels=mol_gnn_output_size,
                                                    out_channels=mol_gnn_output_size))                           
        elif mol_gnn_layer_type in [TransformerConv]:
            self.mol_g1 = mol_gnn_layer_type(in_channels=node_dim,
                                     out_channels=mol_gnn_output_size,
                                     heads=mol_gnn_heads,
                                     concat=False,
                                     edge_dim=edge_dim)            
            for _ in range(self.num_mol_gnn_layers):
                self.mol_gnn_layers.append(mol_gnn_layer_type(in_channels=mol_gnn_output_size,
                                                    out_channels=mol_gnn_output_size,
                                                    heads=mol_gnn_heads,
                                                    concat=False,
                                                    edge_dim=edge_dim))  
        # ------------------------------------------
        
        # Utils calculations
        # ------------------------------------------
        gnn_poolings = sum([1 if x else 0 for x in [self.max_pool, self.mean_pool, self.add_pool]]) 
        mol_gnn_poolings = sum([1 if x else 0 for x in [self.mol_max_pool, self.mol_mean_pool, self.mol_add_pool]])
        # Output size depends on the poolings performed, on the output size of gnn and gnn type
        # number of poolings are calculated above
        if gnn_layer_type in [CGConv]:
            total_output_size_gnn = node_dim * gnn_poolings
        else:
            total_output_size_gnn = self.gnn_output_size * gnn_poolings

        if mol_gnn_layer_type in [CGConv]:
            total_output_size_mol_gnn = node_dim * mol_gnn_poolings
        else:
            total_output_size_mol_gnn = self.mol_gnn_output_size * mol_gnn_poolings
    
        if u_net_type:
            total_output_size_gnn *= num_gnn_layers

        if mol_u_net_type:
            total_output_size_mol_gnn *= num_mol_gnn_layers

        # ------------------------------------------
            
        #DNN Part
        # ------------------------------------------
        if seperated_dnn:
            self.dnn = ModuleList([])
            self.mol_dnn = ModuleList([])            
            self.comb_dnn = ModuleList([])
            
            self.knng2f = Linear(total_output_size_gnn, gnn_hl_size)
            for i in range(self.num_gnn_hl):
                self.dnn.append(Linear(gnn_hl_size, gnn_hl_size))

            self.molg2f = Linear(total_output_size_mol_gnn, mol_gnn_hl_size)
            for i in range(self.num_mol_gnn_hl):
                self.mol_dnn.append(Linear(mol_gnn_hl_size, mol_gnn_hl_size))        
            
            
            self.g2f = Linear(gnn_hl_size + mol_gnn_hl_size + embedding_size + extra_energies + 1024, hl_size)

            for i in range(self.num_hl):
                if triangle:
                    self.comb_dnn.append(Linear(round(hl_size/(i+1)), round(hl_size/(i+2))))
                else:
                    self.comb_dnn.append(Linear(hl_size, hl_size))                    
            if triangle:
                self.mu = Linear(round(hl_size/(i+2)), 1)
                self.var = Linear(round(hl_size/(i+2)), 1)
            else:
                self.mu = Linear(hl_size, 1)
                self.var = Linear(hl_size, 1)
            
        else:
            self.comb_dnn = ModuleList([])

            # knn, molgraph, molgraph-afp, FSP, MFP in order     
 
            self.g2f = Linear(total_output_size_gnn + total_output_size_mol_gnn + embedding_size + extra_energies + 1024, hl_size) 

            for i in range(self.num_hl):
                if triangle:
                    self.comb_dnn.append(Linear(round(hl_size/(i+1)), round(hl_size/(i+2))))
                else:
                    self.comb_dnn.append(Linear(hl_size, hl_size))                    
            if triangle:
                self.mu = Linear(round(hl_size/(i+2)), 1)
                self.var = Linear(round(hl_size/(i+2)), 1)
            else:
                self.mu = Linear(hl_size, 1)
                self.var = Linear(hl_size, 1)
        # ------------------------------------------                        

    def forward(self, graph, knngraph):
        extras = graph.extras
        fp = graph.fp

        # knn part
        # ------------------------------------------
        knn_emb = knngraph.x
        knn_emb = self.g1(knn_emb, knngraph.edge_index)

        if torch.cuda.is_available():
            total_readout = torch.cuda.FloatTensor()
        else:
            total_readout = torch.FloatTensor()

        if self.u_net_type:
          
            for i in range(self.num_gnn_layers):
                knn_emb = self.gnn_activation(self.gnn_layers[i](knn_emb, knngraph.edge_index))
                for i, pooling in enumerate([self.max_pool, self.mean_pool, self.add_pool]):
                    if pooling:
                        if i == 0:
                            total_readout = torch.cat([total_readout,
                                                            gmaxp(knn_emb, knngraph.batch)], dim=1)
                        elif i == 1:
                            total_readout = torch.cat([total_readout,
                                                            gmp(knn_emb, knngraph.batch)], dim=1)                                
                        elif i == 2:
                            total_readout = torch.cat([total_readout,
                                                            gap(knn_emb, knngraph.batch)], dim=1)   
            
        else:
            for i in range(self.num_gnn_layers):
                knn_emb = self.gnn_activation(self.gnn_layers[i](knn_emb, knngraph.edge_index))

            for i, pooling in enumerate([self.max_pool, self.mean_pool, self.add_pool]):
                if pooling:
                    if i == 0:
                        total_readout = torch.cat([total_readout,
                                                        gmaxp(knn_emb, knngraph.batch)], dim=1)
                    elif i == 1:
                        total_readout = torch.cat([total_readout,
                                                        gmp(knn_emb, knngraph.batch)], dim=1)                                
                    elif i == 2:
                        total_readout = torch.cat([total_readout,
                                                        gap(knn_emb, knngraph.batch)], dim=1)   
        knn_emb = total_readout
        # ------------------------------------------

        # mol gnn part
        # ------------------------------------------
        graph_emb = graph.x
        if self.mol_gnn_layer_type in [GCNConv]:
            graph_emb = self.mol_gnn_activation(self.mol_g1(graph_emb, graph.edge_index))
        else:
            graph_emb = self.mol_gnn_activation(self.mol_g1(graph_emb, graph.edge_index, graph.edge_attr))    

        if torch.cuda.is_available():
            mol_total_readout = torch.cuda.FloatTensor()
        else:
            mol_total_readout = torch.FloatTensor()

        if self.mol_u_net_type:
                   
            for i in range(self.num_mol_gnn_layers):
                if self.mol_gnn_layer_type in [GCNConv]:
                    graph_emb = self.mol_gnn_activation(self.mol_gnn_layers[i](graph_emb, graph.edge_index))
                else:
                    graph_emb = self.mol_gnn_activation(self.mol_gnn_layers[i](graph_emb, graph.edge_index, graph.edge_attr))

                for i, pooling in enumerate([self.mol_max_pool, self.mol_mean_pool, self.mol_add_pool]):
                    if pooling:
                        if i == 0:
                            mol_total_readout = torch.cat([mol_total_readout,
                                                            gmaxp(graph_emb, graph.batch)], dim=1)
                        elif i == 1:
                            mol_total_readout = torch.cat([mol_total_readout,
                                                            gmp(graph_emb, graph.batch)], dim=1)                                
                        elif i == 2:
                            mol_total_readout = torch.cat([mol_total_readout,
                                                            gap(graph_emb, graph.batch)], dim=1)  
            
        else:
            for i in range(self.num_mol_gnn_layers):
                if self.mol_gnn_layer_type in [GCNConv]:
                    graph_emb = self.mol_gnn_activation(self.mol_gnn_layers[i](graph_emb, graph.edge_index))
                else:
                    graph_emb = self.mol_gnn_activation(self.mol_gnn_layers[i](graph_emb, graph.edge_index, graph.edge_attr))

            for i, pooling in enumerate([self.mol_max_pool, self.mol_mean_pool, self.mol_add_pool]):
                if pooling:
                    if i == 0:
                        mol_total_readout = torch.cat([mol_total_readout,
                                                        gmaxp(graph_emb, graph.batch)], dim=1)
                    elif i == 1:
                        mol_total_readout = torch.cat([mol_total_readout,
                                                        gmp(graph_emb, graph.batch)], dim=1)                                
                    elif i == 2:
                        mol_total_readout = torch.cat([mol_total_readout,
                                                        gap(graph_emb, graph.batch)], dim=1)
        graph_emb = mol_total_readout                                    
        # ------------------------------------------      
             
        #Attentive Fingerprint             
        # ------------------------------------------                 
        afp = self.afp(graph.x, graph.edge_index, graph.edge_attr, graph.batch)
        # ------------------------------------------

        if self.seperated_dnn:
            knn_emb = self.knng2f(knn_emb)
            for i in range(self.num_gnn_hl):
                knn_emb = self.dnn[i](knn_emb)
                knn_emb = self.dnn_activation(knn_emb)
                
            graph_emb = self.molg2f(graph_emb)
            for i in range(self.num_mol_gnn_hl):
                graph_emb = self.mol_dnn[i](graph_emb)
                graph_emb = self.dnn_activation(graph_emb)
                
        comb_emb = torch.cat([knn_emb, graph_emb, afp, fp, extras], dim=1)
        comb_emb = self.g2f(comb_emb)
        x_all = self.dnn_activation(comb_emb)
            
        for i in range(self.num_hl):
            x_all = self.dnn_activation(self.comb_dnn[i](x_all))
            x_all = self.hl_dropout(x_all)
        
        mu = self.mu(x_all)
        var = torch.exp(self.var(x_all)) # exponential trick to make var non-negative
        
        return mu, var