Commit d6f089d5 authored by Christophe Gonzales's avatar Christophe Gonzales

new learning databases

parent 6850bccd
Pipeline #20194481 failed with stages
in 26 minutes and 37 seconds
......@@ -466,6 +466,14 @@ namespace gum {
*/
class SyntaxError;
/**
* @class gum::NotImplementedYet agrum/core/exceptions.h
* @extends gum::Exception
* Exception : there is something wrong with an arc
*/
class NotImplementedYet;
#ifndef DOXYGEN_SHOULD_SKIP_THIS
const std::string __createMsg(const std::string& filename,
const std::string& function,
......@@ -473,6 +481,7 @@ namespace gum {
const std::string& msg);
GUM_MAKE_ERROR(IdError, Exception, "ID error")
GUM_MAKE_ERROR(FatalError, Exception, "Fatal error")
GUM_MAKE_ERROR(NotImplementedYet, Exception, "Not implemented yet")
GUM_MAKE_ERROR(UndefinedIteratorValue, Exception, "Undefined iterator")
GUM_MAKE_ERROR(UndefinedIteratorKey, Exception, "Undefined iterator's key")
GUM_MAKE_ERROR(NullElement, Exception, "Null element")
......
/***************************************************************************
* Copyright (C) 2005 by Pierre-Henri WUILLEMIN and Christophe GONZALES *
* {prenom.nom}_at_lip6.fr *
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program; if not, write to the *
* Free Software Foundation, Inc., *
* 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
***************************************************************************/
/**
* @file
* @brief C++11 threads convenience utilities for agrum.
* @author Christophe GONZALES and Pierre-Henri WUILLEMIN
*/
// to ease automatic parsers
#include <agrum/config.h>
#include <agrum/core/thread.h>
// include the inlined functions if necessary
#ifdef GUM_NO_INLINE
#include <agrum/core/thread_inl.h>
#endif /* GUM_NO_INLINE */
/***************************************************************************
* Copyright (C) 2005 by Pierre-Henri WUILLEMIN and Christophe GONZALES *
* {prenom.nom}_at_lip6.fr *
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program; if not, write to the *
* Free Software Foundation, Inc., *
* 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
***************************************************************************/
/**
* @file
* @brief C++11 threads convenience utilities for agrum.
* @author Christophe GONZALES and Pierre-Henri WUILLEMIN
*/
#ifndef GUM_THREAD_H
#define GUM_THREAD_H
#include <thread>
namespace gum {
namespace thread {
/**
* @brief returns the maximum number of threads possible
* @ingroup basicstruct_group
*
* @return Returns the number of concurrent threads supported by the
* implementation. The value should be considered only a hint.
*/
unsigned int getMaxNumberOfThreads();
}
} /* namespace gum */
// include the inlined functions if necessary
#ifndef GUM_NO_INLINE
#include <agrum/core/thread_inl.h>
#endif /* GUM_NO_INLINE */
#endif /* GUM_THREADS */
/***************************************************************************
* Copyright (C) 2005 by Pierre-Henri WUILLEMIN and Christophe GONZALES *
* {prenom.nom}_at_lip6.fr *
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program; if not, write to the *
* Free Software Foundation, Inc., *
* 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
***************************************************************************/
/**
* @file
* @brief C++11 threads convenience utilities for agrum.
* @author Christophe GONZALES and Pierre-Henri WUILLEMIN
*/
// to ease automatic parsers
#include <agrum/agrum.h>
#include <agrum/core/thread.h>
namespace gum {
namespace thread {
// returns the maximum number of threads possible
INLINE unsigned int getMaxNumberOfThreads () {
return std::thread::hardware_concurrency();
}
} /* namespace thread */
} /* namespace gum */
This diff is collapsed.
This diff is collapsed.
......@@ -38,12 +38,11 @@
#include <agrum/core/sequence.h>
#include <agrum/graphs/DAG.h>
#include <agrum/learning/database/DBCellTranslators/cellTranslatorCompactIntId.h>
#include <agrum/learning/database/DBCellTranslators/cellTranslatorUniversal.h>
#include <agrum/learning/database/DBRowTranslatorSet.h>
#include <agrum/learning/database/DBTransformCompactInt.h>
#include <agrum/learning/database/databaseFromCSV.h>
#include <agrum/learning/database/filteredRowGenerators/rowGeneratorIdentity.h>
#include <agrum/learning/database/DBTranslator4LabelizedVariable.h>
#include <agrum/learning/database/DBRowGeneratorParser.h>
#include <agrum/learning/database/DBInitializerFromCSV.h>
#include <agrum/learning/database/databaseTable.h>
#include <agrum/learning/database/DBRowGeneratorParser.h>
#include <agrum/learning/scores_and_tests/scoreAIC.h>
#include <agrum/learning/scores_and_tests/scoreBD.h>
......@@ -87,9 +86,6 @@ namespace gum {
namespace learning {
/// reads a file and returns a databaseVectInRam
DatabaseVectInRAM readFile(const std::string& filename);
class BNLearnerListener;
/** @class genericBNLearner
......@@ -120,6 +116,7 @@ namespace gum {
MIIC_THREE_OFF_TWO
};
/// a helper to easily read databases
class Database {
public:
......@@ -130,7 +127,7 @@ namespace gum {
/// default constructor
explicit Database(const std::string& file);
explicit Database(const DatabaseVectInRAM& db);
explicit Database(const DatabaseTable<>& db);
/// default constructor with defined modalities for some variables
/**
......@@ -144,10 +141,12 @@ namespace gum {
* @param check_database If true, the database will be checked.
*
*/
/*
Database(std::string filename,
const NodeProperty< Sequence< std::string > >& modalities,
bool check_database = false);
*/
/// default constructor for the aprioris
/** We must ensure that, when reading the apriori database, if the
* "apriori" rowFilter says that a given variable has value i
......@@ -155,7 +154,7 @@ namespace gum {
* apriori database is the same as in the score/parameter database
* read before creating the apriori. This is compulsory to have
* aprioris that make sense. */
Database(std::string filename, Database& score_database);
Database(const std::string& filename, Database& score_database);
/// default constructor for the aprioris
/** We must ensure that, when reading the apriori database, if the
......@@ -174,10 +173,15 @@ namespace gum {
* of id 1 in the BN will have 3 modalities, the first one being True,
* the second one being False, and the third bein Big.
*/
Database(std::string filename,
Database& score_database,
const NodeProperty< Sequence< std::string > >& modalities);
template < typename GUM_SCALAR >
Database( const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& bn);
template < typename GUM_SCALAR >
Database( const std::string& filename,
Database& score_database,
const gum::BayesNet< GUM_SCALAR >& bn);
/// copy constructor
Database(const Database& from);
......@@ -207,11 +211,8 @@ namespace gum {
// ########################################################################
/// @{
/// returns the row filter
DBRowFilter< DatabaseVectInRAM::Handler,
DBRowTranslatorSet< CellTranslatorCompactIntId >,
FilteredRowGeneratorSet< RowGeneratorIdentity > >&
rowFilter();
/// returns the parser for the database
DBRowGeneratorParser<>& parser ();
/// returns the modalities of the variables
std::vector< Size >& modalities() noexcept;
......@@ -225,35 +226,17 @@ namespace gum {
/// returns the variable name corresponding to a given node id
const std::string& nameFromId(NodeId id) const;
/// returns the "raw" translators (needed for the aprioris)
/** We must ensure that, when reading the apriori database, if the
* "apriori" rowFilter says that a given variable has value i
* (given by its fast translator), the corresponding "raw" value in the
* apriori database is the same as in the score/parameter database
* read before creating the apriori. This is compulsory to have
* aprioris that make sense. */
DBRowTranslatorSet< CellTranslatorUniversal >& rawTranslators();
/// returns the internal database table
const DatabaseTable<>& databaseTable () const;
/// @}
protected:
/// the database itself
DatabaseVectInRAM __database;
/// the rwo translators
DBRowTranslatorSet< CellTranslatorUniversal > __raw_translators;
DatabaseTable<> __database;
/// the translators used for reading the database
DBRowTranslatorSet< CellTranslatorCompactIntId > __translators;
/// the generators used for reading the database
FilteredRowGeneratorSet< RowGeneratorIdentity > __generators;
/// the filtered row that reads the database
DBRowFilter< DatabaseVectInRAM::Handler,
DBRowTranslatorSet< CellTranslatorCompactIntId >,
FilteredRowGeneratorSet< RowGeneratorIdentity > >*
__row_filter{nullptr};
/// the parser used for reading the database
DBRowGeneratorParser<>* __parser { nullptr };
/// the modalities of the variables
std::vector< Size > __modalities;
......@@ -270,8 +253,21 @@ namespace gum {
/// the minimal number of rows to parse (on average) by thread
Size __min_nb_rows_per_thread{100};
private:
// returns the set of variables as a BN. This is convenient for
// the constructors of apriori Databases
template < typename GUM_SCALAR >
BayesNet<GUM_SCALAR> __BNVars () const;
};
public:
// ##########################################################################
/// @name Constructors / Destructors
......@@ -284,7 +280,7 @@ namespace gum {
* names
*/
genericBNLearner(const std::string& db);
genericBNLearner(const DatabaseVectInRAM& db);
genericBNLearner(const DatabaseTable<>& db);
/**
* read the database file for the score / parameter estimation and var
......@@ -305,10 +301,10 @@ namespace gum {
* as being exactly those of the variables of the BN (as a consequence,
* if we find other values in the database, an exception will be raised
* during learning). */
genericBNLearner(const std::string& filename,
const NodeProperty< Sequence< std::string > >& modalities,
bool parse_database = false);
template < typename GUM_SCALAR >
genericBNLearner(const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& src );
/// copy constructor
genericBNLearner(const genericBNLearner&);
......@@ -586,8 +582,11 @@ namespace gum {
const ApproximationScheme* __current_algorithm{nullptr};
/// reads a file and returns a databaseVectInRam
static DatabaseVectInRAM __readFile(const std::string& filename);
static DatabaseTable<> __readFile(const std::string& filename);
/// checks whether the extension of a CSV filename is correct
static void __checkFileName(const std::string& filename);
/// create the apriori used for learning
void __createApriori();
......@@ -905,4 +904,6 @@ namespace gum {
#include <agrum/learning/BNLearnUtils/genericBNLearner_inl.h>
#endif /* GUM_NO_INLINE */
#include <agrum/learning/BNLearnUtils/genericBNLearner_tpl.h>
#endif /* GUM_LEARNING_GENERIC_BN_LEARNER_H */
......@@ -34,11 +34,9 @@ namespace gum {
namespace learning {
// returns the row filter
INLINE DBRowFilter< DatabaseVectInRAM::Handler,
DBRowTranslatorSet< CellTranslatorCompactIntId >,
FilteredRowGeneratorSet< RowGeneratorIdentity > >&
genericBNLearner::Database::rowFilter() {
return *__row_filter;
INLINE DBRowGeneratorParser<>&
genericBNLearner::Database::parser () {
return *__parser;
}
// returns the modalities of the variables
......@@ -57,7 +55,7 @@ namespace gum {
genericBNLearner::Database::idFromName(const std::string& var_name) const {
try {
return __name2nodeId.second(const_cast< std::string& >(var_name));
} catch (gum::NotFound&) {
} catch (gum::NotFound) {
GUM_ERROR(MissingVariableInDatabase, "for variable " << var_name);
}
}
......@@ -68,12 +66,14 @@ namespace gum {
return __name2nodeId.first(id);
}
// returns the "raw" translators (needed for the aprioris)
INLINE DBRowTranslatorSet< CellTranslatorUniversal >&
genericBNLearner::Database::rawTranslators() {
return __raw_translators;
/// returns the internal database table
INLINE const DatabaseTable<>&
genericBNLearner::Database::databaseTable () const {
return __database;
}
// ===========================================================================
// returns the node id corresponding to a variable name
......@@ -150,7 +150,7 @@ namespace gum {
GUM_ERROR(OperationNotAllowed, "Must be using the 3off2 algorithm");
}
__mutual_info = new CorrectedMutualInformation<>(
__score_database.rowFilter(), __score_database.modalities());
__score_database.parser(), __score_database.modalities());
__mutual_info->useNML();
}
/// indicate that we wish to use the MDL correction for 3off2
......@@ -159,7 +159,7 @@ namespace gum {
GUM_ERROR(OperationNotAllowed, "Must be using the 3off2 algorithm");
}
__mutual_info = new CorrectedMutualInformation<>(
__score_database.rowFilter(), __score_database.modalities());
__score_database.parser(), __score_database.modalities());
__mutual_info->useMDL();
}
/// indicate that we wish to use the NoCorr correction for 3off2
......@@ -168,7 +168,7 @@ namespace gum {
GUM_ERROR(OperationNotAllowed, "Must be using the 3off2 algorithm");
}
__mutual_info = new CorrectedMutualInformation<>(
__score_database.rowFilter(), __score_database.modalities());
__score_database.parser(), __score_database.modalities());
__mutual_info->useNoCorr();
}
......
namespace gum {
namespace learning {
template < typename GUM_SCALAR >
genericBNLearner::Database::Database(
const std::string& filename,
const BayesNet< GUM_SCALAR >& bn ) {
// assign to each column name in the database its position
genericBNLearner::__checkFileName( filename );
DBInitializerFromCSV<> initializer ( filename );
const auto& xvar_names = initializer.variableNames ();
std::size_t nb_vars = xvar_names.size();
HashTable<std::string,std::size_t> var_names ( nb_vars );
for ( std::size_t i = std::size_t(0); i < nb_vars; ++i )
var_names.insert ( xvar_names[i], i );
// we use the bn to insert the translators into the database table
try {
for ( auto node : bn.dag () ) {
const Variable& var = bn.variable(node);
__database.insertTranslator ( var, var_names[var.name()] );
}
}
catch ( NotFound& ) {
GUM_ERROR ( MissingVariableInDatabase,
"the database does not contain variable " );
}
// fill the database
initializer.fillDatabase ( __database );
// get the domain sizes of the variables
__modalities = __database.domainSizes ();
nb_vars = __database.nbVariables ();
for ( std::size_t i = std::size_t(0); i < nb_vars; ++i )
__name2nodeId.insert ( __database.variable(i).name(), i );
// create the parser
__parser = new DBRowGeneratorParser<> ( __database.handler (),
DBRowGeneratorSet<> () );
}
template < typename GUM_SCALAR >
genericBNLearner::Database::Database(
const std::string& filename,
Database& score_database,
const BayesNet< GUM_SCALAR >& bn)
: __database(genericBNLearner::__readFile(filename,bn)) {
}
template < typename GUM_SCALAR >
BayesNet<GUM_SCALAR> genericBNLearner::Database::__BNVars () const {
BayesNet<GUM_SCALAR> bn;
const std::size_t nb_vars = __database.nbVariables ();
for ( std::size_t i = 0; i < nb_vars; ++i ) {
const DiscreteVariable& var = dynamic_cast<const DiscreteVariable&>
( __database.variable (i) );
bn.add ( var );
}
return bn;
}
template < typename GUM_SCALAR >
genericBNLearner::genericBNLearner(
const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& bn )
: __score_database ( filename, bn ) {
GUM_CONSTRUCTOR(genericBNLearner);
}
}
}
......@@ -66,7 +66,7 @@ namespace gum {
* names
*/
BNLearner(const std::string& filename);
BNLearner(const DatabaseVectInRAM& db);
BNLearner(const DatabaseTable<>& db);
/**
* @brief Read the database file for the score / parameter estimation and
......@@ -91,17 +91,18 @@ namespace gum {
* @param parse_database if true, the modalities specified by the user
* will be considered as a superset of the modalities of the variables.
*/
/*
BNLearner(const std::string& filename,
const NodeProperty< Sequence< std::string > >& modalities,
bool parse_database = false);
const NodeProperty< Sequence< std::string > >& modalities );
*/
/**
* Wrapper for BNLearner (filename,modalities,parse_database) using a bn
* to find those modalities and nodeids.
**/
BNLearner(const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& src,
bool parse_database = false);
const gum::BayesNet< GUM_SCALAR >& src );
/// copy constructor
BNLearner(const BNLearner&);
......
......@@ -33,7 +33,6 @@
#include <agrum/learning/BNLearner.h>
#include <agrum/learning/BNLearnUtils/BNLearnerListener.h>
#include <agrum/learning/database/CSVParser.h>
namespace gum {
......@@ -45,41 +44,29 @@ namespace gum {
}
template < typename GUM_SCALAR >
BNLearner< GUM_SCALAR >::BNLearner(const DatabaseVectInRAM& db)
BNLearner< GUM_SCALAR >::BNLearner(const DatabaseTable<>& db)
: genericBNLearner(db) {
GUM_CONSTRUCTOR(BNLearner);
}
template < typename GUM_SCALAR >
BNLearner< GUM_SCALAR >::BNLearner(
const std::string& filename,
const NodeProperty< Sequence< std::string > >& modalities,
bool parse_database)
: genericBNLearner(filename, modalities, parse_database) {
GUM_CONSTRUCTOR(BNLearner);
}
template < typename GUM_SCALAR >
BNLearner< GUM_SCALAR >::BNLearner(const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& src,
bool parse_database)
: BNLearner(filename,
BNLearner< GUM_SCALAR >::__labelsFromBN(filename, src),
parse_database) {
// GUM_CONSTRUCTOR in BNLearner(filename,modalities,parse_database)
const gum::BayesNet< GUM_SCALAR >& bn )
: genericBNLearner(filename, bn ) {
GUM_CONSTRUCTOR (BNLearner)
}
/// copy constructor
template < typename GUM_SCALAR >
BNLearner< GUM_SCALAR >::BNLearner(const BNLearner< GUM_SCALAR >& src)
: genericBNLearner(static_cast< const genericBNLearner& >(src)) {
: genericBNLearner( src ) {
GUM_CONSTRUCTOR(BNLearner);
}
/// move constructor
template < typename GUM_SCALAR >
BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src)
: genericBNLearner(static_cast< genericBNLearner&& >(src)) {
: genericBNLearner( src ) {
GUM_CONSTRUCTOR(BNLearner);
}
......@@ -100,8 +87,7 @@ namespace gum {
template < typename GUM_SCALAR >
BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::
operator=(const BNLearner< GUM_SCALAR >& src) {
static_cast< genericBNLearner* >(this)->operator=(
static_cast< const genericBNLearner& >(src));
genericBNLearner::operator=( src );
return *this;
}
......@@ -109,8 +95,7 @@ namespace gum {
template < typename GUM_SCALAR >
BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::
operator=(BNLearner< GUM_SCALAR >&& src) {
static_cast< genericBNLearner* >(this)->operator=(
static_cast< genericBNLearner&& >(src));
genericBNLearner::operator=( std::move (src) );
return *this;
}
......@@ -125,12 +110,12 @@ namespace gum {
return DAG2BNLearner::createBN<