[aGrUM] new method Potential.fillWith(pot,map)

parent 859f886e
......@@ -225,8 +225,7 @@ namespace gum {
/**
* @brief copy a Potential data using name of variables and labels (not
* necessarily
* the same variables in the same orders)
* necessarily the same variables in the same orders)
*
* @warning a strict control on names of variables and labels are made
*
......@@ -235,6 +234,26 @@ namespace gum {
const Potential< GUM_SCALAR >&
fillWith(const Potential< GUM_SCALAR >& src) const;
/**
* @brief copy a Potential data using the sequence of names in mapSrc to find
* the corresponding variables.
*
* For instance, to copy the potential P(A,B,C) in Q(D,E,A) with the mapping
* P.A<->Q.E, P.B<->Q.A, P.C<->Q.D (assuming that the corresponding variables
* have the same domain size and the order of labels):
*
* @code
* Q.fillWith(P,{"C","A","B"});
* @endcode
*
* @warning a strict control on names of variables and labels are made
*
* @throw InvalidArgument if the Potential is not compatible with this
* */
const Potential< GUM_SCALAR >&
fillWith(const Potential< GUM_SCALAR >& src,
const std::vector< std::string >& mapSrc) const;
/**
* @brief Automatically fills the potential with the values in
* v.
......
......@@ -235,12 +235,29 @@ namespace gum {
this->fill(v);
return *this;
}
template < typename GUM_SCALAR >
INLINE const Potential< GUM_SCALAR >&
Potential< GUM_SCALAR >::fillWith(const Potential< GUM_SCALAR >& src) const {
if (src.domainSize() != this->domainSize()) {
GUM_ERROR(InvalidArgument, "Potential to copy has not the same dimension.");
}
gum::Set< std::string > son; // set of names
for (const auto& v : src.variablesSequence()) {
son.insert(v->name());
}
for (const auto& v : this->variablesSequence()) {
if (!son.contains(v->name())) {
GUM_ERROR(InvalidArgument,
"Variable <" << v->name() << "> not present in src.");
}
// we check size, labels and order of labels in the same time
if (v->toString() != src.variable(v->name()).toString()) {
GUM_ERROR(InvalidArgument,
"Variables <" << v->name() << "> are not identical.");
}
}
Instantiation Isrc(src);
Instantiation Idst(*this);
for (Isrc.setFirst(); !Isrc.end(); ++Isrc) {
......@@ -253,6 +270,36 @@ namespace gum {
return *this;
}
template < typename GUM_SCALAR >
INLINE const Potential< GUM_SCALAR >& Potential< GUM_SCALAR >::fillWith(
const Potential< GUM_SCALAR >& src,
const std::vector< std::string >& mapSrc) const {
if (src.nbrDim() != this->nbrDim()) {
GUM_ERROR(InvalidArgument, "Potential to copy has not the same dimension.");
}
if (src.nbrDim() != mapSrc.size()) {
GUM_ERROR(InvalidArgument,
"Potential and vector have not the same dimension.");
}
Instantiation Isrc;
for (Idx i = 0; i < src.nbrDim(); i++) {
if (src.variable(mapSrc[i]).domainSize() != this->variable(i).domainSize()) {
GUM_ERROR(InvalidArgument,
"Variables " << mapSrc[i] << " (in the argument) and "
<< this->variable(i).name()
<< " have not the same dimension.");
} else {
Isrc.add(src.variable(mapSrc[i]));
}
}
Instantiation Idst(*this);
for (Isrc.setFirst(); !Isrc.end(); ++Isrc, ++Idst) {
this->set(Idst, src.get(Isrc));
}
return *this;
}
template < typename GUM_SCALAR >
INLINE const Potential< GUM_SCALAR >& Potential< GUM_SCALAR >::sq() const {
this->apply([](GUM_SCALAR x) { return x * x; });
......
......@@ -24,8 +24,8 @@
#include <agrum/variables/labelizedVariable.h>
#include <agrum/multidim/ICIModels/multiDimLogit.h>
#include <agrum/multidim/instantiation.h>
#include <agrum/multidim/implementations/multiDimArray.h>
#include <agrum/multidim/instantiation.h>
#include <agrum/multidim/potential.h>
namespace gum_tests {
......@@ -501,6 +501,7 @@ namespace gum_tests {
r.fillWith({3, 6, 9, 12, 15, 18, 21, 24, 27});
TS_ASSERT(pot.reorganize({&b, &c, &a}).extract(I) == r);
}
void testOperatorEqual() {
auto a = gum::LabelizedVariable("a", "afoo", 3);
auto b = gum::LabelizedVariable("b", "bfoo", 3);
......@@ -961,7 +962,6 @@ namespace gum_tests {
pp.add(ww);
pp.add(vv);
TS_ASSERT_EQUALS(p.domainSize(), gum::Size(6));
TS_ASSERT_EQUALS(pp.domainSize(), gum::Size(6));
......@@ -989,14 +989,38 @@ namespace gum_tests {
gum::Potential< int > bad_p2;
bad_p2.add(vvv);
bad_p2.add(www);
TS_ASSERT_THROWS(bad_p2.fillWith(p), gum::OutOfBounds);
TS_ASSERT_THROWS(bad_p2.fillWith(p), gum::InvalidArgument);
gum::Potential< int > bad_p3;
bad_p3.add(w);
bad_p3.add(z);
// TS_GUM_ASSERT_THROWS_NOTHING(bad_p3.fillWith(p));
TS_ASSERT_THROWS(bad_p3.fillWith(p), gum::NotFound);
TS_ASSERT_THROWS(bad_p3.fillWith(p), gum::InvalidArgument);
gum::Potential< int > bad_p4;
gum::LabelizedVariable badv("v", "v", 0);
badv.addLabel("3").addLabel("1");
bad_p4.add(w);
bad_p4.add(badv);
TS_ASSERT_THROWS(bad_p4.fillWith(p), gum::InvalidArgument);
}
void testFillWithPotentialAndMapMethod() {
gum::LabelizedVariable v("v", "v", 2), w("w", "w", 3);
gum::Potential< int > p;
p.add(v);
p.add(w);
gum::LabelizedVariable vv("vv", "vv", 2), ww("ww", "ww", 3);
gum::Potential< int > pp;
pp.add(ww);
pp.add(vv);
TS_ASSERT_EQUALS(p.domainSize(), gum::Size(6));
TS_ASSERT_EQUALS(pp.domainSize(), gum::Size(6));
p.fillWith({1, 2, 3, 4, 5, 6});
TS_GUM_ASSERT_THROWS_NOTHING(pp.fillWith(p, {"w", "v"}));
TS_ASSERT_THROWS(pp.fillWith(p, {"v", "w"}), gum::InvalidArgument);
}
};
}
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -556,8 +556,23 @@ def _reprPotential(pot, digits=4, withColors=True, varnames=None, asString=False
else:
return HTML("".join(html))
def __isKindOfProba(pot):
"""
check if pot is a joint proba or a CPT
:param pot: the potential
:return: True or False
"""
if abs(pot.sum()-1)<1e-2:
return True
q=pot.margSumOut([pot.variable(0).name()])
if abs(q.max()-1)>1e-2:
return False
if abs(q.min()-1)>1e-2:
return False
return True
def showPotential(pot, digits=4, withColors=True, varnames=None):
def showPotential(pot, digits=4, withColors=None, varnames=None):
"""
show a gum.Potential as a HTML table.
The first dimension is special (horizontal) due to the representation of conditional probability table
......@@ -568,10 +583,13 @@ def showPotential(pot, digits=4, withColors=True, varnames=None):
:param list of strings varnames: the aliases for variables name in the table
:return: the display of the potential
"""
if withColors is None:
withColors=__isKindOfProba
display(_reprPotential(pot, digits, withColors, varnames, asString=False))
def getPotential(pot, digits=4, withColors=True, varnames=None):
def getPotential(pot, digits=4, withColors=None, varnames=None):
"""
return a HTML string of a gum.Potential as a HTML table.
The first dimension is special (horizontal) due to the representation of conditional probability table
......@@ -582,6 +600,9 @@ def getPotential(pot, digits=4, withColors=True, varnames=None):
:param list of strings varnames: the aliases for variables name in the table
:return: the HTML string
"""
if withColors is None:
withColors=__isKindOfProba(pot)
return _reprPotential(pot, digits, withColors, varnames, asString=True)
......
......@@ -30,7 +30,6 @@ CHANGE_THEN_RETURN_SELF(sq)
CHANGE_THEN_RETURN_SELF(scale)
CHANGE_THEN_RETURN_SELF(translate)
CHANGE_THEN_RETURN_SELF(normalizeAsCPT)
CHANGE_THEN_RETURN_SELF(fillWith)
CHANGE_THEN_RETURN_SELF(set)
%rename ("$ignore", fullname=1) gum::Potential<double>::margSumOut(const Set<const DiscreteVariable*>& del_vars) const;
......
......@@ -32,7 +32,7 @@ class BNDatabaseGeneratorTestCase(pyAgrumTestCase):
with self.assertRaises(gum.FatalError):
dbgen.setVarOrder(["A", "O", "R", "S", "T"])
with self.assertRaises(IndexError):
with self.assertRaises(gum.NotFound):
dbgen.setVarOrder(["A", "O", "R", "S", "T", "X"])
def testDrawSamples(self):
......
......@@ -655,7 +655,7 @@ class TestOperators(pyAgrumTestCase):
self.assertNotEqual(p.variable(1), p.variable('v'))
self.assertNotEqual(p.variable(0), p.variable('w'))
with self.assertRaises(IndexError):
with self.assertRaises(gum.NotFound):
x = p.variable("zz")
def testFillWithPotential(self):
......@@ -673,6 +673,24 @@ class TestOperators(pyAgrumTestCase):
self.assertAlmostEquals(np.max(pABC.reorganize(['A', 'B', 'C']).toarray() -
pABC2.reorganize(['A', 'B', 'C']).toarray()), 0)
def testFillWithPotentialAndMap(self):
v = gum.LabelizedVariable("v", "v", 2)
w = gum.LabelizedVariable("w", "w", 3)
p = gum.Potential().add(v).add(w)
p.fillWith([1, 2, 3, 4, 5, 6])
vv = gum.LabelizedVariable("vv", "vv", 2)
ww = gum.LabelizedVariable("ww", "ww", 3)
pp = gum.Potential().add(ww).add(vv)
pp.fillWith(p, ["w", "v"])
self.assertAlmostEquals(np.max(p.reorganize(['v', 'w']).toarray() -
pp.reorganize(['vv', 'ww']).toarray()), 0)
vvv = gum.LabelizedVariable("vvv", "vvv", 2)
www = gum.LabelizedVariable("www", "www", 2)
ppp = gum.Potential().add(vvv).add(www)
with self.assertRaises(gum.InvalidArgument):
ppp.fillWith(p, ["w", "v"])
ts = unittest.TestSuite()
addTests(ts, TestInsertions)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment