[aGrUM] BayesNetFragment API : copy local potentials (instead of only refer to)

parent 5af74613
......@@ -221,7 +221,7 @@ namespace gum {
}
/**
* install a local marginal for a node into the fragment.
* install a local marginal BY COPY for a node into the fragment.
* This function will remove all the arcs from the parents to the node.
* @param id the nodeId
* @param pot the potential
......@@ -230,26 +230,26 @@ namespace gum {
*variable
*(or is not a marginal)
**/
void installMarginal(NodeId id, const Potential< GUM_SCALAR >* pot);
void installMarginal(NodeId id, const Potential< GUM_SCALAR >& pot);
void installMarginal(const std::string& name,
const Potential< GUM_SCALAR >* pot) {
const Potential< GUM_SCALAR >& pot) {
installMarginal(__bn.idFromName(name), pot);
}
/**
* install a local cpt for a node into the fragment.
* install a local cpt BY COPYfor a node into the fragment.
* This function will change the arcs from the parents to the node in order
*to be
* consistent with the new local potential.
* @param id the nodeId
* @param pot the potential<
* @param pot the potential to be copied
*
* @throw NotFound if the id is not in the fragment
* @throw OperationNotAllowed if the potential is not compliant with the
*variable or if a variable in the CPT is not a parent in the referred bn.
**/
void installCPT(NodeId id, const Potential< GUM_SCALAR >* pot);
void installCPT(const std::string& name, const Potential< GUM_SCALAR >* pot) {
void installCPT(NodeId id, const Potential< GUM_SCALAR >& pot);
void installCPT(const std::string& name, const Potential< GUM_SCALAR >& pot) {
installCPT(__bn.idFromName(name), pot);
};
......@@ -291,10 +291,10 @@ namespace gum {
// add an arc
void _installArc(NodeId from, NodeId to);
// install a CPT, create or delete arcs. Checks are made in public methods
// install a CPT BY COPY, create or delete arcs. Checks are made in public methods
// In particular, it is assumed that all the variables in the pot are in the
// fragment
void _installCPT(NodeId id, const Potential< GUM_SCALAR >* pot);
void _installCPT(NodeId id, const Potential< GUM_SCALAR >& pot);
/**
* uninstall a local CPT. Does nothing if no local CPT for this nodeId
......
......@@ -78,7 +78,8 @@ namespace gum {
template < typename GUM_SCALAR >
INLINE const Potential< GUM_SCALAR >&
BayesNetFragment< GUM_SCALAR >::cpt(NodeId id) const {
if (!isInstalledNode(id)) GUM_ERROR(NotFound, id << " is not installed");
if (!isInstalledNode(id))
GUM_ERROR(NotFound, "NodeId " << id << " is not installed");
if (__localCPTs.exists(id))
return *__localCPTs[id];
......@@ -96,7 +97,8 @@ namespace gum {
template < typename GUM_SCALAR >
INLINE const DiscreteVariable&
BayesNetFragment< GUM_SCALAR >::variable(NodeId id) const {
if (!isInstalledNode(id)) GUM_ERROR(NotFound, id << " is not installed");
if (!isInstalledNode(id))
GUM_ERROR(NotFound, "NodeId " << id << " is not installed");
return __bn.variable(id);
}
......@@ -126,7 +128,7 @@ namespace gum {
template < typename GUM_SCALAR >
INLINE const DiscreteVariable& BayesNetFragment< GUM_SCALAR >::variableFromName(
const std::string& name) const {
NodeId id = __bn.idFromName(name);
NodeId id = idFromName(name);
if (!isInstalledNode(id))
GUM_ERROR(NotFound, "variable " << name << " is not installed");
......@@ -172,8 +174,8 @@ namespace gum {
template < typename GUM_SCALAR >
INLINE void BayesNetFragment< GUM_SCALAR >::uninstallNode(NodeId id) {
if (isInstalledNode(id)) {
this->_dag.eraseNode(id);
uninstallCPT(id);
this->_dag.eraseNode(id);
}
}
......@@ -190,15 +192,15 @@ namespace gum {
template < typename GUM_SCALAR >
void BayesNetFragment< GUM_SCALAR >::_installCPT(
NodeId id, const Potential< GUM_SCALAR >* pot) {
NodeId id, const Potential< GUM_SCALAR >& pot) {
// topology
const auto& parents = this->parents(id);
for (auto node_it = parents.beginSafe(); node_it != parents.endSafe();
++node_it) // safe iterator needed here
_uninstallArc(*node_it, id);
for (Idx i = 1; i < pot->nbrDim(); i++) {
NodeId parent = __bn.idFromName(pot->variable(i).name());
for (Idx i = 1; i < pot.nbrDim(); i++) {
NodeId parent = __bn.idFromName(pot.variable(i).name());
if (isInstalledNode(parent)) _installArc(parent, id);
}
......@@ -206,16 +208,16 @@ namespace gum {
// local cpt
if (__localCPTs.exists(id)) _uninstallCPT(id);
__localCPTs.insert(id, pot);
__localCPTs.insert(id, new gum::Potential< GUM_SCALAR >(pot));
}
template < typename GUM_SCALAR >
void BayesNetFragment< GUM_SCALAR >::installCPT(
NodeId id, const Potential< GUM_SCALAR >* pot) {
NodeId id, const Potential< GUM_SCALAR >& pot) {
if (!dag().existsNode(id))
GUM_ERROR(NotFound, "Node " << id << " is not installed in the fragment");
if (&(pot->variable(0)) != &(variable(id))) {
if (&(pot.variable(0)) != &(variable(id))) {
GUM_ERROR(OperationNotAllowed,
"The potential is not a marginal for __bn.variable <"
<< variable(id).name() << ">");
......@@ -223,10 +225,10 @@ namespace gum {
const NodeSet& parents = __bn.parents(id);
for (Idx i = 1; i < pot->nbrDim(); i++) {
if (!parents.contains(__bn.idFromName(pot->variable(i).name())))
for (Idx i = 1; i < pot.nbrDim(); i++) {
if (!parents.contains(__bn.idFromName(pot.variable(i).name())))
GUM_ERROR(OperationNotAllowed,
"Variable <" << pot->variable(i).name()
"Variable <" << pot.variable(i).name()
<< "> is not in the parents of node " << id);
}
......@@ -257,16 +259,16 @@ namespace gum {
template < typename GUM_SCALAR >
void BayesNetFragment< GUM_SCALAR >::installMarginal(
NodeId id, const Potential< GUM_SCALAR >* pot) {
NodeId id, const Potential< GUM_SCALAR >& pot) {
if (!isInstalledNode(id)) {
GUM_ERROR(NotFound, "The node " << id << " is not part of this fragment");
}
if (pot->nbrDim() > 1) {
if (pot.nbrDim() > 1) {
GUM_ERROR(OperationNotAllowed, "The potential is not a marginal :" << pot);
}
if (&(pot->variable(0)) != &(__bn.variable(id))) {
if (&(pot.variable(0)) != &(__bn.variable(id))) {
GUM_ERROR(OperationNotAllowed,
"The potential is not a marginal for __bn.variable <"
<< __bn.variable(id).name() << ">");
......
......@@ -94,9 +94,8 @@ namespace gum {
void ImportanceSampling< GUM_SCALAR >::_unsharpenBN(
BayesNetFragment< GUM_SCALAR >* bn, float epsilon) {
for (const auto nod : bn->nodes().asNodeSet()) {
auto p = new Potential< GUM_SCALAR >();
*p = bn->cpt(nod).isNonZeroMap().scale(epsilon) + bn->cpt(nod);
p->normalizeAsCPT();
auto p = bn->cpt(nod).isNonZeroMap().scale(epsilon) + bn->cpt(nod);
p.normalizeAsCPT();
bn->installCPT(nod, p);
}
}
......@@ -106,7 +105,7 @@ namespace gum {
BayesNetFragment< GUM_SCALAR >* bn) {
for (const auto ev : this->hardEvidenceNodes()) {
bn->uninstallCPT(ev);
bn->installCPT(ev, new Potential< GUM_SCALAR >(*this->evidence()[ev]));
bn->installCPT(ev, *(this->evidence()[ev]));
// we keep the variables with hard evidence but alone
// bn->uninstallNode( sid[i] );
}
......
......@@ -146,9 +146,7 @@ namespace gum {
I.chgVal(this->BN().variable(hard), this->hardEvidence()[hard]);
for (const auto& child : this->BN().children(hard)) {
auto p = new gum::Potential< GUM_SCALAR >();
*p = this->BN().cpt(child).extract(I);
__samplingBN->installCPT(child, p);
__samplingBN->installCPT(child, this->BN().cpt(child).extract(I));
}
}
......
......@@ -490,9 +490,9 @@ namespace gum_tests {
gum::BayesNetFragment< double > frag(bn);
frag.installAscendants(bn.idFromName("v6")); // 1->3->6
gum::Potential< double >* newV3 = new gum::Potential< double >();
(*newV3) << bn.variable(bn.idFromName("v3"));
newV3->fillWith({0.0, 1.0});
gum::Potential< double > newV3;
newV3 << bn.variable(bn.idFromName("v3"));
newV3.fillWith({0.0, 1.0});
frag.installMarginal(frag.idFromName("v3"), newV3); // 1 3->6
TS_ASSERT_EQUALS(frag.size(), (gum::Size)3);
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)1);
......@@ -533,9 +533,9 @@ namespace gum_tests {
TS_ASSERT(!frag.checkConsistency());
gum::Potential< double >* newV5 = new gum::Potential< double >();
(*newV5) << bn.variable(bn.idFromName("v5"));
newV5->fillWith({0.0, 0.0, 1.0});
gum::Potential< double > newV5;
newV5 << bn.variable(bn.idFromName("v5"));
newV5.fillWith({0.0, 0.0, 1.0});
frag.installMarginal(frag.idFromName("v5"), newV5); // 1-->3-->6 5
TS_ASSERT(frag.checkConsistency());
TS_ASSERT_EQUALS(frag.size(), (gum::Size)4);
......@@ -559,10 +559,10 @@ namespace gum_tests {
TS_ASSERT_EQUALS(frag.size(), (gum::Size)5);
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)4);
gum::Potential< double >* newV5bis = new gum::Potential< double >();
(*newV5bis) << bn.variable(bn.idFromName("v5"))
<< bn.variable(bn.idFromName("v2"))
<< bn.variable(bn.idFromName("v3"));
gum::Potential< double > newV5bis;
newV5bis << bn.variable(bn.idFromName("v5"))
<< bn.variable(bn.idFromName("v2"))
<< bn.variable(bn.idFromName("v3"));
frag.installCPT(frag.idFromName("v5"), newV5bis);
TS_ASSERT(frag.checkConsistency());
TS_ASSERT_EQUALS(frag.size(), (gum::Size)5);
......@@ -584,17 +584,16 @@ namespace gum_tests {
TS_ASSERT_EQUALS(frag.size(), (gum::Size)6);
TS_ASSERT_EQUALS(frag.sizeArcs(), (gum::Size)7);
gum::Potential< double >* newV5 = new gum::Potential< double >();
(*newV5) << bn.variable(bn.idFromName("v5"))
<< bn.variable(bn.idFromName("v2"))
<< bn.variable(bn.idFromName("v3"));
gum::Potential< double > newV5;
newV5 << bn.variable(bn.idFromName("v5")) << bn.variable(bn.idFromName("v2"))
<< bn.variable(bn.idFromName("v3"));
const gum::Potential< double >& pot2 = bn2.cpt(bn2.idFromName("v5"));
gum::Instantiation I(pot2);
gum::Instantiation J(*newV5);
gum::Instantiation J(newV5);
for (I.setFirst(), J.setFirst(); !I.end(); ++I, ++J)
newV5->set(J, pot2[I]);
newV5.set(J, pot2[I]);
frag.installCPT(frag.idFromName("v5"), newV5);
TS_ASSERT(frag.checkConsistency());
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment