From 171e720b7b6a4daa2786d11097043d66759fe80b Mon Sep 17 00:00:00 2001 From: Mark Abraham <mark.j.abraham@gmail.com> Date: Mon, 26 Jul 2021 10:31:26 +0200 Subject: [PATCH] Lift atom type lookup out of inner loops Grompp loops over molecule types, looking up force parameters for each interaction from the associated bond types for the system (e.g. from the force field). The atom types for that interaction have to be looked up from the atoms for the molecule type that contains it, but this should be done only once, before considering each bond type as a possible match. The lookups for both A- and B-state parameters are now lifted out of the loops over bond types, simplifying the logic and significantly improving performance. Once that is done, one custom function could be replaced by std::equal. Improved some variable naming Apply 1 suggestion(s) to 1 file(s) --- src/gromacs/gmxpreprocess/toppush.cpp | 177 +++++++++----------------- 1 file changed, 60 insertions(+), 117 deletions(-) diff --git a/src/gromacs/gmxpreprocess/toppush.cpp b/src/gromacs/gmxpreprocess/toppush.cpp index 9aae7a39267..c73e821844e 100644 --- a/src/gromacs/gmxpreprocess/toppush.cpp +++ b/src/gromacs/gmxpreprocess/toppush.cpp @@ -45,6 +45,7 @@ #include <cstring> #include <algorithm> +#include <array> #include <string> #include "gromacs/fileio/warninp.h" @@ -1722,20 +1723,17 @@ static bool default_cmap_params(gmx::ArrayRef<InteractionsOfType> bondtype, /* Returns the number of exact atom type matches, i.e. non wild-card matches, * returns -1 when there are no matches at all. */ -static int natom_match(const InteractionOfType& pi, - int type_i, - int type_j, - int type_k, - int type_l, - const PreprocessingAtomTypes* atypes) +static int findNumberOfDihedralAtomMatches(const InteractionOfType& bondType, + const gmx::ArrayRef<const int> atomTypes) { - if ((pi.ai() == -1 || atypes->bondAtomTypeFromAtomType(type_i) == pi.ai()) - && (pi.aj() == -1 || atypes->bondAtomTypeFromAtomType(type_j) == pi.aj()) - && (pi.ak() == -1 || atypes->bondAtomTypeFromAtomType(type_k) == pi.ak()) - && (pi.al() == -1 || atypes->bondAtomTypeFromAtomType(type_l) == pi.al())) + GMX_RELEASE_ASSERT(atomTypes.size() == 4, "Dihedrals have 4 atom types"); + if ((bondType.ai() == -1 || atomTypes[0] == bondType.ai()) + && (bondType.aj() == -1 || atomTypes[1] == bondType.aj()) + && (bondType.ak() == -1 || atomTypes[2] == bondType.ak()) + && (bondType.al() == -1 || atomTypes[3] == bondType.al())) { - return (pi.ai() == -1 ? 0 : 1) + (pi.aj() == -1 ? 0 : 1) + (pi.ak() == -1 ? 0 : 1) - + (pi.al() == -1 ? 0 : 1); + return (bondType.ai() == -1 ? 0 : 1) + (bondType.aj() == -1 ? 0 : 1) + + (bondType.ak() == -1 ? 0 : 1) + (bondType.al() == -1 ? 0 : 1); } else { @@ -1743,77 +1741,14 @@ static int natom_match(const InteractionOfType& pi, } } -static int findNumberOfDihedralAtomMatches(const InteractionOfType& currentParamFromParameterArray, - const InteractionOfType& parameterToAdd, - const t_atoms* at, - const PreprocessingAtomTypes* atypes, - bool bB) -{ - if (bB) - { - return natom_match(currentParamFromParameterArray, - at->atom[parameterToAdd.ai()].typeB, - at->atom[parameterToAdd.aj()].typeB, - at->atom[parameterToAdd.ak()].typeB, - at->atom[parameterToAdd.al()].typeB, - atypes); - } - else - { - return natom_match(currentParamFromParameterArray, - at->atom[parameterToAdd.ai()].type, - at->atom[parameterToAdd.aj()].type, - at->atom[parameterToAdd.ak()].type, - at->atom[parameterToAdd.al()].type, - atypes); - } -} - -static bool findIfAllParameterAtomsMatch(gmx::ArrayRef<const int> atomsFromParameterArray, - gmx::ArrayRef<const int> atomsFromCurrentParameter, - const t_atoms* at, - const PreprocessingAtomTypes* atypes, - bool bB) -{ - if (atomsFromParameterArray.size() != atomsFromCurrentParameter.size()) - { - return false; - } - else if (bB) - { - for (gmx::index i = 0; i < atomsFromCurrentParameter.ssize(); i++) - { - if (atypes->bondAtomTypeFromAtomType(at->atom[atomsFromCurrentParameter[i]].typeB) - != atomsFromParameterArray[i]) - { - return false; - } - } - return true; - } - else - { - for (gmx::index i = 0; i < atomsFromCurrentParameter.ssize(); i++) - { - if (atypes->bondAtomTypeFromAtomType(at->atom[atomsFromCurrentParameter[i]].type) - != atomsFromParameterArray[i]) - { - return false; - } - } - return true; - } -} - -static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ftype, - gmx::ArrayRef<InteractionsOfType> bt, - t_atoms* at, - PreprocessingAtomTypes* atypes, - const InteractionOfType& p, - bool bB, - int* nparam_def) +static std::vector<InteractionOfType>::iterator +defaultInteractionsOfType(int ftype, + gmx::ArrayRef<InteractionsOfType> bondType, + const gmx::ArrayRef<const int> atomTypes, + int* nparam_def) { int nparam_found = 0; + if (ftype == F_PDIHS || ftype == F_RBDIHS || ftype == F_IDIHS || ftype == F_PIDIHS) { int nmatch_max = -1; @@ -1821,24 +1756,23 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft /* For dihedrals we allow wildcards. We choose the first type * that has the most real matches, i.e. non-wildcard matches. */ - auto prevPos = bt[ftype].interactionTypes.end(); - auto pos = bt[ftype].interactionTypes.begin(); - while (pos != bt[ftype].interactionTypes.end() && nmatch_max < 4) - { - pos = std::find_if(bt[ftype].interactionTypes.begin(), - bt[ftype].interactionTypes.end(), - [&p, &at, &atypes, &bB, &nmatch_max](const auto& param) { - return (findNumberOfDihedralAtomMatches(param, p, at, atypes, bB) - > nmatch_max); + auto prevPos = bondType[ftype].interactionTypes.end(); + auto pos = bondType[ftype].interactionTypes.begin(); + while (pos != bondType[ftype].interactionTypes.end() && nmatch_max < 4) + { + pos = std::find_if(bondType[ftype].interactionTypes.begin(), + bondType[ftype].interactionTypes.end(), + [&atomTypes, &nmatch_max](const auto& bondType) { + return (findNumberOfDihedralAtomMatches(bondType, atomTypes) > nmatch_max); }); - if (pos != bt[ftype].interactionTypes.end()) + if (pos != bondType[ftype].interactionTypes.end()) { prevPos = pos; - nmatch_max = findNumberOfDihedralAtomMatches(*pos, p, at, atypes, bB); + nmatch_max = findNumberOfDihedralAtomMatches(*pos, atomTypes); } } - if (prevPos != bt[ftype].interactionTypes.end()) + if (prevPos != bondType[ftype].interactionTypes.end()) { nparam_found++; @@ -1854,7 +1788,7 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft }; /* Continue from current iterator position */ auto nextPos = prevPos; - const auto endIter = bt[ftype].interactionTypes.end(); + const auto endIter = bondType[ftype].interactionTypes.end(); safeAdvance(nextPos, 2, endIter); for (; nextPos < endIter && bSame; safeAdvance(nextPos, 2, endIter)) { @@ -1872,14 +1806,13 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft } else /* Not a dihedral */ { - gmx::ArrayRef<const int> atomParam = p.atoms(); - auto found = std::find_if(bt[ftype].interactionTypes.begin(), - bt[ftype].interactionTypes.end(), - [&atomParam, &at, &atypes, &bB](const auto& param) { - return findIfAllParameterAtomsMatch( - param.atoms(), atomParam, at, atypes, bB); - }); - if (found != bt[ftype].interactionTypes.end()) + auto found = std::find_if( + bondType[ftype].interactionTypes.begin(), + bondType[ftype].interactionTypes.end(), + [&atomTypes](const auto& param) { + return std::equal(param.atoms().begin(), param.atoms().end(), atomTypes.begin()); + }); + if (found != bondType[ftype].interactionTypes.end()) { nparam_found = 1; } @@ -1912,10 +1845,10 @@ void push_bond(Directive d, int nral, nral_fmt, nread, ftype; char format[STRLEN]; /* One force parameter more, so we can check if we read too many */ - double cc[MAXFORCEPARAM + 1]; - int aa[MAXATOMLIST + 1]; - bool bFoundA = FALSE, bFoundB = FALSE, bDef, bSwapParity = FALSE; - int nparam_defA, nparam_defB; + double cc[MAXFORCEPARAM + 1]; + std::array<int, MAXATOMLIST + 1> aa; + bool bFoundA = FALSE, bFoundB = FALSE, bDef, bSwapParity = FALSE; + int nparam_defA, nparam_defB; nparam_defA = nparam_defB = 0; @@ -1979,7 +1912,7 @@ void push_bond(Directive d, } - /* Check for double atoms and atoms out of bounds */ + /* Check for double atoms and atoms out of bounds, then convert to 0-based indexing */ for (int i = 0; (i < nral); i++) { if (aa[i] < 1 || aa[i] > at->nr) @@ -2018,21 +1951,33 @@ void push_bond(Directive d, } } } + + // Convert to 0-based indexing + --aa[i]; } + // These are the atom indices for this interaction + gmx::ArrayRef<int> atomIndices(aa.begin(), aa.begin() + nral); + + // Look up the A-state atom types for this interaction + std::vector<int> atomTypes(atomIndices.size()); + std::transform(atomIndices.begin(), atomIndices.end(), atomTypes.begin(), [at, atypes](const int atomIndex) { + return atypes->bondAtomTypeFromAtomType(at->atom[atomIndex].type).value(); + }); + // Look up the B-state atom types for this interaction + std::vector<int> atomTypesB(atomIndices.size()); + std::transform(atomIndices.begin(), atomIndices.end(), atomTypesB.begin(), [at, atypes](const int atomIndex) { + return atypes->bondAtomTypeFromAtomType(at->atom[atomIndex].typeB).value(); + }); + /* default force parameters */ - std::vector<int> atoms; - for (int j = 0; (j < nral); j++) - { - atoms.emplace_back(aa[j] - 1); - } /* need to have an empty but initialized param array for some reason */ std::array<real, MAXFORCEPARAM> forceParam = { 0.0 }; /* Get force params for normal and free energy perturbation * studies, as determined by types! */ - InteractionOfType param(atoms, forceParam, ""); + InteractionOfType param(atomIndices, forceParam, ""); std::vector<InteractionOfType>::iterator foundAParameter = bondtype[ftype].interactionTypes.end(); std::vector<InteractionOfType>::iterator foundBParameter = bondtype[ftype].interactionTypes.end(); @@ -2044,8 +1989,7 @@ void push_bond(Directive d, } else { - foundAParameter = - defaultInteractionsOfType(ftype, bondtype, at, atypes, param, FALSE, &nparam_defA); + foundAParameter = defaultInteractionsOfType(ftype, bondtype, atomTypes, &nparam_defA); if (foundAParameter != bondtype[ftype].interactionTypes.end()) { /* Copy the A-state and B-state default parameters. */ @@ -2066,8 +2010,7 @@ void push_bond(Directive d, } else { - foundBParameter = - defaultInteractionsOfType(ftype, bondtype, at, atypes, param, TRUE, &nparam_defB); + foundBParameter = defaultInteractionsOfType(ftype, bondtype, atomTypesB, &nparam_defB); if (foundBParameter != bondtype[ftype].interactionTypes.end()) { /* Copy only the B-state default parameters */ -- GitLab