Skip to content

Commit

Permalink
Workaround to speed up bond table calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
peterspackman committed Jan 14, 2025
1 parent d9c19d2 commit 02b008d
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 28 deletions.
62 changes: 38 additions & 24 deletions src/io/crystalclear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,32 @@ inline occ::crystal::Crystal loadOccCrystal(const json &json) {
return occ::crystal::Crystal(asym, sg, uc);
}

void loadMetadata(PairInteraction *pair, const json &obj) {
for (auto it = obj.begin(); it != obj.end(); ++it) {
const auto &key = it.key();
const auto &value = it.value();

if (key == "energies" || key == "uc_atom_offsets")
continue;

QString qKey = QString::fromStdString(key);
if (value.is_number_integer()) {
pair->addMetadata(qKey, value.get<int>());
} else if (value.is_number_float()) {
pair->addMetadata(qKey, value.get<double>());
} else if (value.is_boolean()) {
pair->addMetadata(qKey, value.get<bool>());
} else if (value.is_string()) {
QString qValue = QString::fromStdString(value.get<std::string>());
if (qKey.toLower().contains("id")) {
pair->setLabel(qValue);
} else {
pair->addMetadata(qKey, qValue);
}
}
}
}

CrystalStructure *loadCrystalClearJson(const QString &filename) {
auto json = loadJsonDocument(filename);
if (json.is_null())
Expand Down Expand Up @@ -76,34 +102,20 @@ CrystalStructure *loadCrystalClearJson(const QString &filename) {
auto siteEnergies = pairsArray[i];
auto &neighbors = interactions[i];
auto &offsets = atomIndices[i];
neighbors.reserve(siteEnergies.size());
offsets.reserve(siteEnergies.size());

for (int j = 0; j < siteEnergies.size(); ++j) {
auto *pair = new PairInteraction(modelName);
pair_energy::Parameters params;
params.hasPermutationSymmetry = hasPermutationSymmetry;
pair->setParameters(params);

auto dimerObj = siteEnergies[j];
const auto &dimerObj = siteEnergies[j];
pair->setLabel(QString::number(j + 1));
for (auto it = dimerObj.begin(); it != dimerObj.end(); ++it) {
QString key = QString::fromStdString(it.key());
if (key == "energies")
continue;
if (it->is_number_integer()) {
pair->addMetadata(key, it->get<int>());
} else if (it->is_number_float()) {
pair->addMetadata(key, it->get<double>());
} else if (it->is_boolean()) {
pair->addMetadata(key, it->get<bool>());
} else if (it->is_string()) {
QString value = QString::fromStdString(it->get<std::string>());
if (key.toLower().contains("id")) {
pair->setLabel(value);
} else {
pair->addMetadata(key, value);
}
}
}
auto energiesObj = dimerObj["energies"];
loadMetadata(pair, dimerObj);

const auto &energiesObj = dimerObj["energies"];
for (auto it = energiesObj.begin(); it != energiesObj.end(); ++it) {
QString key = QString::fromStdString(it.key());
if (it->is_number()) {
Expand All @@ -113,14 +125,16 @@ CrystalStructure *loadCrystalClearJson(const QString &filename) {
continue;
}
}
auto offsetsObj = dimerObj["uc_atom_offsets"];
const auto &offsetsObj = dimerObj["uc_atom_offsets"];
DimerAtoms d;
auto a = offsetsObj[0];
const auto &a = offsetsObj[0];
d.a.reserve(a.size()); // For DimerAtoms
for (int i = 0; i < a.size(); i++) {
auto idx = a[i];
d.a.push_back(GenericAtomIndex{idx[0], idx[1], idx[2], idx[3]});
}
auto b = offsetsObj[1];
const auto &b = offsetsObj[1];
d.b.reserve(b.size());
for (int i = 0; i < b.size(); i++) {
auto idx = b[i];
d.b.push_back(GenericAtomIndex{idx[0], idx[1], idx[2], idx[3]});
Expand Down
1 change: 0 additions & 1 deletion src/occ/crystal/crystal.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <fmt/core.h>
#include <iostream>
#include <occ/core/element.h>
#include <occ/core/kdtree.h>
#include <occ/core/linear_algebra.h>
Expand Down
16 changes: 13 additions & 3 deletions src/occ/crystal/dimer_mapping_table.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <iostream>
#include <occ/crystal/crystal.h>
#include <occ/crystal/dimer_mapping_table.h>

Expand Down Expand Up @@ -113,8 +112,6 @@ DimerMappingTable::DimerMappingTable(const Crystal &crystal,
m_centroids.col(i) = crystal.to_fractional(uc_mols[i].centroid());
}

std::cout << "Centroids\n" << m_centroids.transpose() << '\n';

const auto &symops = crystal.symmetry_operations();
ankerl::unordered_dense::set<DimerIndex, DimerIndexHash> unique_dimers_set;

Expand Down Expand Up @@ -174,6 +171,7 @@ bool DimerMappingTable::have_dimer(const DimerIndex &dimer) const {
DimerMappingTable
DimerMappingTable::create_atomic_pair_table(const Crystal &crystal,
bool consider_inversion) {

DimerMappingTable table;
table.m_consider_inversion = consider_inversion;
table.m_cell = crystal.unit_cell();
Expand All @@ -186,16 +184,28 @@ DimerMappingTable::create_atomic_pair_table(const Crystal &crystal,
const auto &symops = crystal.symmetry_operations();
ankerl::unordered_dense::set<DimerIndex, DimerIndexHash> unique_dimers_set;

// Get max possible bonding distance from vdw radii
const auto &vdw_radii = crystal.asymmetric_unit().vdw_radii();
double max_vdw = vdw_radii.maxCoeff();
double max_dist = (max_vdw * 2 + 0.6) * (max_vdw * 2 + 0.6);

// For each atom in the unit cell
for (int i = 0; i < uc_atoms.size(); i++) {
Vec3 pos_i = uc_atoms.frac_pos.col(i);
Vec3 cart_pos_i = uc_atoms.cart_pos.col(i);

// Look through expanded slab for possible bonds
for (int j = 0; j < s.frac_pos.cols(); j++) {
if (j % uc_atoms.size() <= i)
continue; // avoid duplicates
// Then in the pair loop:
Vec3 cart_pos_j = s.cart_pos.col(j);
Vec3 pos_diff = cart_pos_j - cart_pos_i;
if (pos_diff.squaredNorm() > max_dist)
continue; // Skip pairs too far apart

Vec3 pos_j = s.frac_pos.col(j);

int uc_idx_j = s.uc_idx(j);
HKL cell_offset{s.hkl(0, j), s.hkl(1, j), s.hkl(2, j)};

Expand Down
4 changes: 4 additions & 0 deletions src/occ/include/occ/crystal/dimer_mapping_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ struct DimerIndex {
return a == other.a && b == other.b;
}

inline bool operator!=(const DimerIndex &other) const {
return !(*this == other);
}

inline bool operator<(const DimerIndex &other) const {
if (a.offset != other.a.offset)
return a.offset < other.a.offset;
Expand Down

0 comments on commit 02b008d

Please sign in to comment.