import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np

def mol_objecter(dataframe, column):
    mol_obj = []
    for m in dataframe[column]:
        mol_obj.append(Chem.MolFromSmiles(m))
    return mol_obj

def one_of_k_encoding(x, allowable_set):
  if x not in allowable_set:
    raise Exception("input {0} not in allowable set{1}:".format(
        x, allowable_set))
  return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
  """Maps inputs not in the allowable set to the last element."""
  if x not in allowable_set:
    x = allowable_set[-1]
  return list(map(lambda s: x == s, allowable_set))

def generate_structure_from_mol(mol):

    # Generate a 3D structure from mol

    #mol = Chem.AddHs(mol)

    status = AllChem.EmbedMolecule(mol,
                                   useExpTorsionAnglePrefs=True,
                                    useBasicKnowledge=True,
                                    enforceChirality=True,
                                    )
    status = AllChem.MMFFOptimizeMolecule(mol)
    conformer = mol.GetConformer()
    coordinates = conformer.GetPositions()
    coordinates = np.array(coordinates)
    
    return conformer, coordinates, mol

def atom_features(atom,
                  xyz = None,
                  bool_id_feat=False,
                  explicit_H=False,
                  use_chirality=True,
                  use_partial_charge=True,
                  ):
  if bool_id_feat:
    return np.array([atom_to_id(atom)])
  else:
    from rdkit import Chem
    results = one_of_k_encoding_unk(
      atom.GetSymbol(),
      [
        'C',
        'N',
        'O',
        'S', #legacy
        #'F',
        #'P',
        #'Cl',
        #'Br',
      ]) + one_of_k_encoding(atom.GetDegree(),
                             [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + \
              one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
              [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
              one_of_k_encoding_unk(atom.GetHybridization(), [
                Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
                                    SP3D, Chem.rdchem.HybridizationType.SP3D2
              ]) + [atom.GetIsAromatic()]
    # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
    if not explicit_H:
      results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                [0, 1, 2, 3, 4])
    if use_chirality:
      try:
        results = results + one_of_k_encoding_unk(
            atom.GetProp('_CIPCode'),
            ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
      except:
        results = results + [False, False
                            ] + [atom.HasProp('_ChiralityPossible')]
    if use_partial_charge:
      Chem.rdPartialCharges
      results = results + [float(atom.GetProp('_GasteigerCharge'))]


    return np.array(results)

def bond_features(bond, use_chirality=True):
  from rdkit import Chem
  bt = bond.GetBondType()
  bond_feats = [
      bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
      bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
      bond.GetIsConjugated(),
      bond.IsInRing(),
  ]
  if use_chirality:
    bond_feats = bond_feats + one_of_k_encoding_unk(
        str(bond.GetStereo()),
        ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"])
  return np.array(bond_feats)

def get_bond_pair(mol):
  bonds = mol.GetBonds()
  res = [[],[]]
  for bond in bonds:
    res[0] += [bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]
    res[1] += [bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]
  return res

def get_edge_index(mol):
    row, col = [], []
    
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        row += [start, end]
        col += [end, start]
        
    return torch.tensor([row, col], dtype=torch.long)


def mol2graph_noBasis(mol, extras, ccsdt, incre, smile, fp, chirality=True, xyz_features = True, explicit_H=False):

  atoms = mol.GetAtoms()
  failed_embedding = False
  if xyz_features:
    try:
      _, xyz, mol = generate_structure_from_mol(mol)
      pos = torch.tensor(xyz, dtype=torch.float)
    except:
      failed_embedding = True
      xyz = [[0.0, 0.0, 0.0]] * len(atoms)

      pos = torch.tensor(xyz, dtype=torch.float)
  AllChem.ComputeGasteigerCharges(mol)
  if xyz_features and not failed_embedding:
    node_f = [atom_features(atom, pos, explicit_H=explicit_H) for atom, pos in zip(atoms, xyz) if atom.GetSymbol() != 'H']

  elif xyz_features and failed_embedding:
    node_f = [atom_features(atom, [0.0, 0.0, 0.0], explicit_H=explicit_H) for atom in atoms if atom.GetSymbol() != 'H']
  else:
    node_f = [atom_features(atom) for atom in atoms]
    
  bonds = mol.GetBonds()
  
  edge_index = get_bond_pair(mol)
  edge_attr = [bond_features(bond, use_chirality=chirality) for bond in bonds]
  for bond in bonds:
    edge_attr.append(bond_features(bond))

  data = Data(x=torch.tensor(node_f, dtype=torch.float),
              edge_index=torch.tensor(edge_index, dtype=torch.long),
#             neg_edges=torch.tensor(neg_edge_index),
              edge_attr=torch.tensor(edge_attr, dtype=torch.float),
              extras=torch.tensor([extras, ], dtype=torch.float), ccsdt=ccsdt, incre=incre, 
              smile=smile,
              fp=torch.tensor([fp, ], dtype=torch.float),
              pos=pos,
              )
  return data

def clean_xyz(xyz):
    new_xyz = []
    xyz = xyz.split('!')
    for coord in xyz:
        coord = coord.replace(';', ' ')
        if 'H' not in coord and len(coord) > 1:
            new_xyz.append(coord)
            
    xyz_self = [xyz[2:] for xyz in new_xyz]
    xyz_list = [[coord.split(' ')] for coord in xyz_self]
    xyz_list = [line[0][:3] for line in xyz_list]
    xyz_arr = np.asarray(xyz_list, dtype=float)
       
    return xyz_arr

def calc_distances(xyz):
    distances = []
    # loop through atoms
    for index, atom in enumerate(xyz):
        distance = []
        # loop through neighbors
        for index_neigh, atom_neigh in enumerate(xyz):
            # calc euclidian distance
            d = np.sqrt((atom[0]-atom_neigh[0])**2 + (atom[1]-atom_neigh[1])**2 + (atom[2]-atom_neigh[2])**2)
            distance.append(d)
        distances.append(distance)
    distances = np.asarray(distances, dtype=float)
    return distances
 
def get_edges(distances, cut_off_radius = 3.0):
    edges = [[], []]
    edge_attr = []
    for index, atom in enumerate(distances):
        for neigh_index, dis in enumerate(atom):
            if dis < cut_off_radius and dis > 0.1:
                edges[0] += [index] # bond from a to b
                edges[1] += [neigh_index] # bond from b to a
                edge_attr.append([dis])
    edges = np.asarray(edges, dtype=float)
    edge_attr = np.asarray(edge_attr, dtype=float)

    return edges, edge_attr 
 
def mol2cluster(mol, xyz, cut_off_radius, chirality=True):
  atoms = mol.GetAtoms()
  AllChem.ComputeGasteigerCharges(mol)
  node_f = [atom_features(atom, use_chirality=chirality) for atom in atoms if atom.GetSymbol() != 'H']
  xyz_clean = clean_xyz(xyz)
  distances = calc_distances(xyz_clean)
  #print(distances)
  e_index, edge_attr = get_edges(distances, cut_off_radius)

  data = Data(x=torch.tensor(node_f, dtype=torch.float),
              edge_index=torch.tensor(e_index, dtype=torch.long),
              edge_attr=torch.tensor(edge_attr, dtype=torch.float),
              )
  return data

def my_knnG_dataloader(dataframe,
                      bs = 16,
                      xyz_features = False,
                      explicit_H = False,
                      target_column = 'inkre',
                      cut_off_radius = 3.0,
                      num_extras=9,
                      ):

  mol_list = mol_objecter(dataframe, column="SMILES")
  xyz_list = dataframe[['xyz']].values.tolist()

  data_dic = {
    'extras' : [],
    'ccsdt' : [],
    'inkre' : [],
    'smiles': [],
    'fp'    : [],
  }
  graph_list = []

  for extras, ccsdt, inkre, smile in zip(dataframe['extra_feat'],
                                          dataframe['DLPNO_E_CCSDt_au'], 
                                          dataframe[target_column],
                                          dataframe['SMILES']):
    ccsdt = torch.FloatTensor([ccsdt]).view(1, 1)
    inkre = torch.FloatTensor([inkre]).view(1, 1)

    mol = Chem.MolFromSmiles(smile)
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 4, nBits=1024, useChirality=True, useFeatures=True)

    data_dic['extras'].append(extras)
    data_dic['ccsdt'].append(ccsdt)
    data_dic['inkre'].append(inkre)
    data_dic['smiles'].append(smile)
    data_dic['fp'].append(fp)
  new_graph_list = [mol2cluster(m, xyz[0], cut_off_radius, chirality=True) for m, xyz in zip(mol_list, xyz_list)]
  graph_list = [mol2graph_noBasis(m, extras, ccsdt, inkre, smile, fp, xyz_features=xyz_features, explicit_H=explicit_H) for m, extras, ccsdt, inkre, smile, fp in zip(mol_list,
                                                                                                                              data_dic['extras'],
                                                                                                                              data_dic['ccsdt'],
                                                                                                                              data_dic['inkre'],
                                                                                                                              data_dic['smiles'],
                                                                                                                              data_dic['fp'],
                                                                                                                              )]
  return DataLoader(graph_list, batch_size=bs, shuffle=False), DataLoader(new_graph_list, batch_size=bs, shuffle=False), graph_list, new_graph_list
