Skip to content

Commit

Permalink
use unyt_arrays for charges and mass instead of lists
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjonesBSU committed Oct 18, 2024
1 parent 5e8635f commit b96a83c
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions gmso/external/convert_hoomd.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,17 @@ def _parse_particle_information(
site.name if site.atom_type is None else site.atom_type.name
for site in top.sites
]
masses = np.zeros(top.n_sites)
charges = np.zeros(top.n_sites)
for idx, site in enumerate(top.sites):
masses[idx] = (
masses = u.unyt_array(
[
site.mass.to_value(base_units["mass"])
if site.mass
else 1 * base_units["mass"]
)
charges[idx] = site.charge if site.charge else 0 * u.elementary_charge
for site in top.sites
]
)
charges = u.unyt_array(
[site.charge if site.charge else 0 * u.elementary_charge for site in top.sites]
)

unique_types = sorted(list(set(types)))
typeids = np.array([unique_types.index(t) for t in types])
Expand All @@ -302,8 +304,9 @@ def _parse_particle_information(
rigid_ids = [site.molecule.number for site in top.sites]
rigid_ids_set = set(rigid_ids)
n_rigid = len(rigid_ids_set)
rigid_masses = np.zeros(n_rigid)
rigid_xyz = np.zeros((n_rigid, 3))
rigid_charges = np.zeros(n_rigid) * charges.units
rigid_masses = np.zeros(n_rigid) * masses.units
rigid_xyz = np.zeros((n_rigid, 3)) * xyz.units
# Rigid particle type defaults to "R"; add to front of list
# TODO: Can we always use "R" here? What if an atom_type is "R"?
unique_types = ["R"] + unique_types
Expand All @@ -322,8 +325,8 @@ def _parse_particle_information(
# Append rigid center mass and xyz to front
masses = np.concatenate((rigid_masses, masses))
xyz = np.concatenate((rigid_xyz, xyz))
charges = np.concatenate((np.zeros(n_rigid), charges))
rigid_id_tags = np.concatenate((np.arange(n_rigid), np.array(rigid_ids)))
charges = np.concatenate((rigid_charges, charges))
rigid_id_tags = np.concatenate((np.arange(n_rigid), rigid_ids))
else:
n_rigid = 0

Expand Down

0 comments on commit b96a83c

Please sign in to comment.