import torch
import regex as re
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils import negative_sampling
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np

def mol_objecter(dataframe, column):
    mol_obj = []
    for m in dataframe[column]:
        mol_obj.append(Chem.MolFromSmiles(m))
    return mol_obj

def one_of_k_encoding(x, allowable_set):
  if x not in allowable_set:
    raise Exception("input {0} not in allowable set{1}:".format(
        x, allowable_set))
  return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
  """Maps inputs not in the allowable set to the last element."""
  if x not in allowable_set:
    x = allowable_set[-1]
  return list(map(lambda s: x == s, allowable_set))

def make_smiles_adduct(reactant: str, catalyst: str) -> str:
  numbers = re.findall(r'\d', reactant)
  if len(numbers) != 0:
    highest_number_reactant = int(max(numbers))
  else:
    highest_number_reactant = 0
  # mask Br
  catalyst = catalyst.replace('Br', 'Masked')
  catalyst = catalyst.replace('B', '[B-]9')
  catalyst = catalyst.replace('Masked', 'Br')

  # check if ester is in reactant and what type of oxyen double bond 
  equal_oxy = re.findall(r'=O', reactant)
  oxy_equal = re.findall(r'O=', reactant)
  ester = re.findall(r'[(]OC[)]=O', reactant)
  phosp = re.findall(r'P', reactant)

  if len(ester) == 0 and len(phosp) == 0: # no ester and phosp, so free to replace the only oxygen double bond
    if len(equal_oxy) + len(oxy_equal) >= 2:
      if len(equal_oxy) != 0:
        reactant = reactant.replace('=O', '=[O+]9')
      else: 
        reactant = reactant.replace('O=', '[O+]9=') 
    else:
      reactant = reactant.replace('=O', '=[O+]9') 
      reactant = reactant.replace('O=', '[O+]9=') 
  else: # ester has to masked first, so it does not get replaced
    reactant = reactant.replace('(OC)=O', 'XXYYZZ')
    reactant = reactant.replace('O=P', 'LLLL')
    reactant = reactant.replace('=O', '=[O+]9') 
    reactant = reactant.replace('O=', '[O+]9=') 
    reactant = reactant.replace('XXYYZZ', '(OC)=O')
    reactant = reactant.replace('LLLL', 'O=P')

  #update numbering in catalyst
  for i in range(7, 0, -1):
    catalyst = catalyst.replace(f'{i}', str(highest_number_reactant + i))

  #make adduct smile
  adduct = reactant + '.' + catalyst
  
  return adduct


def generate_structure_from_mol(mol):

    # Generate a 3D structure from mol

    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)
    AllChem.UFFOptimizeMolecule(mol)
    
    # get the first (and in this case, the only) conformer
    conf = mol.GetConformer()
    coordinates = conf.GetPositions()
    coordinates = np.array(coordinates)
    
    return conf, coordinates, mol

def atom_features(atom,
                  xyz = None,
                  bool_id_feat=False,
                  explicit_H=False,
                  use_chirality=True,
                  use_partial_charge=True,
                  ):
  if bool_id_feat:
    return np.array([atom_to_id(atom)])
  else:
    from rdkit import Chem
    results = one_of_k_encoding_unk(
      atom.GetSymbol(),
      [
        'C',
        'N',
        'O',
        'S',
        'F',
        'Si',
        'P',
        'Cl',
        'Br',
        'Mg',
        'Na',
        'Ca',
        'Fe',
        'As',
        'Al',
        'I',
        'B',
        'V',
        'K',
        'Tl',
        'Yb',
        'Sb',
        'Sn',
        'Ag',
        'Pd',
        'Co',
        'Se',
        'Ti',
        'Zn',
        'H',  # H?
        'Li',
        'Ge',
        'Cu',
        'Au',
        'Ni',
        'Cd',
        'In',
        'Mn',
        'Zr',
        'Cr',
        'Pt',
        'Hg',
        'Pb',
        'Unknown'
      ]) + one_of_k_encoding(atom.GetDegree(),
                             [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + \
              one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
              [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
              one_of_k_encoding_unk(atom.GetHybridization(), [
                Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
                                    SP3D, Chem.rdchem.HybridizationType.SP3D2
              ]) + [atom.GetIsAromatic()]
    # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
    if not explicit_H:
      results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                [0, 1, 2, 3, 4])
    if use_chirality:
      try:
        results = results + one_of_k_encoding_unk(
            atom.GetProp('_CIPCode'),
            ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
      except:
        results = results + [False, False
                            ] + [atom.HasProp('_ChiralityPossible')]
    if use_partial_charge:
      try:
        #print(atom.GetProp('_GasteigerCharge'))
        #print(type(atom.GetProp('_GasteigerCharge')))
        results = results + [float(atom.GetProp('_GasteigerCharge'))]
      except:
        print('Failed to compute GasteigerCharge')

    return np.array(results)

def bond_features(bond, use_chirality=True):
  from rdkit import Chem
  bt = bond.GetBondType()
  bond_feats = [
      bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
      bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
      bond.GetIsConjugated(),
      bond.IsInRing(),
  ]
  if use_chirality:
    bond_feats = bond_feats + one_of_k_encoding_unk(
        str(bond.GetStereo()),
        ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"])
  bond_feats = bond_feats + [0]
  return np.array(bond_feats)

def get_bond_pair_complex(mol_1, mol_2):
    bonds_1 = mol_1.GetBonds()
    bonds_2 = mol_2.GetBonds()

    atoms_1 = mol_1.GetAtoms()
    atoms_2 = mol_2.GetAtoms()

    for atom in atoms_1:
        if atom.GetSymbol() == "B":
            bor_index = atom.GetIdx()

    for atom in atoms_2:
        if (atom.GetSymbol() == "O" and
            atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2 and
            len(atom.GetBonds()) == 1):

            bonds = atom.GetBonds()
            for bond in bonds:
                if bond.GetEndAtom().GetSymbol() == 'C':
                    carbonyl_candidate = bond.GetEndAtom()
                else:
                    carbonyl_candidate = bond.GetBeginAtom()
                
            carbonyl_bonds = carbonyl_candidate.GetBonds()
            counter = 0
            for bond in carbonyl_bonds:
                if (bond.GetBeginAtom().GetSymbol() == 'O' or 
                    bond.GetEndAtom().GetSymbol() == 'O'):
                    counter += 1   
            if counter < 2:
              carbonyl_index = atom.GetIdx()

    res = [[],[]]
    for bond in bonds_1:
        res[0] += [bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]
        res[1] += [bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]
    highest_index_0 = max(res[0]) + 1
    highest_index_1 = max(res[1]) + 1
    for bond in bonds_2:
        res[0] += [highest_index_0  + bond.GetBeginAtomIdx(),  highest_index_0  + bond.GetEndAtomIdx()]
        res[1] += [highest_index_1 + bond.GetEndAtomIdx(),  highest_index_1 + bond.GetBeginAtomIdx()]

    res[0] += [bor_index,  highest_index_0  + carbonyl_index]
    res[1] += [carbonyl_index + highest_index_1,  bor_index]

    return res

def complex2onlygraph(mol_1, mol_2, ee):

    try:
      from rdkit.Chem import AllChem
      AllChem.ComputeGasteigerCharges(mol_1)
      AllChem.ComputeGasteigerCharges(mol_2)
    except ModuleNotFoundError:
      raise ImportError("This class requires RDKit to be installed.")

    atoms_1 = mol_1.GetAtoms()
    atoms_2 = mol_2.GetAtoms()

    #compute partial charges
    node_f_1 = [atom_features(atom) for atom in atoms_1]
    node_f_2 = [atom_features(atom) for atom in atoms_2]
    all_nodes = node_f_1 + node_f_2

    bonds_1 = mol_1.GetBonds()
    bonds_2 = mol_2.GetBonds()

    all_edge_index = get_bond_pair_complex(mol_1, mol_2)
    neg_edge_index = negative_sampling(torch.tensor(all_edge_index, dtype=torch.long)).t()
    edge_attr_1 = [bond_features(bond, use_chirality=True) for bond in bonds_1]
    #print(edge_attr_1)
    for bond in bonds_1:
        edge_attr_1.append(bond_features(bond))

    edge_attr_2 = [bond_features(bond, use_chirality=True) for bond in bonds_2]
    #print(edge_attr_2)
    for bond in bonds_2:
        edge_attr_2.append(bond_features(bond))

    # the last bond is the coordination
    edge_attr_3 = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
    edge_attr_4 = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
    all_edge_attr = edge_attr_1 + edge_attr_2

    all_edge_attr.append(edge_attr_3)
    all_edge_attr.append(edge_attr_4)

    data = Data(x=torch.tensor(all_nodes, dtype=torch.float),
              edge_index=torch.tensor(all_edge_index, dtype=torch.long),
              neg_edges=torch.tensor(neg_edge_index),
              edge_attr=torch.tensor(all_edge_attr, dtype=torch.float),
              ee=ee
              )
    return data

def my_graph_only_dataloader(
                              dataframe,
                              bs = 9,
                              num_classes = 2,
                              target_column = 'ee'
                              ):

    data_list = []  
    data_dic = {
      'ee' : [],
    }
    graph_list = []

    for ee in dataframe[target_column]:
      ee = torch.FloatTensor([ee]).view(1,1)

      data_dic['ee'].append(ee)

      cond_data = Data(ee=ee)
      data_list.append(cond_data)

    #get mols for product and cat 
    r1_mol_list = mol_objecter(dataframe, column="r1")
    cat_mol_list = mol_objecter(dataframe, column="cat")
    graph_list = [complex2onlygraph(m1, m2, ee) for m1, m2, ee in zip(cat_mol_list, r1_mol_list, data_dic['ee'])]


    return DataLoader(graph_list, batch_size=bs, shuffle=False), DataLoader(data_list, batch_size=bs, shuffle=False), graph_list