-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrdkit_atom_count.py
46 lines (36 loc) · 1.01 KB
/
rdkit_atom_count.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import pandas as pd
import rdkit.Chem as Chem
import pandas_utils as pu
import numpy as np
from dataclasses import dataclass, field
from enum import Enum
@dataclass
class Element:
symbol: str
atomic_number: int
class Elements(Element, Enum):
H = ("H", 1)
C = ("C", 6)
N = ("N", 7)
O = ("O", 8)
P = ("P", 15)
S = ("S", 16)
F = ("F", 9)
Cl = ("Cl", 17)
Br = ("Br", 35)
I = ("I", 53)
def count_element_atoms_df(
df, mol_col, elements: list[Element] = Elements
) -> pd.DataFrame:
# also count H
mol_col_h = [Chem.AddHs(mol) if pu.notnull(mol) else None for mol in mol_col]
for element in elements:
col = f"at_n_{element.symbol}"
df[col] = [count_element(mol, element.atomic_number) for mol in mol_col_h]
df = pu.astype_int(df, col)
return df
def count_element(mol: Chem.rdchem.Mol, number: int):
if pu.notnull(mol):
return sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == number)
else:
return np.NAN