[pyAgrum] add a saveBN for o3prm file

parent 70e2357b
......@@ -185,7 +185,7 @@ IMPROVE_BAYESNET_API(gum::BayesNet);
writer.write( name, *self );
};
std::string loadPRM(std::string name, std::string system="",std::string classpath="",PyObject *l=(PyObject*)0)
std::string loadO3PRM(std::string name, std::string system="",std::string classpath="",PyObject *l=(PyObject*)0)
{
std::stringstream stream;
std::vector<PythonLoadListener> py_listener;
......@@ -207,6 +207,11 @@ IMPROVE_BAYESNET_API(gum::BayesNet);
return "";
};
void saveO3PRM(std::string name) {
gum::O3prmBNWriter<GUM_SCALAR> writer;
writer.write( name, *self );
};
std::string loadBIFXML(std::string name, PyObject *l=(PyObject*)0)
{
std::stringstream stream;
......
......@@ -49,7 +49,7 @@ def availableWriteBNExts():
"""
:return: a string which lists all suffixes for supported output BN file formats.
"""
return "bif|dsl|net|bifxml|uai"
return "bif|dsl|net|bifxml|o3prm|uai"
def loadBN(filename,listeners=None,verbose=False,**opts):
"""
......@@ -72,7 +72,7 @@ def loadBN(filename,listeners=None,verbose=False,**opts):
elif extension=="NET":
warns=bn.loadNET(filename,listeners)
elif extension=="O3PRM":
warns=bn.loadPRM(filename,opts.get('system',''),opts.get('classpath',''),listeners)
warns=bn.loadO3PRM(filename,opts.get('system',''),opts.get('classpath',''),listeners)
elif extension=="UAI":
warns=bn.loadUAI(filename,listeners)
else:
......@@ -99,6 +99,8 @@ def saveBN(bn,filename):
bn.saveNET(filename)
elif extension=="UAI":
bn.saveUAI(filename)
elif extension=="O3PRM":
bn.saveO3PRM(filename)
else:
raise Exception("extension "+filename.split('.')[-1]+" unknown. Please use "+availableWriteBNExts())
......
......@@ -9,7 +9,8 @@ interface generator."
%include "docs.i"
#pragma SWIG nowarn=341,342 // The 'using' keyword in type aliasing is not fully supported yet.
#pragma SWIG nowarn=320 // Explicit template instantiation ignored.
#pragma SWIG nowarn=320 // Explicit template instantiation ignored.
%begin %{
#include <cmath>
......@@ -31,8 +32,9 @@ interface generator."
//////////////////////////////////////////////////////////////////
/* declaration of code modifiers for 'pythonification' of aGrUM */
//////////////////////////////////////////////////////////////////
%include "pythonize.i"
%include "exceptions.i"
%include "pythonize.i"
//////////////////////////////////////////////////////////////////
......@@ -97,3 +99,4 @@ InfluenceDiagramInference = InfluenceDiagramInference_double
BNLearner = BNLearner_double
%}
// Explicit template instantiation ignored.
# -*- encoding: UTF-8 -*-
import unittest
import pyAgrum as gum
import numpy
import pyAgrum as gum
from pyAgrumTestSuite import pyAgrumTestCase, addTests
......@@ -223,28 +223,28 @@ class TestFeatures(BayesNetTestCase):
self.assertEquals(bn.maxNonOneParam(), 0.9)
def test_fastBN(self):
bn=gum.fastBN("a->b->c;a->c" )
self.assertEquals(bn.size(),3)
self.assertEquals(bn.sizeArcs(),3)
self.assertEquals(bn.dim(),( 2 - 1 ) + ( 2 * ( 2 - 1 ) ) + ( 2 * 2 * ( 2 - 1 ) ))
bn=gum.fastBN("a->b->c;a->c",3 )
self.assertEquals(bn.size(),3)
self.assertEquals(bn.sizeArcs(),3)
self.assertEquals(bn.dim(),( 3 - 1 ) + ( 3 * ( 3 - 1 ) ) + ( 3 * 3 * ( 3 - 1 ) ))
bn=gum.fastBN("a->b[5]->c;a->c" )
self.assertEquals(bn.size(),3)
self.assertEquals(bn.sizeArcs(),3)
self.assertEquals(bn.dim(),( 2 - 1 ) + ( 2 * ( 5 - 1 ) ) + ( 2 * 5 * ( 2 - 1 ) ))
bn=gum.fastBN("a->b->c;a[1000]->c" )
self.assertEquals(bn.size(),3)
self.assertEquals(bn.sizeArcs(),3)
self.assertEquals(bn.dim(),( 2 - 1 ) + ( 2 * ( 2 - 1 ) ) + ( 2 * 2 * ( 2 - 1 ) ))
bn = gum.fastBN("a->b->c;a->c")
self.assertEquals(bn.size(), 3)
self.assertEquals(bn.sizeArcs(), 3)
self.assertEquals(bn.dim(), (2 - 1) + (2 * (2 - 1)) + (2 * 2 * (2 - 1)))
bn = gum.fastBN("a->b->c;a->c", 3)
self.assertEquals(bn.size(), 3)
self.assertEquals(bn.sizeArcs(), 3)
self.assertEquals(bn.dim(), (3 - 1) + (3 * (3 - 1)) + (3 * 3 * (3 - 1)))
bn = gum.fastBN("a->b[5]->c;a->c")
self.assertEquals(bn.size(), 3)
self.assertEquals(bn.sizeArcs(), 3)
self.assertEquals(bn.dim(), (2 - 1) + (2 * (5 - 1)) + (2 * 5 * (2 - 1)))
bn = gum.fastBN("a->b->c;a[1000]->c")
self.assertEquals(bn.size(), 3)
self.assertEquals(bn.sizeArcs(), 3)
self.assertEquals(bn.dim(), (2 - 1) + (2 * (2 - 1)) + (2 * 2 * (2 - 1)))
with self.assertRaises(gum.InvalidDirectedCycle):
bn=gum.fastBN("a->b->c->a" )
bn = gum.fastBN("a->b->c->a")
def test_minimalCondSet(self):
bn = gum.fastBN("A->C->E->F->G;B->C;B->D->F;H->E")
......@@ -259,7 +259,7 @@ class TestFeatures(BayesNetTestCase):
self.assertEqual(r, set([iB, iC]))
r = bn.minimalCondSet(iA, set([iE, iF, iG]))
self.assertEqual(r, set([iE,iF]))
self.assertEqual(r, set([iE, iF]))
r = bn.minimalCondSet(iA, set([iB, iC, iE, iF, iG]))
self.assertEqual(r, set([iB, iC]))
......@@ -282,18 +282,19 @@ class TestFeatures(BayesNetTestCase):
r = bn.minimalCondSet(iC, set([iC, iE, iF, iG]))
self.assertEqual(r, set([iC]))
#for set of targets
# for set of targets
tous = set([iA, iB, iC, iD, iE, iF, iG, iH])
r = bn.minimalCondSet([iE,iD], tous)
self.assertEqual(r, set([iE,iD]))
r = bn.minimalCondSet([iE, iD], tous)
self.assertEqual(r, set([iE, iD]))
r = bn.minimalCondSet([iE,iD], tous - set([iE]))
r = bn.minimalCondSet([iE, iD], tous - set([iE]))
self.assertEqual(r, set([iC, iD, iH, iF]))
r = bn.minimalCondSet([iE,iD], tous - set([iE,iD]))
r = bn.minimalCondSet([iE, iD], tous - set([iE, iD]))
self.assertEqual(r, set([iB, iC, iH, iF]))
class TestLoadBN(BayesNetTestCase):
def listen(self, percent):
if not percent > 100:
......@@ -425,8 +426,45 @@ class TestLoadBN(BayesNetTestCase):
self.assertEqual(bn.size(), 5)
class TestSaveBN(BayesNetTestCase):
def testReadAfterWrite(self):
bn = gum.BayesNet()
bn.add(gum.RangeVariable("1", "", 0, 1))
bn.add(gum.DiscretizedVariable("2", "").addTick(0.0).addTick(0.5).addTick(1.0))
bn.add(gum.LabelizedVariable("3", "", 2))
bn.add(gum.LabelizedVariable("4", "", 2))
bn.add(gum.LabelizedVariable("5", "", 3))
bn.addArc("1", "3")
bn.addArc("1", "4")
bn.addArc("3", "5")
bn.addArc("4", "5")
bn.addArc("2", "4")
bn.addArc("2", "5")
bn.cpt("1").fillWith([0.2, 0.8])
bn.cpt("2").fillWith([0.3, 0.7])
bn.cpt("3").fillWith([0.1, 0.9, 0.9, 0.1])
bn.cpt("4").fillWith([0.4, 0.6, 0.5, 0.5, 0.5, 0.5, 1.0, 0.0])
bn.cpt("5").fillWith([0.3, 0.6, 0.1, 0.5, 0.5, 0.0, 0.5, 0.5,
0.0, 1.0, 0.0, 0.0, 0.4, 0.6, 0.0, 0.5,
0.5, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 1.0])
gum.saveBN(bn, self.agrumSrcDir("src/testunits/ressources/o3prm/BNO3PRMIO_file.o3prm"))
bn2=gum.loadBN(self.agrumSrcDir("src/testunits/ressources/o3prm/BNO3PRMIO_file.o3prm"),system="BayesNet")
self.assertEquals(bn.dim(), bn2.dim())
self.assertEquals(bn.log10DomainSize(), bn2.log10DomainSize())
for n in bn.names():
self.assertEquals(bn.variable(n).name(),bn2.variable(n).name())
self.assertEquals(bn.variable(n).varType(),bn2.variable(n).varType())
self.assertEquals(bn.variable(n).domainSize(),bn2.variable(n).domainSize())
ts = unittest.TestSuite()
addTests(ts, TestConstructors)
addTests(ts,TestInsertions)
addTests(ts,TestFeatures)
addTests(ts, TestInsertions)
addTests(ts, TestFeatures)
addTests(ts, TestLoadBN)
addTests(ts, TestSaveBN)
......@@ -73,6 +73,7 @@
#include <agrum/BN/io/BIFXML/BIFXMLBNReader.h>
#include <agrum/BN/io/BIFXML/BIFXMLBNWriter.h>
#include <agrum/PRM/o3prm/O3prmBNReader.h>
#include <agrum/PRM/o3prm/O3prmBNWriter.h>
#include <agrum/BN/io/UAI/UAIReader.h>
#include <agrum/BN/io/UAI/UAIWriter.h>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment