Commit 171e720b authored by Mark Abraham's avatar Mark Abraham Committed by Artem Zhmurov
Browse files

Lift atom type lookup out of inner loops

Grompp loops over molecule types, looking up force parameters for each
interaction from the associated bond types for the system (e.g. from
the force field). The atom types for that interaction have to be
looked up from the atoms for the molecule type that contains it, but
this should be done only once, before considering each bond type as a
possible match. The lookups for both A- and B-state parameters are now
lifted out of the loops over bond types, simplifying the logic and
significantly improving performance.

Once that is done, one custom function could be replaced by
std::equal.

Improved some variable naming

Apply 1 suggestion(s) to 1 file(s)
parent b7ddf29e
......@@ -45,6 +45,7 @@
#include <cstring>
#include <algorithm>
#include <array>
#include <string>
#include "gromacs/fileio/warninp.h"
......@@ -1722,20 +1723,17 @@ static bool default_cmap_params(gmx::ArrayRef<InteractionsOfType> bondtype,
/* Returns the number of exact atom type matches, i.e. non wild-card matches,
* returns -1 when there are no matches at all.
*/
static int natom_match(const InteractionOfType& pi,
int type_i,
int type_j,
int type_k,
int type_l,
const PreprocessingAtomTypes* atypes)
static int findNumberOfDihedralAtomMatches(const InteractionOfType& bondType,
const gmx::ArrayRef<const int> atomTypes)
{
if ((pi.ai() == -1 || atypes->bondAtomTypeFromAtomType(type_i) == pi.ai())
&& (pi.aj() == -1 || atypes->bondAtomTypeFromAtomType(type_j) == pi.aj())
&& (pi.ak() == -1 || atypes->bondAtomTypeFromAtomType(type_k) == pi.ak())
&& (pi.al() == -1 || atypes->bondAtomTypeFromAtomType(type_l) == pi.al()))
GMX_RELEASE_ASSERT(atomTypes.size() == 4, "Dihedrals have 4 atom types");
if ((bondType.ai() == -1 || atomTypes[0] == bondType.ai())
&& (bondType.aj() == -1 || atomTypes[1] == bondType.aj())
&& (bondType.ak() == -1 || atomTypes[2] == bondType.ak())
&& (bondType.al() == -1 || atomTypes[3] == bondType.al()))
{
return (pi.ai() == -1 ? 0 : 1) + (pi.aj() == -1 ? 0 : 1) + (pi.ak() == -1 ? 0 : 1)
+ (pi.al() == -1 ? 0 : 1);
return (bondType.ai() == -1 ? 0 : 1) + (bondType.aj() == -1 ? 0 : 1)
+ (bondType.ak() == -1 ? 0 : 1) + (bondType.al() == -1 ? 0 : 1);
}
else
{
......@@ -1743,77 +1741,14 @@ static int natom_match(const InteractionOfType& pi,
}
}
static int findNumberOfDihedralAtomMatches(const InteractionOfType& currentParamFromParameterArray,
const InteractionOfType& parameterToAdd,
const t_atoms* at,
const PreprocessingAtomTypes* atypes,
bool bB)
{
if (bB)
{
return natom_match(currentParamFromParameterArray,
at->atom[parameterToAdd.ai()].typeB,
at->atom[parameterToAdd.aj()].typeB,
at->atom[parameterToAdd.ak()].typeB,
at->atom[parameterToAdd.al()].typeB,
atypes);
}
else
{
return natom_match(currentParamFromParameterArray,
at->atom[parameterToAdd.ai()].type,
at->atom[parameterToAdd.aj()].type,
at->atom[parameterToAdd.ak()].type,
at->atom[parameterToAdd.al()].type,
atypes);
}
}
static bool findIfAllParameterAtomsMatch(gmx::ArrayRef<const int> atomsFromParameterArray,
gmx::ArrayRef<const int> atomsFromCurrentParameter,
const t_atoms* at,
const PreprocessingAtomTypes* atypes,
bool bB)
{
if (atomsFromParameterArray.size() != atomsFromCurrentParameter.size())
{
return false;
}
else if (bB)
{
for (gmx::index i = 0; i < atomsFromCurrentParameter.ssize(); i++)
{
if (atypes->bondAtomTypeFromAtomType(at->atom[atomsFromCurrentParameter[i]].typeB)
!= atomsFromParameterArray[i])
{
return false;
}
}
return true;
}
else
{
for (gmx::index i = 0; i < atomsFromCurrentParameter.ssize(); i++)
{
if (atypes->bondAtomTypeFromAtomType(at->atom[atomsFromCurrentParameter[i]].type)
!= atomsFromParameterArray[i])
{
return false;
}
}
return true;
}
}
static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ftype,
gmx::ArrayRef<InteractionsOfType> bt,
t_atoms* at,
PreprocessingAtomTypes* atypes,
const InteractionOfType& p,
bool bB,
int* nparam_def)
static std::vector<InteractionOfType>::iterator
defaultInteractionsOfType(int ftype,
gmx::ArrayRef<InteractionsOfType> bondType,
const gmx::ArrayRef<const int> atomTypes,
int* nparam_def)
{
int nparam_found = 0;
if (ftype == F_PDIHS || ftype == F_RBDIHS || ftype == F_IDIHS || ftype == F_PIDIHS)
{
int nmatch_max = -1;
......@@ -1821,24 +1756,23 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft
/* For dihedrals we allow wildcards. We choose the first type
* that has the most real matches, i.e. non-wildcard matches.
*/
auto prevPos = bt[ftype].interactionTypes.end();
auto pos = bt[ftype].interactionTypes.begin();
while (pos != bt[ftype].interactionTypes.end() && nmatch_max < 4)
{
pos = std::find_if(bt[ftype].interactionTypes.begin(),
bt[ftype].interactionTypes.end(),
[&p, &at, &atypes, &bB, &nmatch_max](const auto& param) {
return (findNumberOfDihedralAtomMatches(param, p, at, atypes, bB)
> nmatch_max);
auto prevPos = bondType[ftype].interactionTypes.end();
auto pos = bondType[ftype].interactionTypes.begin();
while (pos != bondType[ftype].interactionTypes.end() && nmatch_max < 4)
{
pos = std::find_if(bondType[ftype].interactionTypes.begin(),
bondType[ftype].interactionTypes.end(),
[&atomTypes, &nmatch_max](const auto& bondType) {
return (findNumberOfDihedralAtomMatches(bondType, atomTypes) > nmatch_max);
});
if (pos != bt[ftype].interactionTypes.end())
if (pos != bondType[ftype].interactionTypes.end())
{
prevPos = pos;
nmatch_max = findNumberOfDihedralAtomMatches(*pos, p, at, atypes, bB);
nmatch_max = findNumberOfDihedralAtomMatches(*pos, atomTypes);
}
}
if (prevPos != bt[ftype].interactionTypes.end())
if (prevPos != bondType[ftype].interactionTypes.end())
{
nparam_found++;
......@@ -1854,7 +1788,7 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft
};
/* Continue from current iterator position */
auto nextPos = prevPos;
const auto endIter = bt[ftype].interactionTypes.end();
const auto endIter = bondType[ftype].interactionTypes.end();
safeAdvance(nextPos, 2, endIter);
for (; nextPos < endIter && bSame; safeAdvance(nextPos, 2, endIter))
{
......@@ -1872,14 +1806,13 @@ static std::vector<InteractionOfType>::iterator defaultInteractionsOfType(int ft
}
else /* Not a dihedral */
{
gmx::ArrayRef<const int> atomParam = p.atoms();
auto found = std::find_if(bt[ftype].interactionTypes.begin(),
bt[ftype].interactionTypes.end(),
[&atomParam, &at, &atypes, &bB](const auto& param) {
return findIfAllParameterAtomsMatch(
param.atoms(), atomParam, at, atypes, bB);
});
if (found != bt[ftype].interactionTypes.end())
auto found = std::find_if(
bondType[ftype].interactionTypes.begin(),
bondType[ftype].interactionTypes.end(),
[&atomTypes](const auto& param) {
return std::equal(param.atoms().begin(), param.atoms().end(), atomTypes.begin());
});
if (found != bondType[ftype].interactionTypes.end())
{
nparam_found = 1;
}
......@@ -1912,10 +1845,10 @@ void push_bond(Directive d,
int nral, nral_fmt, nread, ftype;
char format[STRLEN];
/* One force parameter more, so we can check if we read too many */
double cc[MAXFORCEPARAM + 1];
int aa[MAXATOMLIST + 1];
bool bFoundA = FALSE, bFoundB = FALSE, bDef, bSwapParity = FALSE;
int nparam_defA, nparam_defB;
double cc[MAXFORCEPARAM + 1];
std::array<int, MAXATOMLIST + 1> aa;
bool bFoundA = FALSE, bFoundB = FALSE, bDef, bSwapParity = FALSE;
int nparam_defA, nparam_defB;
nparam_defA = nparam_defB = 0;
......@@ -1979,7 +1912,7 @@ void push_bond(Directive d,
}
/* Check for double atoms and atoms out of bounds */
/* Check for double atoms and atoms out of bounds, then convert to 0-based indexing */
for (int i = 0; (i < nral); i++)
{
if (aa[i] < 1 || aa[i] > at->nr)
......@@ -2018,21 +1951,33 @@ void push_bond(Directive d,
}
}
}
// Convert to 0-based indexing
--aa[i];
}
// These are the atom indices for this interaction
gmx::ArrayRef<int> atomIndices(aa.begin(), aa.begin() + nral);
// Look up the A-state atom types for this interaction
std::vector<int> atomTypes(atomIndices.size());
std::transform(atomIndices.begin(), atomIndices.end(), atomTypes.begin(), [at, atypes](const int atomIndex) {
return atypes->bondAtomTypeFromAtomType(at->atom[atomIndex].type).value();
});
// Look up the B-state atom types for this interaction
std::vector<int> atomTypesB(atomIndices.size());
std::transform(atomIndices.begin(), atomIndices.end(), atomTypesB.begin(), [at, atypes](const int atomIndex) {
return atypes->bondAtomTypeFromAtomType(at->atom[atomIndex].typeB).value();
});
/* default force parameters */
std::vector<int> atoms;
for (int j = 0; (j < nral); j++)
{
atoms.emplace_back(aa[j] - 1);
}
/* need to have an empty but initialized param array for some reason */
std::array<real, MAXFORCEPARAM> forceParam = { 0.0 };
/* Get force params for normal and free energy perturbation
* studies, as determined by types!
*/
InteractionOfType param(atoms, forceParam, "");
InteractionOfType param(atomIndices, forceParam, "");
std::vector<InteractionOfType>::iterator foundAParameter = bondtype[ftype].interactionTypes.end();
std::vector<InteractionOfType>::iterator foundBParameter = bondtype[ftype].interactionTypes.end();
......@@ -2044,8 +1989,7 @@ void push_bond(Directive d,
}
else
{
foundAParameter =
defaultInteractionsOfType(ftype, bondtype, at, atypes, param, FALSE, &nparam_defA);
foundAParameter = defaultInteractionsOfType(ftype, bondtype, atomTypes, &nparam_defA);
if (foundAParameter != bondtype[ftype].interactionTypes.end())
{
/* Copy the A-state and B-state default parameters. */
......@@ -2066,8 +2010,7 @@ void push_bond(Directive d,
}
else
{
foundBParameter =
defaultInteractionsOfType(ftype, bondtype, at, atypes, param, TRUE, &nparam_defB);
foundBParameter = defaultInteractionsOfType(ftype, bondtype, atomTypesB, &nparam_defB);
if (foundBParameter != bondtype[ftype].interactionTypes.end())
{
/* Copy only the B-state default parameters */
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment