[aGrUM] new method Potential.fillWith(pot,map)

parent 859f886e
...@@ -225,8 +225,7 @@ namespace gum { ...@@ -225,8 +225,7 @@ namespace gum {
/** /**
* @brief copy a Potential data using name of variables and labels (not * @brief copy a Potential data using name of variables and labels (not
* necessarily * necessarily the same variables in the same orders)
* the same variables in the same orders)
* *
* @warning a strict control on names of variables and labels are made * @warning a strict control on names of variables and labels are made
* *
...@@ -235,6 +234,26 @@ namespace gum { ...@@ -235,6 +234,26 @@ namespace gum {
const Potential< GUM_SCALAR >& const Potential< GUM_SCALAR >&
fillWith(const Potential< GUM_SCALAR >& src) const; fillWith(const Potential< GUM_SCALAR >& src) const;
/**
* @brief copy a Potential data using the sequence of names in mapSrc to find
* the corresponding variables.
*
* For instance, to copy the potential P(A,B,C) in Q(D,E,A) with the mapping
* P.A<->Q.E, P.B<->Q.A, P.C<->Q.D (assuming that the corresponding variables
* have the same domain size and the order of labels):
*
* @code
* Q.fillWith(P,{"C","A","B"});
* @endcode
*
* @warning a strict control on names of variables and labels are made
*
* @throw InvalidArgument if the Potential is not compatible with this
* */
const Potential< GUM_SCALAR >&
fillWith(const Potential< GUM_SCALAR >& src,
const std::vector< std::string >& mapSrc) const;
/** /**
* @brief Automatically fills the potential with the values in * @brief Automatically fills the potential with the values in
* v. * v.
......
...@@ -235,12 +235,29 @@ namespace gum { ...@@ -235,12 +235,29 @@ namespace gum {
this->fill(v); this->fill(v);
return *this; return *this;
} }
template < typename GUM_SCALAR > template < typename GUM_SCALAR >
INLINE const Potential< GUM_SCALAR >& INLINE const Potential< GUM_SCALAR >&
Potential< GUM_SCALAR >::fillWith(const Potential< GUM_SCALAR >& src) const { Potential< GUM_SCALAR >::fillWith(const Potential< GUM_SCALAR >& src) const {
if (src.domainSize() != this->domainSize()) { if (src.domainSize() != this->domainSize()) {
GUM_ERROR(InvalidArgument, "Potential to copy has not the same dimension."); GUM_ERROR(InvalidArgument, "Potential to copy has not the same dimension.");
} }
gum::Set< std::string > son; // set of names
for (const auto& v : src.variablesSequence()) {
son.insert(v->name());
}
for (const auto& v : this->variablesSequence()) {
if (!son.contains(v->name())) {
GUM_ERROR(InvalidArgument,
"Variable <" << v->name() << "> not present in src.");
}
// we check size, labels and order of labels in the same time
if (v->toString() != src.variable(v->name()).toString()) {
GUM_ERROR(InvalidArgument,
"Variables <" << v->name() << "> are not identical.");
}
}
Instantiation Isrc(src); Instantiation Isrc(src);
Instantiation Idst(*this); Instantiation Idst(*this);
for (Isrc.setFirst(); !Isrc.end(); ++Isrc) { for (Isrc.setFirst(); !Isrc.end(); ++Isrc) {
...@@ -253,6 +270,36 @@ namespace gum { ...@@ -253,6 +270,36 @@ namespace gum {
return *this; return *this;
} }
template < typename GUM_SCALAR >
INLINE const Potential< GUM_SCALAR >& Potential< GUM_SCALAR >::fillWith(
const Potential< GUM_SCALAR >& src,
const std::vector< std::string >& mapSrc) const {
if (src.nbrDim() != this->nbrDim()) {
GUM_ERROR(InvalidArgument, "Potential to copy has not the same dimension.");
}
if (src.nbrDim() != mapSrc.size()) {
GUM_ERROR(InvalidArgument,
"Potential and vector have not the same dimension.");
}
Instantiation Isrc;
for (Idx i = 0; i < src.nbrDim(); i++) {
if (src.variable(mapSrc[i]).domainSize() != this->variable(i).domainSize()) {
GUM_ERROR(InvalidArgument,
"Variables " << mapSrc[i] << " (in the argument) and "
<< this->variable(i).name()
<< " have not the same dimension.");
} else {
Isrc.add(src.variable(mapSrc[i]));
}
}
Instantiation Idst(*this);
for (Isrc.setFirst(); !Isrc.end(); ++Isrc, ++Idst) {
this->set(Idst, src.get(Isrc));
}
return *this;
}
template < typename GUM_SCALAR > template < typename GUM_SCALAR >
INLINE const Potential< GUM_SCALAR >& Potential< GUM_SCALAR >::sq() const { INLINE const Potential< GUM_SCALAR >& Potential< GUM_SCALAR >::sq() const {
this->apply([](GUM_SCALAR x) { return x * x; }); this->apply([](GUM_SCALAR x) { return x * x; });
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include <agrum/variables/labelizedVariable.h> #include <agrum/variables/labelizedVariable.h>
#include <agrum/multidim/ICIModels/multiDimLogit.h> #include <agrum/multidim/ICIModels/multiDimLogit.h>
#include <agrum/multidim/instantiation.h>
#include <agrum/multidim/implementations/multiDimArray.h> #include <agrum/multidim/implementations/multiDimArray.h>
#include <agrum/multidim/instantiation.h>
#include <agrum/multidim/potential.h> #include <agrum/multidim/potential.h>
namespace gum_tests { namespace gum_tests {
...@@ -501,6 +501,7 @@ namespace gum_tests { ...@@ -501,6 +501,7 @@ namespace gum_tests {
r.fillWith({3, 6, 9, 12, 15, 18, 21, 24, 27}); r.fillWith({3, 6, 9, 12, 15, 18, 21, 24, 27});
TS_ASSERT(pot.reorganize({&b, &c, &a}).extract(I) == r); TS_ASSERT(pot.reorganize({&b, &c, &a}).extract(I) == r);
} }
void testOperatorEqual() { void testOperatorEqual() {
auto a = gum::LabelizedVariable("a", "afoo", 3); auto a = gum::LabelizedVariable("a", "afoo", 3);
auto b = gum::LabelizedVariable("b", "bfoo", 3); auto b = gum::LabelizedVariable("b", "bfoo", 3);
...@@ -961,7 +962,6 @@ namespace gum_tests { ...@@ -961,7 +962,6 @@ namespace gum_tests {
pp.add(ww); pp.add(ww);
pp.add(vv); pp.add(vv);
TS_ASSERT_EQUALS(p.domainSize(), gum::Size(6)); TS_ASSERT_EQUALS(p.domainSize(), gum::Size(6));
TS_ASSERT_EQUALS(pp.domainSize(), gum::Size(6)); TS_ASSERT_EQUALS(pp.domainSize(), gum::Size(6));
...@@ -989,14 +989,38 @@ namespace gum_tests { ...@@ -989,14 +989,38 @@ namespace gum_tests {
gum::Potential< int > bad_p2; gum::Potential< int > bad_p2;
bad_p2.add(vvv); bad_p2.add(vvv);
bad_p2.add(www); bad_p2.add(www);
TS_ASSERT_THROWS(bad_p2.fillWith(p), gum::OutOfBounds); TS_ASSERT_THROWS(bad_p2.fillWith(p), gum::InvalidArgument);
gum::Potential< int > bad_p3; gum::Potential< int > bad_p3;
bad_p3.add(w); bad_p3.add(w);
bad_p3.add(z); bad_p3.add(z);
// TS_GUM_ASSERT_THROWS_NOTHING(bad_p3.fillWith(p)); TS_ASSERT_THROWS(bad_p3.fillWith(p), gum::InvalidArgument);
TS_ASSERT_THROWS(bad_p3.fillWith(p), gum::NotFound);
gum::Potential< int > bad_p4;
gum::LabelizedVariable badv("v", "v", 0);
badv.addLabel("3").addLabel("1");
bad_p4.add(w);
bad_p4.add(badv);
TS_ASSERT_THROWS(bad_p4.fillWith(p), gum::InvalidArgument);
}
void testFillWithPotentialAndMapMethod() {
gum::LabelizedVariable v("v", "v", 2), w("w", "w", 3);
gum::Potential< int > p;
p.add(v);
p.add(w);
gum::LabelizedVariable vv("vv", "vv", 2), ww("ww", "ww", 3);
gum::Potential< int > pp;
pp.add(ww);
pp.add(vv);
TS_ASSERT_EQUALS(p.domainSize(), gum::Size(6));
TS_ASSERT_EQUALS(pp.domainSize(), gum::Size(6));
p.fillWith({1, 2, 3, 4, 5, 6});
TS_GUM_ASSERT_THROWS_NOTHING(pp.fillWith(p, {"w", "v"}));
TS_ASSERT_THROWS(pp.fillWith(p, {"v", "w"}), gum::InvalidArgument);
} }
}; };
} }
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -556,8 +556,23 @@ def _reprPotential(pot, digits=4, withColors=True, varnames=None, asString=False ...@@ -556,8 +556,23 @@ def _reprPotential(pot, digits=4, withColors=True, varnames=None, asString=False
else: else:
return HTML("".join(html)) return HTML("".join(html))
def __isKindOfProba(pot):
"""
check if pot is a joint proba or a CPT
:param pot: the potential
:return: True or False
"""
if abs(pot.sum()-1)<1e-2:
return True
q=pot.margSumOut([pot.variable(0).name()])
if abs(q.max()-1)>1e-2:
return False
if abs(q.min()-1)>1e-2:
return False
return True
def showPotential(pot, digits=4, withColors=True, varnames=None): def showPotential(pot, digits=4, withColors=None, varnames=None):
""" """
show a gum.Potential as a HTML table. show a gum.Potential as a HTML table.
The first dimension is special (horizontal) due to the representation of conditional probability table The first dimension is special (horizontal) due to the representation of conditional probability table
...@@ -568,10 +583,13 @@ def showPotential(pot, digits=4, withColors=True, varnames=None): ...@@ -568,10 +583,13 @@ def showPotential(pot, digits=4, withColors=True, varnames=None):
:param list of strings varnames: the aliases for variables name in the table :param list of strings varnames: the aliases for variables name in the table
:return: the display of the potential :return: the display of the potential
""" """
if withColors is None:
withColors=__isKindOfProba
display(_reprPotential(pot, digits, withColors, varnames, asString=False)) display(_reprPotential(pot, digits, withColors, varnames, asString=False))
def getPotential(pot, digits=4, withColors=True, varnames=None): def getPotential(pot, digits=4, withColors=None, varnames=None):
""" """
return a HTML string of a gum.Potential as a HTML table. return a HTML string of a gum.Potential as a HTML table.
The first dimension is special (horizontal) due to the representation of conditional probability table The first dimension is special (horizontal) due to the representation of conditional probability table
...@@ -582,6 +600,9 @@ def getPotential(pot, digits=4, withColors=True, varnames=None): ...@@ -582,6 +600,9 @@ def getPotential(pot, digits=4, withColors=True, varnames=None):
:param list of strings varnames: the aliases for variables name in the table :param list of strings varnames: the aliases for variables name in the table
:return: the HTML string :return: the HTML string
""" """
if withColors is None:
withColors=__isKindOfProba(pot)
return _reprPotential(pot, digits, withColors, varnames, asString=True) return _reprPotential(pot, digits, withColors, varnames, asString=True)
......
...@@ -30,7 +30,6 @@ CHANGE_THEN_RETURN_SELF(sq) ...@@ -30,7 +30,6 @@ CHANGE_THEN_RETURN_SELF(sq)
CHANGE_THEN_RETURN_SELF(scale) CHANGE_THEN_RETURN_SELF(scale)
CHANGE_THEN_RETURN_SELF(translate) CHANGE_THEN_RETURN_SELF(translate)
CHANGE_THEN_RETURN_SELF(normalizeAsCPT) CHANGE_THEN_RETURN_SELF(normalizeAsCPT)
CHANGE_THEN_RETURN_SELF(fillWith)
CHANGE_THEN_RETURN_SELF(set) CHANGE_THEN_RETURN_SELF(set)
%rename ("$ignore", fullname=1) gum::Potential<double>::margSumOut(const Set<const DiscreteVariable*>& del_vars) const; %rename ("$ignore", fullname=1) gum::Potential<double>::margSumOut(const Set<const DiscreteVariable*>& del_vars) const;
......
...@@ -32,7 +32,7 @@ class BNDatabaseGeneratorTestCase(pyAgrumTestCase): ...@@ -32,7 +32,7 @@ class BNDatabaseGeneratorTestCase(pyAgrumTestCase):
with self.assertRaises(gum.FatalError): with self.assertRaises(gum.FatalError):
dbgen.setVarOrder(["A", "O", "R", "S", "T"]) dbgen.setVarOrder(["A", "O", "R", "S", "T"])
with self.assertRaises(IndexError): with self.assertRaises(gum.NotFound):
dbgen.setVarOrder(["A", "O", "R", "S", "T", "X"]) dbgen.setVarOrder(["A", "O", "R", "S", "T", "X"])
def testDrawSamples(self): def testDrawSamples(self):
......
...@@ -655,7 +655,7 @@ class TestOperators(pyAgrumTestCase): ...@@ -655,7 +655,7 @@ class TestOperators(pyAgrumTestCase):
self.assertNotEqual(p.variable(1), p.variable('v')) self.assertNotEqual(p.variable(1), p.variable('v'))
self.assertNotEqual(p.variable(0), p.variable('w')) self.assertNotEqual(p.variable(0), p.variable('w'))
with self.assertRaises(IndexError): with self.assertRaises(gum.NotFound):
x = p.variable("zz") x = p.variable("zz")
def testFillWithPotential(self): def testFillWithPotential(self):
...@@ -673,6 +673,24 @@ class TestOperators(pyAgrumTestCase): ...@@ -673,6 +673,24 @@ class TestOperators(pyAgrumTestCase):
self.assertAlmostEquals(np.max(pABC.reorganize(['A', 'B', 'C']).toarray() - self.assertAlmostEquals(np.max(pABC.reorganize(['A', 'B', 'C']).toarray() -
pABC2.reorganize(['A', 'B', 'C']).toarray()), 0) pABC2.reorganize(['A', 'B', 'C']).toarray()), 0)
def testFillWithPotentialAndMap(self):
v = gum.LabelizedVariable("v", "v", 2)
w = gum.LabelizedVariable("w", "w", 3)
p = gum.Potential().add(v).add(w)
p.fillWith([1, 2, 3, 4, 5, 6])
vv = gum.LabelizedVariable("vv", "vv", 2)
ww = gum.LabelizedVariable("ww", "ww", 3)
pp = gum.Potential().add(ww).add(vv)
pp.fillWith(p, ["w", "v"])
self.assertAlmostEquals(np.max(p.reorganize(['v', 'w']).toarray() -
pp.reorganize(['vv', 'ww']).toarray()), 0)
vvv = gum.LabelizedVariable("vvv", "vvv", 2)
www = gum.LabelizedVariable("www", "www", 2)
ppp = gum.Potential().add(vvv).add(www)
with self.assertRaises(gum.InvalidArgument):
ppp.fillWith(p, ["w", "v"])
ts = unittest.TestSuite() ts = unittest.TestSuite()
addTests(ts, TestInsertions) addTests(ts, TestInsertions)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment