[aGrUM] gum::BayesNetFragment API : add accessors with names (and not only with nodeIds)

parent ab5ab6ad
......@@ -99,26 +99,24 @@ namespace gum {
/// the action to take when a new node is inserted into the graph
/** @param src the object that sent the signal
* @param id the id of the new node inserted into the graph */
virtual void whenNodeAdded(const void* src, NodeId id) noexcept override;
virtual void whenNodeAdded(const void* src, NodeId id) final;
/// the action to take when a node has just been removed from the graph
/** @param src the object that sent the signal
* @param id the id of the node has just been removed from the graph */
virtual void whenNodeDeleted(const void* src, NodeId id) noexcept override;
virtual void whenNodeDeleted(const void* src, NodeId id) final;
/// the action to take when a new arc is inserted into the graph
/** @param src the object that sent the signal
* @param from the id of tail of the new arc inserted into the graph
* @param to the id of head of the new arc inserted into the graph */
virtual void
whenArcAdded(const void* src, NodeId from, NodeId to) noexcept override;
virtual void whenArcAdded(const void* src, NodeId from, NodeId to) final;
/// the action to take when an arc has just been removed from the graph
/** @param src the object that sent the signal
* @param from the id of tail of the arc removed from the graph
* @param to the id of head of the arc removed from the graph */
virtual void
whenArcDeleted(const void* src, NodeId from, NodeId to) noexcept override;
virtual void whenArcDeleted(const void* src, NodeId from, NodeId to) final;
/// @}
/// @name IBayesNet interface
......@@ -129,33 +127,39 @@ namespace gum {
*
* @throw NotFound If no variable's id matches varId.
*/
virtual const Potential< GUM_SCALAR >& cpt(NodeId varId) const override;
virtual const Potential< GUM_SCALAR >& cpt(NodeId varId) const final;
virtual const Potential< GUM_SCALAR >& cpt(const std::string& name) const {
return cpt(idFromName(name));
};
/**
* Returns a constant reference to the VariableNodeMap of this BN
*/
virtual const VariableNodeMap& variableNodeMap() const override;
virtual const VariableNodeMap& variableNodeMap() const final;
/**
* Returns a constant reference over a variabe given it's node id.
*
* @throw NotFound If no variable's id matches varId.
*/
virtual const DiscreteVariable& variable(NodeId id) const override;
virtual const DiscreteVariable& variable(NodeId id) const final;
virtual const DiscreteVariable& variable(const std::string& name) const final {
return variable(idFromName(name));
};
/**
* Return id node from discrete var pointer.
*
* @throw NotFound If no variable matches var.
*/
virtual NodeId nodeId(const DiscreteVariable& var) const override;
virtual NodeId nodeId(const DiscreteVariable& var) const final;
/**
* Getter by name
*
* @throw NotFound if no such name exists in the graph.
*/
virtual NodeId idFromName(const std::string& name) const override;
virtual NodeId idFromName(const std::string& name) const final;
/**
* Getter by name
......@@ -163,13 +167,13 @@ namespace gum {
* @throw NotFound if no such name exists in the graph.
*/
virtual const DiscreteVariable&
variableFromName(const std::string& name) const override;
variableFromName(const std::string& name) const final;
/**
* creates a dot representing the whole referred BN hilighting the fragment.
* @return Returns a dot representation of this fragment
*/
virtual std::string toDot() const override;
virtual std::string toDot() const final;
/// @}
......@@ -179,7 +183,10 @@ namespace gum {
/**
* check if a certain NodeId exists in the fragment
*/
bool isInstalledNode(NodeId id) const noexcept;
bool isInstalledNode(NodeId id) const;
bool isInstalledNode(const std::string& name) const {
return isInstalledNode(idFromName(name));
};
/**
* install a node referenced by its nodeId
......@@ -188,6 +195,9 @@ namespace gum {
* @warning nothing happens if the node is already installed
*/
void installNode(NodeId id);
void installNode(const std::string& name) {
installNode(__bn.idFromName(name));
}
/**
* install a node and all its ascendants
......@@ -196,13 +206,19 @@ namespace gum {
* @warning nothing happens if the node is already installed
*/
void installAscendants(NodeId id);
void installAscendants(const std::string& name) {
installAscendants(__bn.idFromName(name));
}
/**
* uninstall a node referenced by its nodeId
*
* @warning nothing happens if the node is not installed
*/
void uninstallNode(NodeId id) noexcept;
void uninstallNode(NodeId id);
void uninstallNode(const std::string& name) {
uninstallNode(idFromName(name));
}
/**
* install a local marginal for a node into the fragment.
......@@ -215,6 +231,10 @@ namespace gum {
*(or is not a marginal)
**/
void installMarginal(NodeId id, const Potential< GUM_SCALAR >* pot);
void installMarginal(const std::string& name,
const Potential< GUM_SCALAR >* pot) {
installMarginal(__bn.idFromName(name), pot);
}
/**
* install a local cpt for a node into the fragment.
......@@ -226,11 +246,12 @@ namespace gum {
*
* @throw NotFound if the id is not in the fragment
* @throw OperationNotAllowed if the potential is not compliant with the
*variable
*or if
* a variable in the CPT is not a parent in the referred bn.
*variable or if a variable in the CPT is not a parent in the referred bn.
**/
void installCPT(NodeId id, const Potential< GUM_SCALAR >* pot);
void installCPT(const std::string& name, const Potential< GUM_SCALAR >* pot) {
installCPT(__bn.idFromName(name), pot);
};
/**
* uninstall a local CPT.
......@@ -239,7 +260,8 @@ namespace gum {
*is
*not installed.
*/
void uninstallCPT(NodeId id) noexcept;
void uninstallCPT(NodeId id);
void uninstallCPT(const std::string& name) { uninstallCPT(idFromName(name)); }
/**
* returns true if the nodeId's (local or not) cpt is consistent with its
......@@ -248,11 +270,14 @@ namespace gum {
* @throw NotFound if the id is not in the fragment
*/
bool checkConsistency(NodeId id) const;
bool checkConsistency(const std::string& name) const {
return checkConsistency(idFromName(name));
}
/**
* returns true if all nodes in the fragment are consistent
*/
bool checkConsistency() const noexcept;
bool checkConsistency() const;
/// @}
......@@ -261,21 +286,21 @@ namespace gum {
protected:
// remove an arc
void _uninstallArc(NodeId from, NodeId to) noexcept;
void _uninstallArc(NodeId from, NodeId to);
// add an arc
void _installArc(NodeId from, NodeId to) noexcept;
void _installArc(NodeId from, NodeId to);
// install a CPT, create or delete arcs. Checks are made in public methods
// In particular, it is assumed that all the variables in the pot are in the
// fragment
void _installCPT(NodeId id, const Potential< GUM_SCALAR >* pot) noexcept;
void _installCPT(NodeId id, const Potential< GUM_SCALAR >* pot);
/**
* uninstall a local CPT. Does nothing if no local CPT for this nodeId
* No check. No change in the topology. Checks are made in public methods.
*/
void _uninstallCPT(NodeId id) noexcept;
void _uninstallCPT(NodeId id);
};
......
......@@ -51,24 +51,24 @@ namespace gum {
// signals to keep consistency with the referred BayesNet
template < typename GUM_SCALAR >
INLINE void BayesNetFragment< GUM_SCALAR >::whenNodeAdded(const void* src,
NodeId id) noexcept {
NodeId id) {
// nothing to do
}
template < typename GUM_SCALAR >
INLINE void BayesNetFragment< GUM_SCALAR >::whenNodeDeleted(const void* src,
NodeId id) noexcept {
NodeId id) {
uninstallNode(id);
}
template < typename GUM_SCALAR >
INLINE void BayesNetFragment< GUM_SCALAR >::whenArcAdded(const void* src,
NodeId from,
NodeId to) noexcept {
NodeId to) {
// nothing to do
}
template < typename GUM_SCALAR >
INLINE void BayesNetFragment< GUM_SCALAR >::whenArcDeleted(const void* src,
NodeId from,
NodeId to) noexcept {
NodeId to) {
if (dag().existsArc(from, to)) _uninstallArc(from, to);
}
......@@ -137,8 +137,7 @@ namespace gum {
//============================================================
// specific API for BayesNetFragment
template < typename GUM_SCALAR >
INLINE bool BayesNetFragment< GUM_SCALAR >::isInstalledNode(NodeId id) const
noexcept {
INLINE bool BayesNetFragment< GUM_SCALAR >::isInstalledNode(NodeId id) const {
return dag().existsNode(id);
}
......@@ -171,7 +170,7 @@ namespace gum {
}
template < typename GUM_SCALAR >
INLINE void BayesNetFragment< GUM_SCALAR >::uninstallNode(NodeId id) noexcept {
INLINE void BayesNetFragment< GUM_SCALAR >::uninstallNode(NodeId id) {
if (isInstalledNode(id)) {
this->_dag.eraseNode(id);
uninstallCPT(id);
......@@ -180,19 +179,18 @@ namespace gum {
template < typename GUM_SCALAR >
INLINE void BayesNetFragment< GUM_SCALAR >::_uninstallArc(NodeId from,
NodeId to) noexcept {
NodeId to) {
this->_dag.eraseArc(Arc(from, to));
}
template < typename GUM_SCALAR >
INLINE void BayesNetFragment< GUM_SCALAR >::_installArc(NodeId from,
NodeId to) noexcept {
INLINE void BayesNetFragment< GUM_SCALAR >::_installArc(NodeId from, NodeId to) {
this->_dag.addArc(from, to);
}
template < typename GUM_SCALAR >
void BayesNetFragment< GUM_SCALAR >::_installCPT(
NodeId id, const Potential< GUM_SCALAR >* pot) noexcept {
NodeId id, const Potential< GUM_SCALAR >* pot) {
// topology
const auto& parents = this->parents(id);
for (auto node_it = parents.beginSafe(); node_it != parents.endSafe();
......@@ -236,13 +234,13 @@ namespace gum {
}
template < typename GUM_SCALAR >
INLINE void BayesNetFragment< GUM_SCALAR >::_uninstallCPT(NodeId id) noexcept {
INLINE void BayesNetFragment< GUM_SCALAR >::_uninstallCPT(NodeId id) {
delete __localCPTs[id];
__localCPTs.erase(id);
}
template < typename GUM_SCALAR >
INLINE void BayesNetFragment< GUM_SCALAR >::uninstallCPT(NodeId id) noexcept {
INLINE void BayesNetFragment< GUM_SCALAR >::uninstallCPT(NodeId id) {
if (__localCPTs.exists(id)) {
_uninstallCPT(id);
......@@ -293,7 +291,7 @@ namespace gum {
}
template < typename GUM_SCALAR >
INLINE bool BayesNetFragment< GUM_SCALAR >::checkConsistency() const noexcept {
INLINE bool BayesNetFragment< GUM_SCALAR >::checkConsistency() const {
for (auto node : nodes())
if (!checkConsistency(node)) return false;
......
......@@ -186,6 +186,54 @@ namespace gum_tests {
TS_ASSERT_EQUALS(frag2.sizeArcs(), (gum::Size)6);
}
void testInstallNodesWithVar() {
gum::BayesNet< double > bn;
fill(bn);
gum::BayesNetFragment< double > frag(bn);
// install a node
TS_ASSERT_EQUALS(frag.size(), (gum::Size)0);
TS_GUM_ASSERT_THROWS_NOTHING(frag.installNode("v1"));
TS_ASSERT_EQUALS(frag.size(), (gum::Size)1);
// install twice the same node
TS_GUM_ASSERT_THROWS_NOTHING(frag.installNode("v1"));
TS_ASSERT_EQUALS(frag.size(), (gum::Size)1);
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)0);
// install a non-existing node
TS_ASSERT_THROWS(frag.installNode("v100"), gum::NotFound);
TS_ASSERT_EQUALS(frag.size(), (gum::Size)1);
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)0);
// install a second node (without arc)
TS_GUM_ASSERT_THROWS_NOTHING(frag.installNode("v6"));
TS_ASSERT_EQUALS(frag.size(), (gum::Size)2);
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)0);
// install a third node (and 2 arcs)
TS_GUM_ASSERT_THROWS_NOTHING(frag.installNode("v3"));
TS_ASSERT_EQUALS(frag.size(), (gum::Size)3);
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)2);
// install ascendants (nothing should happen)
TS_GUM_ASSERT_THROWS_NOTHING(frag.installAscendants("v6"));
TS_ASSERT_EQUALS(frag.size(), (gum::Size)3);
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)2);
// install ascendants (nothing should happen)
TS_GUM_ASSERT_THROWS_NOTHING(frag.installAscendants("v5"));
TS_ASSERT_EQUALS(frag.size(), (gum::Size)6);
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)7);
// another test for ascendants
gum::BayesNetFragment< double > frag2(bn);
TS_GUM_ASSERT_THROWS_NOTHING(frag2.installAscendants("v5"));
TS_ASSERT_EQUALS(frag2.size(), (gum::Size)5);
TS_ASSERT_EQUALS(frag2.sizeArcs(), (gum::Size)6);
}
void testUninstallNode() {
gum::BayesNet< double > bn;
fill(bn);
......@@ -203,6 +251,23 @@ namespace gum_tests {
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)0);
}
void testUninstallNodeWithNames() {
gum::BayesNet< double > bn;
fill(bn);
gum::BayesNetFragment< double > frag(bn);
// install ascendants (nothing should happen)
TS_GUM_ASSERT_THROWS_NOTHING(frag.installAscendants("v6"));
TS_ASSERT_EQUALS(frag.size(), (gum::Size)3);
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)2);
// uninstall node 3 (in the middle)
TS_GUM_ASSERT_THROWS_NOTHING(frag.uninstallNode("v3"));
TS_ASSERT_EQUALS(frag.size(), (gum::Size)2);
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)0);
}
void testIBayetNetMethodsWithoutLocalCPTs() {
gum::BayesNet< double > bn;
fill(bn);
......@@ -267,6 +332,69 @@ namespace gum_tests {
TS_ASSERT_EQUALS(frag.variable(order.atPos(2)).name(), "v6");
}
void testIBayetNetMethodsWithoutLocalCPTsWithNames() {
gum::BayesNet< double > bn;
fill(bn);
gum::BayesNetFragment< double > frag(bn);
TS_ASSERT(frag.empty());
TS_GUM_ASSERT_THROWS_NOTHING(frag.installNode("v1"));
TS_ASSERT(!frag.empty());
TS_GUM_ASSERT_THROWS_NOTHING(frag.installNode("v6"));
TS_ASSERT_EQUALS(frag.dag().sizeNodes(), gum::Size(2));
TS_ASSERT_EQUALS(frag.dag().sizeArcs(), gum::Size(0));
TS_ASSERT_EQUALS(frag.size(), gum::Size(2));
TS_ASSERT_EQUALS(frag.dim(), gum::Size((3 - 1) + (2 - 1)));
TS_ASSERT_EQUALS(pow(10, frag.log10DomainSize()), 2 * 3);
TS_GUM_ASSERT_THROWS_NOTHING(frag.installAscendants("v6"));
TS_ASSERT_EQUALS(frag.dag().sizeNodes(), gum::Size(3));
TS_ASSERT_EQUALS(frag.dag().sizeArcs(), gum::Size(2));
TS_ASSERT_EQUALS(frag.size(), gum::Size(3));
TS_ASSERT_EQUALS(frag.dim(),
gum::Size((2 * (3 - 1)) + (2 * (2 - 1)) + (2 - 1)));
TS_ASSERT_DELTA(pow(10, frag.log10DomainSize()), 2 * 2 * 3, 1e-5);
auto I = frag.completeInstantiation();
I.setFirst();
TS_ASSERT_EQUALS(I.toString(), "<v1:0|v3:0|v6:0>");
while (!I.end()) {
float p = bn.cpt("v1")[I] * bn.cpt("v3")[I] * bn.cpt("v6")[I];
TS_ASSERT_DELTA(frag.jointProbability(I), p, 1e-5);
TS_ASSERT_DELTA(frag.log2JointProbability(I), log2(p), 1e-5);
++I;
}
gum::Size count = 0;
for (const auto node : frag.nodes()) {
GUM_UNUSED(node);
count++;
}
TS_ASSERT_EQUALS(count, frag.size());
count = 0;
for (const auto arc : frag.arcs()) {
GUM_UNUSED(arc);
count++;
}
TS_ASSERT_EQUALS(count, frag.sizeArcs());
const auto& order = frag.topologicalOrder();
TS_ASSERT_EQUALS(order.size(), gum::Size(3));
TS_ASSERT_EQUALS(frag.variable(order.atPos(0)).name(), "v1");
TS_ASSERT_EQUALS(frag.variable(order.atPos(1)).name(), "v3");
TS_ASSERT_EQUALS(frag.variable(order.atPos(2)).name(), "v6");
}
void testListeners() {
gum::BayesNet< double > bn;
fill(bn);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment