Commit 1f1b380b authored by Jaroslaw Zola's avatar Jaroslaw Zola

Release 2.100

parent fc1c3863
2.100 - 2019.02.07
* Several network post-processing tools added.
* Parameters estimation via MLE added.
* Minor bug fixes.
2.001 - 2018.07.29
* Initial support for hard priors (aka structure constraints) added.
* New level of performance unlocked by OMP parallelism.
......
......@@ -2,10 +2,13 @@
**Authors:**
Subhadeep Karan <skaran@buffalo.edu>,
Jaroslaw Zola <jaroslaw.zola@hush.com>
**Contributors:**
John Demetros <johndeme@buffalo.edu>,
Matthew Eichhorn <maeichho@buffalo.edu>,
Blake Hurlburt <blakehur@buffalo.edu>,
Grant Iraci <grantira@buffalo.edu>,
Jaroslaw Zola <jaroslaw.zola@hush.com>
Grant Iraci <grantira@buffalo.edu>
## About
......@@ -54,10 +57,21 @@ The example below shows how to learn an exact BN under the MDL score using BFS w
2. Use the resulting MPS to find an optimal ordering: `./sabna-exsl-bfs-ope -n 26 --mps-file autos.mps --ord-file autos.ord`
3. Convert the resulting ordering into the corresponding structure in BIF format: `./sabna-order2net --csv-file autos.csv --mps-file autos.mps --ord-file autos.ord --net-name autos --format bif`
#### End-to-end example
The example below shows how to learn exact BN given some categorical data (with variables and their states represented by strings). We use `asia.csv` dataset with 8 variables and 200 observations (available in `data/`).
1. Convert input data into SABNA compatible format: `./csv-prepare -H True asia.csv asia`
2. Run SABNA learning steps on the resulting data:
* `./sabna-exsl-mpsbuild --csv-file asia/asia.sabna.csv --mps-file asia.sabna.mps`
* `./sabna-exsl-bfs-ope -n 8 --mps-file asia.sabna.mps --ord-file asia.sabna.ord`
* `./sabna-order2net --csv-file asia/asia.sabna.csv --mps-file asia.sabna.mps --ord-file asia.sabna.ord --net-name asia.sabna --format net`
3. Convert final network back to annotated format: `./net-format.py --map-variables asia/asia.sabna.variables --map-states asia/asia.sabna.states asia.sabna.net asia.net`
## Priors
SABNA provides a mechanism to specify _hard_ priors, which are specific constraints on network structure. Priors are provided in a text file, using simple but flexible syntax. See [data/priors.txt](data/priors.txt) for explanation. This feature is currently experimental.
SABNA provides a mechanism to specify _hard_ priors, which are specific constraints on network structure. Priors are provided in a text file, using simple but flexible syntax. See [data/priors.txt](data/priors.txt) for explanation.
## License
......
# C++/Python interface for the core SABNA functionality.
# SABNAtk: Fast Counting in Machine Learning Applications
**Authors:**
Subhadeep Karan <skaran@buffalo.edu>,
Matthew Eichhorn <maeichho@buffalo.edu>,
Blake Hurlburt <blakehur@buffalo.edu>,
Grant Iraci <grantira@buffalo.edu>,
Jaroslaw Zola <jaroslaw.zola@hush.com>
## About
SABNAtk is a small C++14 library, together with Python bindings, to efficiently execute counting queries over categorical data. Such queries are common to Machine Learning applications, for example, they show up in Probabilistic Graphical Models, regression analysis, etc.. In practical applications, SABNAtk [significantly outperforms](https://gitlab.com/SCoRe-Group/SABNAtk-Benchmarks) typical approaches based on e.g., hash tables or ADtrees. Currently, SABNAtk is powering [SABNA](https://gitlab.com/SCoRe-Group/SABNA-Release), our Bayesian networks learning engine. We are working on further improving performance, so stay tuned!
## User Guide
In preparation - we will provide extended documentation soon. Please, refer to `examples/` directory to see several use examples. If you have immediate questions, please do not hesitate to contact Jaric Zola <jaroslaw.zola@hush.com>.
## References
To cite SABNAtk, refer to this repository and our 2018 UAI paper:
* S. Karan, M. Eichhorn, B. Hurlburt, G. Iraci, J. Zola, _Fast Counting in Machine Learning Applications_, In Proc. Uncertainty in Artificial Intelligence (UAI), 2018. <https://arxiv.org/abs/1804.04640>.
......@@ -26,6 +26,7 @@
template <int N> class BVCounter {
public:
typedef uint_type<N> set_type;
typedef uint8_t data_type;
int n() const { return n_; }
......@@ -48,6 +49,7 @@ public:
F[i](m_);
F[i](data_[xi_vect[i]].first[r_id].weight(), m_);
}
return;
}
auto pa_vect = as_vector(pa);
......@@ -70,10 +72,9 @@ public:
}
} // apply (state specific queries)
template <typename score_functor>
void apply(const set_type& xi, const set_type& pa, std::vector<score_functor>& F) const {
std::vector<int> xi_vect = as_vector(xi);
template <typename score_functor>
void apply(const std::vector<int>& xi_vect, const set_type& pa, std::vector<score_functor>& F) const {
int qpa = m_q__(pa);
for (int i = 0; i < F.size(); ++i) F[i].init(r(xi_vect[i]), qpa);
......@@ -163,6 +164,21 @@ public:
for (int i = 0; i < F.size(); ++i) F[i].finalize(qi_obs);
} // apply
template <typename score_functor>
void apply(const set_type& xi, const set_type& pa, std::vector<score_functor>& F) const {
std::vector<int> xi_vect = as_vector(xi);
apply(xi_vect, pa, F);
} // apply
template <typename score_functor>
void apply(int xi, const set_type& pa, score_functor& F) const {
std::vector<int> xi_vect{xi};
std::vector<score_functor> F_vect{F};
apply(xi_vect, pa, F_vect);
F = F_vect[0];
} // apply
// reorder variables to improve expected query performance
bool reorder(const std::vector<int>& order) {
std::vector<std::vector<int>> temp_r_idx;
......
......@@ -20,9 +20,10 @@
#include <bit_util.hpp>
template <int N, typename data_type = uint8_t> class RadCounter {
template <int N, typename Data = uint8_t> class RadCounter {
public:
using set_type = uint_type<N>;
using data_type = Data;
using pair_data_type = std::pair<data_type, data_type>;
using pair_int = std::pair<int, int>;
......@@ -34,8 +35,9 @@ public:
bool is_reorderable() { return true; }
template <typename score_functor>
void apply(const set_type& set_xi, const set_type& pa, const std::vector<data_type>& state_xi, const std::vector<data_type>& state_pa, std::vector<score_functor>& F) {
void apply(const set_type& set_xi, const set_type& pa, const std::vector<data_type>& state_xi, const std::vector<data_type>& state_pa, std::vector<score_functor>& F) const {
std::vector<pair_data_type> state_range_pa(state_pa.size());
std::vector<pair_data_type> state_range_xi(state_xi.size());
......@@ -45,8 +47,9 @@ public:
m_radcounter_core__(false, set_xi, pa, state_range_xi, state_range_pa, F);
} // apply (state_specific_queries)
template <typename score_functor>
void apply(const set_type& set_xi, const set_type& pa, std::vector<score_functor>& F) {
void apply(const set_type& set_xi, const set_type& pa, std::vector<score_functor>& F) const {
std::vector<pair_data_type> state_pa(set_size(pa));
std::vector<pair_data_type> state_xi(F.size());
......@@ -58,6 +61,23 @@ public:
m_radcounter_core__(true, set_xi, pa, state_xi, state_pa, F);
} // apply
template <typename score_functor>
void apply(const std::vector<int>& xi_vect, const set_type& pa, std::vector<score_functor>& F) const {
set_type set_xi = set_empty<set_type>();
for (auto xi : xi_vect) set_xi = set_add(set_xi, xi);
apply(set_xi, pa, F);
} // apply
template <typename score_functor>
void apply(int xi, const set_type& pa, score_functor& F) const {
set_type set_xi = set_empty<set_type>();
set_xi = set_add(set_xi, xi);
std::vector<score_functor> F_vect{F};
apply(set_xi, pa, F_vect);
F = F_vect[0];
} // apply
bool reorder(const std::vector<int>& norder) {
std::vector<data_type> r_temp;
std::vector<data_type> D_temp;
......@@ -89,7 +109,7 @@ private:
friend RadCounter <M, data_type_copy> create_RadCounter(int n, int m, Iter it);
template<typename score_functor>
void m_radcounter_core__(bool skip_unique_row, const set_type& set_xi, const set_type pa, const std::vector<pair_data_type>& state_xi, const std::vector<pair_data_type>& state_pa, std::vector<score_functor>& F) {
void m_radcounter_core__(bool skip_unique_row, const set_type& set_xi, const set_type pa, const std::vector<pair_data_type>& state_xi, const std::vector<pair_data_type>& state_pa, std::vector<score_functor>& F) const {
int q_pa = m_compute_q_pa__(pa, state_pa);
int pa_size = set_size(pa);
......@@ -154,7 +174,7 @@ private:
} // m_rad_counter_core__
template<typename score_functor>
int m_radix_sort_core__ (bool skip_unique_row, int xi, bool is_lsb, const pair_data_type& range_r, const std::vector<pair_int>& ro_bracket, std::vector<pair_int>& wo_bracket, std::vector<std::vector<int>>& Rxi, std::vector<int>& row_id, score_functor& F) {
int m_radix_sort_core__ (bool skip_unique_row, int xi, bool is_lsb, const pair_data_type& range_r, const std::vector<pair_int>& ro_bracket, std::vector<pair_int>& wo_bracket, std::vector<std::vector<int>>& Rxi, std::vector<int>& row_id, score_functor& F) const {
int idx = 0;
int count_unique_row = 0;
......@@ -202,7 +222,7 @@ private:
int m_;
int max_r_ = -1;
std::vector<data_type> r_;
std::vector<data_type> r_; // we do not expect more than 255 states
std::vector<int> idx_r_;
std::vector<int> max_count_r_;
// stores the row id's for each xi in state 'r'
......@@ -213,7 +233,7 @@ private:
}; // class RadCounter
template <int N, typename data_type, typename Iter>
template <int N, typename data_type = uint8_t, typename Iter>
RadCounter<N, data_type> create_RadCounter(int n, int m, Iter it) {
RadCounter<N, data_type> rad;
......@@ -226,18 +246,19 @@ RadCounter<N, data_type> create_RadCounter(int n, int m, Iter it) {
for (int i = 0; i < n * m; ++i, ++it) { rad.D_.push_back(*it); }
for (int xi = 0, r_sum = 0; xi < n; ++xi) {
auto min_max = std::minmax_element(rad.D_.begin() + xi * m, rad.D_.begin() + (xi+1) * m);
std::transform(rad.D_.begin() + xi * m, rad.D_.begin() + (xi+1) * m, rad.D_.begin() + xi * m, [min_max](data_type x) { return x - *(min_max.first); } );
auto min_max = std::minmax_element(rad.D_.begin() + xi * m, rad.D_.begin() + (xi + 1) * m);
std::transform(rad.D_.begin() + xi * m, rad.D_.begin() + (xi + 1) * m, rad.D_.begin() + xi * m, [min_max](data_type x) { return x - *(min_max.first); } );
rad.r_[xi] = *min_max.second - *min_max.first + 1;
rad.max_r_ = rad.max_r_ > rad.r_[xi] ? rad.max_r_ : rad.r_[xi];
rad.max_r_ = std::max<int>(rad.max_r_, rad.r_[xi]);
rad.idx_r_[xi] = r_sum;
r_sum += rad.r_[xi];
}
rad.Rxi_.resize(rad.idx_r_[n - 1] + rad.r_[n - 1]);
rad.max_count_r_.resize(rad.max_r_, -1);
for (int xi = 0; xi < n; ++xi) {
for (int i = 0; i < m; ++i) { rad.Rxi_[rad.idx_r_[xi] + rad.D_[ xi * m + i]].push_back(i); }
for (int i = 0; i < m; ++i) { rad.Rxi_[rad.idx_r_[xi] + rad.D_[xi * m + i]].push_back(i); }
for (int r = 0, idx = 0; r < rad.r_[xi]; ++r) {
const int temp = rad.Rxi_[rad.idx_r_[xi] + r].size();
rad.max_count_r_[r] = rad.max_count_r_[r] > temp ? rad.max_count_r_[r] : temp;
......
......@@ -3,3 +3,4 @@
| n | m | File | Categories (min/max) |
| --- + ----- + -------------- + --------------------- |
| 26 | 159 | autos.csv | 2/2 |
| 8 | 200 | asia.csv | 2/2 |
asia,tub,smoke,lung,bronc,either,xray,dysp
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,yes,no,no,no
no,no,yes,no,no,no,no,yes
no,no,yes,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,yes,no,yes,yes,yes
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,no
no,no,no,no,no,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,no
yes,no,no,no,no,no,no,no
no,no,yes,yes,yes,yes,yes,yes
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,yes,no
no,no,yes,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,yes,no,yes,no,no,yes
no,no,yes,no,no,no,no,no
no,no,yes,no,no,no,no,yes
no,no,yes,no,no,no,no,no
no,no,no,no,yes,no,no,yes
no,no,yes,no,yes,no,no,yes
no,no,yes,yes,no,yes,yes,yes
no,no,yes,no,yes,no,no,yes
no,no,yes,no,yes,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,yes,no,yes,no,no,no
no,no,yes,no,no,no,no,no
no,no,no,no,yes,no,no,yes
no,no,yes,no,yes,no,no,yes
no,no,yes,no,yes,no,no,yes
no,no,yes,no,no,no,no,no
no,no,no,no,yes,no,no,no
no,no,yes,yes,no,yes,yes,yes
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,yes,no,yes,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,yes
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,no
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,yes,no
no,no,yes,no,no,no,yes,no
no,no,yes,no,no,no,no,no
yes,no,yes,no,yes,no,no,yes
no,no,no,no,no,no,no,no
yes,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,yes,no,no,yes
no,no,yes,no,no,no,no,no
no,no,no,no,no,no,no,yes
no,no,yes,no,yes,no,no,no
no,no,yes,no,no,no,yes,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,no
no,no,yes,no,no,no,no,yes
no,no,no,no,no,no,yes,no
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,no
no,no,no,no,yes,no,no,yes
no,no,yes,no,yes,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,yes,no,no,no,no,no
no,no,no,no,yes,no,no,yes
no,no,yes,no,no,no,no,no
no,no,yes,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,yes
no,no,yes,no,no,no,yes,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,yes,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,yes,yes
no,no,no,no,no,no,no,no
no,yes,no,no,no,yes,yes,no
no,yes,no,no,yes,yes,yes,yes
no,no,yes,yes,yes,yes,yes,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,yes
no,no,yes,no,yes,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,yes,no
no,no,yes,no,yes,no,no,no
no,no,no,no,no,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,yes,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,no
no,no,no,no,yes,no,no,no
no,no,yes,yes,yes,yes,yes,yes
no,no,yes,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,yes
no,no,yes,no,yes,no,no,no
no,no,yes,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,yes,no,yes,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,yes,yes,yes,yes,yes,yes
no,no,yes,no,yes,no,no,yes
no,no,no,no,yes,no,no,no
no,no,yes,yes,yes,yes,yes,yes
no,no,no,no,no,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,yes,yes,yes,yes,yes,yes
yes,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,no,no,yes,no,yes,no
no,no,yes,no,yes,no,no,no
no,no,yes,no,no,no,no,no
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,yes,no,no,no,no,no
no,no,no,yes,no,yes,yes,no
no,no,yes,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,no,no
no,no,yes,no,yes,no,yes,no
no,no,yes,no,no,no,no,no
no,no,yes,no,no,no,no,yes
no,no,no,no,no,no,no,no
no,no,yes,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,yes,no,no,yes
no,no,no,no,yes,no,no,yes
no,no,no,no,no,no,no,no
no,no,no,no,no,no,no,no
no,no,yes,no,no,no,yes,no
no,no,no,no,no,no,no,no
......@@ -75,7 +75,7 @@ private:
} // m_extend_and_insert__
void m_max_pa_size__() {
if ((this->pa_size < 1) || (this->pa_size > n_ - 1)) this->pa_size = n_;
if ((this->pa_size < 1) || (this->pa_size > n_ - 1)) this->pa_size = n_ - 1;
for (int i = 0; i <= this->pa_size; ++i) md_[i] = X_;
Log.info() << "maximal number of parents limited to " << this->pa_size << std::endl;
} // m_max_pa_size_
......
......@@ -48,6 +48,9 @@ public:
Log.debug() << "MDLEngine ready!" << std::endl;
} // MDLEngine
MDLEngine(CountingQueryEngine& cqe, MDL& mdl, int ps = -1)
: MDLEngine(cqe, mdl, Priors<N>(), ps) { }
int n() const { return n_; }
int m() const { return m_; }
......@@ -79,7 +82,7 @@ public:
++idx;
}
// if we EXTEND than MPS may not be able to answer all queries d()
// if we EXTEND then MPS may not be able to answer all queries d()
// we fix it by inserting emptyset with infinity
if (!is_emptyset(ch_ext) && is_emptyset(pa)) {
for (int xi = 0; xi < n_; ++xi) {
......
......@@ -252,8 +252,7 @@ public:
} // check_sanitize
private:
const double INF = std::numeric_limits<double>::max();
const MPSNode temp_mps_{INF, set_empty<set_type>()};
const MPSNode temp_mps_{SABNA_DBL_INFTY, set_empty<set_type>()};
std::vector<std::vector<MPSNode>> mps_list_;
int n_ = -1;
......
......@@ -27,9 +27,18 @@
template <int N>
class Priors {
public:
using set_type = uint_type<N>;
enum Result { IGNORE, EXTEND, ADD_AND_EXTEND };
using set_type = uint_type<N>;
enum prior_class { EDGE, NOEDGE };
struct prior_type {
int s;
int t;
prior_class R;
}; // struct prior_type
bool empty() const { return allowed_.empty(); }
......@@ -90,9 +99,28 @@ public:
} // check
private:
enum prior_class { EDGE, NOEDGE };
// the list is not reordered (i.e., CQE ordering is not considered)
// this is original list taken from input file
const std::vector<prior_type>& list() const { return priors_; }
std::vector<prior_type> list_compacted() const {
std::vector<prior_type> L;
int n = allowed_.size();
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
if (i != j) {
if (in_set(required_[i], j)) L.push_back(prior_type{j, i, EDGE});
else if (!in_set(allowed_[i], j)) L.push_back(prior_type{j, i, NOEDGE});
}
}
}
return L;
} // list_compacted
private:
static prior_class m_str2prior__(const std::string& s) {
if (s == "-") return EDGE;
return NOEDGE;
......@@ -103,13 +131,6 @@ private:
return "!";
} // m_prior2str__
struct prior_type {
int s;
int t;
prior_class R;
}; // struct prior_type
friend std::ostream& operator<<(std::ostream& os, const prior_type& pt) {
os << pt.s << " " << Priors::m_prior2str__(pt.R) << " " << pt.t;
return os;
......
......@@ -52,6 +52,13 @@ std::tuple<bool, int, int> read_csv(std::ifstream& f, std::vector<T>& data) {
else if (m != l) return std::make_tuple(false, -1, -1);
} // for it
// sanity check
for (int i = 0; i < n; ++i) {
auto mm = std::minmax_element(data.data() + i * m, data.data() + (i + 1) * m);
int d = *mm.second - *mm.first;
if ((d < 1) || (d > 254)) return std::make_tuple(false, -1, -1);
}
return std::make_tuple(true, n, m);
} // read_csv
......
/***
* $Id$
**
* File: eval.hpp
* Created: Sep 12, 2018
*
* Author: Jaroslaw Zola <jaroslaw.zola@hush.com>
* Copyright (c) 2018 SCoRe Group http://www.score-group.org/
* Distributed under the MIT License.
* See accompanying file LICENSE.
*/
#ifndef EVAL_HPP
#define EVAL_HPP
#include <unordered_map>
#include <vector>
#include "bit_util.hpp"
template <int N> class Eval {
public:
explicit Eval(int n) : cache_(n) { }
template <typename ScoreFunction, typename CQE>
double operator()(const std::vector<uint_type<N>>& net, CQE& cqe, ScoreFunction F) {
double S = 0.0;
double s = 0.0;
for (int xi = 0; xi < net.size(); ++xi) {
auto pos = cache_[xi].find(net[xi]);
if (pos != cache_[xi].end()) s = pos->second;
else {
cqe.apply(xi, net[xi], F);
s = std::get<0>(F.score());
cache_[xi].insert({net[xi], s});
}
S += s;
}
return S;
} // operator()
private:
std::vector<std::unordered_map<uint_type<N>, double, uint_hash>> cache_;
}; // class Eval
#endif // EVAL_HPP
This diff is collapsed.
......@@ -37,7 +37,7 @@ namespace jaz {
else if (sz < (1024 * 1024)) {
ss << std::setprecision(4) << (static_cast<float>(sz) / 1024) << "KB";
} else if (sz < (1024 * 1024 * 1024)) {
ss << std::setprecision(4) << (static_cast<float>(sz) / (1024 * 1024)) << "MB";
ss << std::setprecision(4) << (static_cast<float>(sz) / (1024 * 1024)) << "MB";
} else {
ss << std::setprecision(4) << (static_cast<float>(sz) / (1024 * 1024 * 1024)) << "GB";
}
......
......@@ -48,6 +48,42 @@ namespace jaz {
} // split
/** Function: join
* Merges a sequence of strings into a single string.
*
* Parameters:
* pat - Separator string.
* first - Beginning of the sequence to join.
* last - End of the sequence to join.
* init - Prefix to add to the sequence.
*/
template <typename Iter, typename charT, typename traits, typename Alloc>
std::basic_string<charT, traits, Alloc>
join(const charT* pat, Iter first, Iter last, const std::basic_string<charT, traits, Alloc>& init) {
std::basic_string<charT, traits, Alloc> s(init);
if (s.empty() == true) s = *(first++);
for (; first != last; ++first) s += std::string(pat) + std::basic_string<charT, traits, Alloc>(*first);
return s;
} // join
/** Function: join
* Merges a sequence of strings into a single string.
*
* Parameters:
* pat - Separator string.
* first - Beginning of the sequence to join.
* last - End of the sequence to join.
* init - Prefix to add to the sequence.
*/
template <typename Iter, typename charT, typename traits, typename Alloc>
std::basic_string<charT, traits, Alloc>
join(const std::basic_string<charT, traits, Alloc>& pat, Iter first, Iter last, const std::basic_string<charT, traits, Alloc>& init = "") {
std::basic_string<charT, traits, Alloc> s(init);
if (s.empty() == true) s = *(first++);
for (; first != last; ++first) s += pat + std::basic_string<charT, traits, Alloc>(*first);
return s;
} // join
/** Function: join
* Merges a sequence of strings into a single string.
*
......@@ -66,6 +102,12 @@ namespace jaz {
return s;
} // join
/** Function: join
*/
template <typename Iter> inline std::string join(const char* pat, Iter first, Iter last) {
return join(pat, first, last, std::string(""));
} // join
/** Function: join
*/
template <typename Iter> inline std::string join(char pat, Iter first, Iter last) {
......
/***
* $Id$
**
* File: model_fit.hpp
* Created: Aug 04, 2018
*
* Author: Jaroslaw Zola <jaroslaw.zola@hush.com>
* Copyright (c) 2018 SCoRe Group http://www.score-group.org/
* Distributed under the MIT License.
* See accompanying file LICENSE.
*/
#ifndef MODEL_FIT_HPP
#define MODEL_FIT_HPP
#include "graph_util.hpp"
#include "omp.hpp"
class estimator {
public:
void init(int ri, int qi) { state_ = 0.0; }
void finalize(int qi) { }
void operator()(int Nij) { }
void operator()(int Nijk, int Nij) { state_ = static_cast<double>(Nijk) / Nij; }
double state() const { return state_; }
private:
double state_ = 0.0;
}; // class estimator