Commit 5562915b authored by Lionel's avatar Lionel

Adding O3PRM write support for BN

parent 56ca7a9a
......@@ -109,7 +109,8 @@ namespace gum {
void __generateBN( prm::PRMSystem<GUM_SCALAR>& system );
static std::string __getVariableName( const std::string& path,
const std::string& type,
const std::string& name );
const std::string& name,
const std::string& toRemove = "" );
static std::string __getEntityName( const std::string& filename );
static std::string __getInstanceName( const std::string& classname );
};
......
......@@ -29,9 +29,18 @@
namespace gum {
template <typename GUM_SCALAR>
INLINE std::string O3prmBNReader<GUM_SCALAR>::__getVariableName(
const std::string& path, const std::string& type, const std::string& name ) {
return path + name; // path ends up with a "."
INLINE std::string
O3prmBNReader<GUM_SCALAR>::__getVariableName( const std::string& path,
const std::string& type,
const std::string& name,
const std::string& toRemove ) {
auto res = path + name; // path ends up with a "."
if ( toRemove != "" ) {
if ( res.substr( 0, toRemove.size() ) == toRemove ) {
res = res.substr( toRemove.size() );
}
}
return res;
}
template <typename GUM_SCALAR>
......@@ -82,22 +91,27 @@ namespace gum {
gum::prm::PRM<GUM_SCALAR>* prm = reader.prm();
__errors = reader.errorsContainer();
if ( errors() == 0 ) {
std::string instanceName = "";
if ( prm->isSystem( __entityName ) ) {
__generateBN( prm->getSystem( __entityName ) );
} else {
if ( prm->isClass( __entityName ) ) {
ParseError warn( false,
"No system '" + __entityName +
"' found but class found. Generating instance.",
__filename,
0 );
ParseError warn(
false,
"No system '" + __entityName +
"' found but class found. Generating unnamed instance.",
__filename,
0 );
__errors.add( warn );
gum::prm::PRMSystem<GUM_SCALAR> s( "S_" + __entityName );
instanceName = __getInstanceName( __entityName );
auto i = new gum::prm::PRMInstance<GUM_SCALAR>(
__getInstanceName( __entityName ), prm->getClass( __entityName ) );
instanceName, prm->getClass( __entityName ) );
s.add( i );
__generateBN( s );
instanceName += "."; // to be removed in __getVariableName
} else {
ParseError err( true,
"Neither system nor class '" + __entityName + "'.",
......
/***************************************************************************
* Copyright (C) 2005 by Pierre-Henri WUILLEMIN et Christophe GONZALES *
* {prenom.nom}_at_lip6.fr *
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program; if not, write to the *
* Free Software Foundation, Inc., *
* 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
***************************************************************************/
/** @file
* @brief Outlined implementation of UAIBNReader
*
* @author Pierre-Henri WUILLEMIN and Christophe GONZALES
*/
#include <agrum/PRM/o3prm/O3prmBNWriter.h>
template class gum::O3prmBNWriter<float>;
template class gum::O3prmBNWriter<double>;
/***************************************************************************
* Copyright (C) 2005 by Pierre-Henri WUILLEMIN et Christophe GONZALES *
* {prenom.nom}_at_lip6.fr *
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program; if not, write to the *
* Free Software Foundation, Inc., *
* 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
***************************************************************************/
/**
* @file
* @brief Definition file for BIF XML exportation class
*
* Writes an bayes net in XML files with BIF format
*
* @author Jean-Christophe MAGNAN and Pierre-Henri WUILLEMIN
*/
#ifndef GUM_O3PRMBNWRITER_H
#define GUM_O3PRMBNWRITER_H
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <agrum/BN/io/BNWriter.h>
#include <agrum/variables/discreteVariable.h>
#include <agrum/variables/discretizedVariable.h>
#include <agrum/variables/rangeVariable.h>
#include <agrum/config.h>
namespace gum {
/**
* @class O3prmBNWriter O3prmBNWriter.h
*<agrum/PRM/o3prm/O3prmBNWriter.h>
* @ingroup bn_io
* @brief Writes an bayes net in a text file with O3PRM format
*
* This class export a bayes net into an text file, using O3PRM format
*
*/
template <typename GUM_SCALAR>
class O3prmBNWriter : public BNWriter<GUM_SCALAR> {
public:
// ==========================================================================
/// @name Constructor & destructor
// ==========================================================================
/// @{
/**
* Default constructor.
*/
O3prmBNWriter();
/**
* Destructor.
*/
virtual ~O3prmBNWriter();
/// @}
/**
* Writes an bayes net in the given ouput stream.
*
* @param output The output stream.
* @param bn The bayes net writen in the stream.
* @throws IOError Raised if an I/O error occurs.
*/
virtual void write( std::ostream& output, const IBayesNet<GUM_SCALAR>& bn );
/**
* Writes an bayes net in the file referenced by filePath.
* If the file doesn't exists, it is created.
* If the file exists, it's content will be erased.
*
* @param filePath The path to the file used to write the bayes net.
* @param bn The bayes net written in the file.
* @throw IOError Raised if an I/O error occurs.
*/
virtual void write( std::string filePath, const IBayesNet<GUM_SCALAR>& bn );
private:
std::string __extractAttribute( const IBayesNet<GUM_SCALAR>& bn, NodeId node );
std::string __extractType(const IBayesNet<GUM_SCALAR>& bn, NodeId node);
template <typename VARTYPE>
std::string __extractDiscretizedType(const VARTYPE* var);
std::string __extractName(const IBayesNet<GUM_SCALAR>& bn, NodeId node);
std::string __extractParents(const IBayesNet<GUM_SCALAR>& bn, NodeId node);
std::string __extractCPT(const IBayesNet<GUM_SCALAR>& bn, NodeId node);
std::string __extractRangeType(const IBayesNet<GUM_SCALAR>& bn, NodeId node);
std::string __extractLabelizedType(const IBayesNet<GUM_SCALAR>& bn, NodeId node);
};
extern template class O3prmBNWriter<float>;
extern template class O3prmBNWriter<double>;
} /* namespace gum */
#include <agrum/PRM/o3prm/O3prmBNWriter_tpl.h>
#endif // GUM_O3PRMBNWRITER_H
/***************************************************************************
* Copyright (C) 2005 by Pierre-Henri WUILLEMIN et Christophe GONZALES *
* {prenom.nom}_at_lip6.fr *
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program; if not, write to the *
* Free Software Foundation, Inc., *
* 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
***************************************************************************/
#ifndef DOXYGEN_SHOULD_SKIP_THIS
#include <agrum/PRM/o3prm/O3prmBNWriter.h>
namespace gum {
/*
* Default constructor.
*/
template <typename GUM_SCALAR>
INLINE O3prmBNWriter<GUM_SCALAR>::O3prmBNWriter() {
GUM_CONSTRUCTOR( O3prmBNWriter );
}
/*
* Destructor.
*/
template <typename GUM_SCALAR>
INLINE O3prmBNWriter<GUM_SCALAR>::~O3prmBNWriter() {
GUM_DESTRUCTOR( O3prmBNWriter );
}
/*
* Writes a bayes net in the given ouput stream.
*
* @param output The output stream.
* @param bn The bayes net writen in the stream.
* @throws IOError Raised if an I/O error occurs.
*/
template <typename GUM_SCALAR>
INLINE void O3prmBNWriter<GUM_SCALAR>::write( std::ostream& output,
const IBayesNet<GUM_SCALAR>& bn ) {
if ( !output.good() ) {
GUM_ERROR( IOError, "Stream states flags are not all unset." );
}
output << "class BayesNet {" << std::endl;
for ( auto node : bn.nodes() ) {
output << __extractAttribute( bn, node ) << std::endl;
}
output << "}" << std::endl;
output << std::endl;
output.flush();
if ( output.fail() ) {
GUM_ERROR( IOError, "Writing in the ostream failed." );
}
}
template <typename GUM_SCALAR>
INLINE std::string
O3prmBNWriter<GUM_SCALAR>::__extractAttribute( const IBayesNet<GUM_SCALAR>& bn, NodeId node ) {
std::stringstream str;
str << __extractType(bn, node) << " ";
str << __extractName(bn, node) << " ";
if (bn.dag().parents(node).size() > 0) {
str << "dependson " << __extractParents(bn, node) << " ";
}
str << " {" << __extractCPT(bn, node) << "};" << std::endl;
return str.str();
}
template <typename GUM_SCALAR>
INLINE std::string
O3prmBNWriter<GUM_SCALAR>::__extractParents(const IBayesNet<GUM_SCALAR>& bn, NodeId node) {
std::stringstream str;
auto var =&( bn.variable(node));
for (auto parent: bn.cpt(node).variablesSequence()) {
if (var != parent) {
str << parent->name() << ", ";
}
}
return str.str().substr(0, str.str().size() - 2);
}
template <typename GUM_SCALAR>
INLINE std::string
O3prmBNWriter<GUM_SCALAR>::__extractCPT(const IBayesNet<GUM_SCALAR>& bn, NodeId node) {
Instantiation inst( bn.cpt(node) );
Instantiation jnst;
for ( auto var = inst.variablesSequence().rbegin();
var != inst.variablesSequence().rend();
--var ) {
jnst.add( **var );
}
std::stringstream str;
str << "[" << std::endl;
for (jnst.begin(); !jnst.end(); jnst.inc()) {
inst.setVals( jnst );
str << bn.cpt(node)[inst] << ", ";
}
str << std::endl;
return str.str().substr(0, str.str().size()-3) + "]";
//try {
//auto cpt = std::vector<GUM_SCALAR>();
//auto inst = Instantiation(bn.cpt(node));
//for (inst.setFirst(); ! inst.end(); inst.inc()) {
// cpt.push_back(bn.cpt(node)[inst]);
//}
//std::stringstream str;
//str << "[" << std::endl;
//for (size_t mod = 0; mod < bn.variable(node).domainSize(); ++mod) {
// for (size_t i = mod; i < cpt.size();i += bn.variable(node).domainSize()) {
// str << cpt[i] << ", ";
// }
// str << std::endl;
//}
//return str.str().substr(0, str.str().size()-3) + "]";
//} catch (gum::Exception& e) {
// GUM_SHOWERROR(e);
// throw e;
//}
// const auto& cpt = bn.cpt(node);
// auto inst = Instantiation(cpt);
// auto var = Instantiation();
// var.add(bn.variable(node));
// std::stringstream str;
// str << "[" << std::endl;
// //for (inst.setFirst(); ! inst.end(); inst.inc()) {
// // str << cpt[inst] << ", ";
// //}
// //str << std::endl;
// for (var.setFirst(); ! var.end(); var.inc()) {
// inst.setFirst();
// inst.setVals(var);
// for (;! inst.end(); inst.incOut(var)) {
// str << cpt[inst] << ", ";
// }
// str << std::endl;
// }
// return str.str().substr(0, str.str().size()-3) + "]";
}
template <typename GUM_SCALAR>
INLINE std::string
O3prmBNWriter<GUM_SCALAR>::__extractType(const IBayesNet<GUM_SCALAR>& bn, NodeId node) {
switch (bn.variable(node).varType()) {
case gum::DiscreteVariable::VarType::Discretized:
{
auto double_var = dynamic_cast<const DiscretizedVariable<double>*>(&(bn.variable(node)));
if ( double_var != nullptr ) {
return __extractDiscretizedType<DiscretizedVariable<double>>(double_var);
} else {
auto float_var = dynamic_cast<const DiscretizedVariable<float>*>(&(bn.variable(node)));
if ( float_var != nullptr ) {
return __extractDiscretizedType<DiscretizedVariable<float>>(float_var);
}
}
GUM_ERROR(InvalidArgument, "DiscretizedVariable ticks are neither doubles or floats");
}
case gum::DiscreteVariable::VarType::Range:
{
return __extractRangeType(bn, node);
}
default:
{
return __extractLabelizedType(bn, node);
}
}
}
template <typename GUM_SCALAR>
INLINE std::string
O3prmBNWriter<GUM_SCALAR>::__extractRangeType(const IBayesNet<GUM_SCALAR>& bn, NodeId node) {
const auto& var = static_cast<const RangeVariable&>(bn.variable(node));
std::stringstream str;
str << "int (" << var.minVal() << ", " << var.maxVal() << ")";
return str.str();
}
template <typename GUM_SCALAR>
INLINE std::string
O3prmBNWriter<GUM_SCALAR>::__extractLabelizedType(const IBayesNet<GUM_SCALAR>& bn, NodeId node) {
std::stringstream str;
str << "labels(";
for (auto l: bn.variable(node).labels()) {
str << l << ", ";
}
return str.str().substr(0, str.str().size() - 2) + ")";
}
template <typename GUM_SCALAR>
template <typename VARTYPE>
INLINE std::string
O3prmBNWriter<GUM_SCALAR>::__extractDiscretizedType(const VARTYPE* var) {
std::stringstream str;
if ( var->ticks().size() > 3 ) {
str << "real(" << var->ticks()[0];
for (size_t i = 1; i < var->ticks().size() - 1; ++i) {
str << ", " << var->ticks()[i];
}
str << ")";
return str.str();
}
GUM_ERROR( InvalidArgument, "discretized variable does not have enough ticks");
}
template <typename GUM_SCALAR>
INLINE std::string
O3prmBNWriter<GUM_SCALAR>::__extractName(const IBayesNet<GUM_SCALAR>& bn, NodeId node) {
if (!bn.variable(node).name().empty()) {
return bn.variable(node).name();
} else {
std::stringstream str;
str << node;
return str.str();
}
}
/*
* Writes a bayes net in the file referenced by filePath.
* If the file doesn't exists, it is created.
* If the file exists, it's content will be erased.
*
* @param filePath The path to the file used to write the bayes net.
* @param bn The bayes net writen in the file.
* @throw IOError Raised if an I/O error occurs.
*/
template <typename GUM_SCALAR>
INLINE void O3prmBNWriter<GUM_SCALAR>::write( std::string filePath,
const IBayesNet<GUM_SCALAR>& bn ) {
std::ofstream output( filePath.c_str(), std::ios_base::trunc );
write( output, bn );
output.close();
if ( output.fail() ) {
GUM_ERROR( IOError, "Writing in the ostream failed." );
}
}
} /* namespace gum */
#endif // DOXYGEN_SHOULD_SKIP_THIS
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -120,37 +120,37 @@ void __split( const O3Label& value, O3Label& left, O3Label& right) {
}
}
O3Label __setAnonTypeName(O3Class& c, O3Position& pos, O3Type& t) {
std::stringstream name;
name << "__" << c.name();
O3Label __setAnonTypeName(O3Class& c, O3Label& name, O3Position& pos, O3Type& t) {
std::stringstream ss;
ss << "__" << c.name() << "_" << name.label();
for (auto& l: t.labels()) {
name << '_' << l.first.label();
ss << '_' << l.first.label();
}
name << "__";
ss << "__";
t.name().position() = pos;
t.name().label() = name.str();
t.name().label() = ss.str();
return t.name();
}
O3Label __setAnonTypeName(O3Class& c, O3Position& pos, O3IntType& t) {
std::stringstream name;
name << "__" << c.name();
name << "_" << t.start().value() << '_' << t.end().value();
name << "__";
O3Label __setAnonTypeName(O3Class& c, O3Label& name, O3Position& pos, O3IntType& t) {
std::stringstream ss;
ss << "__" << c.name() << "_" << name.label();
ss << "_" << t.start().value() << '_' << t.end().value();
ss << "__";
t.name().position() = pos;
t.name().label() = name.str();
t.name().label() = ss.str();
return t.name();
}
O3Label __setAnonTypeName(O3Class& c, O3Position& pos, O3RealType& t) {
std::stringstream name;
name << "__" << c.name();
O3Label __setAnonTypeName(O3Class& c, O3Label& name, O3Position& pos, O3RealType& t) {
std::stringstream ss;
ss << "__" << c.name() << "_" << name.label();
for (auto& v: t.values()) {
name << '_' << v.value();
ss << '_' << v.value();
}
name << "__";
ss << "__";
t.name().position() = pos;
t.name().label() = name.str();
t.name().label() = ss.str();
return t.name();
}
......@@ -314,27 +314,27 @@ CLASS_ANON_TYPE_ATTR<O3Class& c> =
(. pos.line() = t->line; .)
(. pos.column() = t->col; .)
(. auto type = O3Label(); .)
(. auto name = O3Label(); .)
(
(. auto t = O3Type(); .)
(. t.position() = pos; .)
labels '(' TYPE_VALUE_LIST<t.labels()> ')'
(. type = __setAnonTypeName(c, pos, t); .)
labels '(' TYPE_VALUE_LIST<t.labels()> ')' LABEL_OR_INT<name>
(. type = __setAnonTypeName(c, name, pos, t); .)
(. if ( __ok( n ) ) { __addO3Type( std::move(t) ); } .)
|
(. auto t = O3IntType(); .)
(. t.position() = pos; .)
INT_TYPE_DECLARATION<t.start(), t.end()>
(. type = __setAnonTypeName(c, pos, t); .)
INT_TYPE_DECLARATION<t.start(), t.end()> LABEL_OR_INT<name>
(. type = __setAnonTypeName(c, name, pos, t); .)
(. if ( __ok( n ) ) { __addO3IntType( std::move(t) ); } .)
|
(. auto t = O3RealType(); .)
(. t.position() = pos; .)
REAL_TYPE_DECLARATION<t.values()>
(. type = __setAnonTypeName(c, pos, t); .)
REAL_TYPE_DECLARATION<t.values()> LABEL_OR_INT<name>
(. type = __setAnonTypeName(c, name, pos, t); .)
(. if ( __ok( n ) ) { __addO3RealType( std::move(t) ); } .)
)
(. auto name = O3Label(); .)
LABEL<name>
ATTRIBUTE<type, name, c>
.
......@@ -356,7 +356,7 @@ ARRAY_REFERENCE_SLOT<O3Label& type, O3ReferenceSlotList& refs> =
(. auto name = O3Label(); .)
'['
']'
LABEL<name>
LABEL_OR_INT<name>
';'
(. refs.push_back( O3ReferenceSlot( type, name, isArray ) ); .)
.
......@@ -364,7 +364,7 @@ ARRAY_REFERENCE_SLOT<O3Label& type, O3ReferenceSlotList& refs> =
//________________________
NAMED_CLASS_ELEMENT<O3Label& type, O3Class& c> =
(. auto name = O3Label(); .)
LABEL<name>
LABEL_OR_INT<name>
(
REFERENCE_SLOT<type, name, c>
|
......@@ -1005,8 +1005,13 @@ CAST<std::stringstream& s> =
//________________________
LINK<std::stringstream& s> =
label
(. s << narrow( t->val ); .)
(
label
(. s << narrow( t->val ); .)
|
integer
(. s << narrow( t->val ); .)
)
.
//________________________
......
......@@ -301,11 +301,6 @@ namespace gum {
*/
virtual const std::vector<double>& history() const = 0;
/**
* @brief Configuration transmission.
* @param cfg The configuration to copy.
*/
void copyConfiguration( const IApproximationSchemeConfiguration& cfg );
};
} // namespace gum
......
......@@ -33,12 +33,6 @@
namespace gum {
INLINE
void IApproximationSchemeConfiguration::copyConfiguration(
const IApproximationSchemeConfiguration& cfg ) {
GUM_TRACE( "COPYING CONFIGURATION" );
}
INLINE
std::string
IApproximationSchemeConfiguration::messageApproximationScheme() const {
......
......@@ -178,6 +178,9 @@ namespace gum {
/// from the index to the tick.
/// @throws NotFound
const T_TICKS& tick( Idx i ) const;
/// Return the list of ticks
const std::vector<T_TICKS>& ticks() const;
};
} /* namespace gum */
......
......@@ -36,7 +36,6 @@ namespace gum {
DiscreteVariable::_copy( aDRV );
for ( Idx i = 0; i < aDRV.__ticks_size; ++i ) {
GUM_TRACE_VAR( aDRV.__ticks[i] );
addTick( (T_TICKS)aDRV.__ticks[i] );
}
}
......@@ -296,6 +295,11 @@ namespace gum {
return s.str();
}
template <typename T_TICKS>
INLINE const std::vector<T_TICKS>& DiscretizedVariable<T_TICKS>::ticks() const {
return this->__ticks;
}
} /* namespace gum */
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
......@@ -24,6 +24,7 @@
#include <cxxtest/AgrumTestSuite.h>
#include <cxxtest/testsuite_utils.h>
#include <agrum/variables/discretizedVariable.h>
#include <agrum/PRM/PRM.h>
#include <agrum/PRM/o3prm/O3prm.h>
#include <agrum/PRM/o3prm/O3prmReader.h>
......@@ -2691,6 +2692,7 @@ namespace gum_tests {
TS_ASSERT_EQUALS( output.str(), "|2 col 6| Error : invalid declaration\n" );
TS_ASSERT_EQUALS( prm.classes().size(), (gum::Size)0 );
}
};
} // namespace gum_tests
......@@ -614,6 +614,28 @@ namespace gum_tests {
TS_ASSERT( prm.isType( "fr.agrum.t_degraded" ) );
}
void testRangeType() {
// Arrange
std::stringstream input;
input << "type range int(1, 10);" << std::endl;
std::stringstream output;
gum::prm::PRM<double> prm;
auto factory = gum::prm::o3prm::O3prmReader<double>( prm );
// Act
TS_GUM_ASSERT_THROWS_NOTHING( factory.parseStream( input, output ) );
// Assert
TS_ASSERT_EQUALS( output.str(), "" );
TS_ASSERT_EQUALS( prm.types().size(), (gum::Size)2 );
TS_ASSERT( prm.isType( "range" ) );
const auto& range = prm.type( "range" );
TS_ASSERT_EQUALS( range.variable().labels().size(), (gum::Size)10 );
TS_ASSERT_EQUALS( range.variable().varType(),gum::DiscreteVariable::VarType::Range);
TS_ASSERT_EQUALS( range.variable().labels().at( 0 ), "1" );
TS_ASSERT_EQUALS( range.variable().labels().at( 1 ), "2" );