[aGrUM] add Potential::findAll/argmax/argmin methods

parent e4dce0ce
......@@ -203,6 +203,13 @@ namespace gum {
/// @throw NotFound if all value == 0.0
GUM_SCALAR minNonZero() const;
/// set of instantiation corresponding to the parameter v in the Potential
Set< Instantiation > findAll(GUM_SCALAR v) const;
/// set of instantiation corresponding to the max in the Potential
Set< Instantiation > argmax() const;
/// set of instantiation corresponding to the min in the Potential
Set< Instantiation > argmin() const;
/// entropy of the Potential
GUM_SCALAR entropy() const;
......
......@@ -587,4 +587,29 @@ namespace gum {
return out;
}
// argmax of all elements in this
template < typename GUM_SCALAR >
Set< Instantiation > Potential< GUM_SCALAR >::findAll(GUM_SCALAR v) const {
Instantiation I(*this);
Set< Instantiation > res;
if (static_cast< MultiDimContainer< GUM_SCALAR >* >(this->_content)->empty()) {
return res;
}
for (I.setFirst(); !I.end(); ++I) {
if (this->get(I) == v) res.insert(I);
}
return res;
}
// argmax of all elements in this
template < typename GUM_SCALAR >
INLINE Set< Instantiation > Potential< GUM_SCALAR >::argmax() const {
return findAll(max());
}
// argmin of all elements in this
template < typename GUM_SCALAR >
INLINE Set< Instantiation > Potential< GUM_SCALAR >::argmin() const {
return findAll(min());
}
} /* namespace gum */
......@@ -1022,5 +1022,39 @@ namespace gum_tests {
TS_GUM_ASSERT_THROWS_NOTHING(pp.fillWith(p, {"w", "v"}));
TS_ASSERT_THROWS(pp.fillWith(p, {"v", "w"}), gum::InvalidArgument);
}
private:
void __testval_for_set(const gum::Potential< int >& p,
int val,
const gum::Set< gum::Instantiation > s,
gum::Size expected_size) {
gum::Instantiation ip(p);
TS_ASSERT_EQUALS(s.size(), expected_size);
for (ip.setFirst(); !ip.end(); ++ip) {
if (s.contains(ip)) {
TS_ASSERT_EQUALS(p[ip], val);
} else {
TS_ASSERT_DIFFERS(p[ip], val);
}
}
}
public:
void testArgMaxMinFindAll() {
gum::LabelizedVariable v("v", "v", 2), w("w", "w", 3);
gum::Potential< int > p;
__testval_for_set(p, 4, p.findAll(4), 0);
p.add(v);
p.add(w);
p.fillWith({1, 3, 2, 4, 1, 4});
gum::Instantiation ip(p);
__testval_for_set(p, 3, p.findAll(3), 1);
__testval_for_set(p, 10, p.findAll(10), 0);
__testval_for_set(p, 4, p.argmax(), 2);
__testval_for_set(p, 1, p.argmin(), 2);
}
};
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment