Commit 19b8ddd7 authored by Jaroslaw Zola's avatar Jaroslaw Zola

Network scoring example extended. Small fix to RadCounter.

parent 4a38125d
#!/usr/bin/env python
__author__ = "Jaroslaw Zola"
__copyright__ = "Copyright (c) 2018 SCoRe Group http://www.score-group.org/"
__license__ = "MIT"
__version__ = "1.0.0"
__maintainer__ = "Jaroslaw Zola"
__email__ = "jaroslaw.zola@hush.com"
__status__ = "Development"
import argparse
import csv
import os
import sys
from bitarray import bitarray
from sabnatk.BVCounter import AIC64
from sabnatk.BVCounter import BDeu64
from sabnatk.BVCounter import MDL64
from sabnatk.BVCounter import AIC256
from sabnatk.BVCounter import BDeu256
from sabnatk.BVCounter import MDL256
def extant_file(fname):
if not os.path.isfile(fname): raise argparse.ArgumentTypeError("file {0} not found".format(fname))
return fname
if __name__ == "__main__":
if len(sys.argv) != 3:
print("usage: bnscore.py data network")
exit(-1)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("data", metavar="", help = "input csv file", type = extant_file)
parser.add_argument("network", metavar="", help = "input sif file", type = extant_file)
parser.add_argument("--sep", metavar="", help = "data separator, guessed by default", nargs='?', const = None, type = str)
parser.add_argument("-s", "--score", metavar="", help = "scoring function [aic|bdeu|mdl]", type = str, default = "mdl")
if len(sys.argv)==1:
parser.print_help()
sys.exit(-1)
data = sys.argv[1]
args = parser.parse_args()
data = args.data
net = args.network
sep = args.sep
score = args.score
# read data
X = {}
n = 0
m = 0
T = {}
with open(data, "rt") as csvfile:
if not csv.Sniffer().has_header(csvfile.read(100000)):
print("error: csv header missing")
sys.exit(-1)
csvfile.seek(0)
if not sep:
dialect = csv.Sniffer().sniff(csvfile.read(100000))
csvfile.seek(0)
sep = dialect.delimiter
h = next(csvfile)
h = h.replace("\n", "").replace("\r", "").split(sep)
n = len(h)
if (n > 255):
print("too many variables")
sys.exit(-1)
for xi in range(n):
X[h[xi]] = xi
T[xi] = []
for l in csvfile:
l = l.replace("\n", "").replace("\r", "").split(sep)
for xi in range(n):
T[xi].append(l[xi])
m = m + 1
# transform data
D = []
# read data
with open(data, "rt") as cf:
rd = csv.reader(cf, delimiter = ' ')
for row in rd:
n = n + 1
m = len(row)
D = D + row
for xi in range(n):
t = 0
M = {}
for val in T[xi]:
if val not in M:
D.append(t)
M[val] = t
t = t + 1
else:
D.append(M[val])
D = list(map(int, D))
# init graph
G = []
......@@ -38,25 +105,29 @@ if __name__ == "__main__":
G = G + [u]
# read network
net = sys.argv[2]
with open(net, "rt") as nf:
for l in nf.readlines():
s, _, t = l.rstrip().split(" ")
if (s != t):
G[int(t)][int(s)] = 1;
G[X[t]][X[s]] = 1;
# init score
s = MDL64()
s = MDL256()
if (score == "aic"):
s = AIC256()
elif (score == "bdeu"):
s = BDeu256()
s.init(n, m, D)
# get score
score = 0.0
S = 0.0
for i in range(n):
xi = bitarray(n, endian = "little")
xi.setall(0)
xi[i] = 1
score = score + s.score(xi.tobytes(), G[i].tobytes())[0][0]
S = S + s.score(xi.tobytes(), G[i].tobytes())[0][0]
print(score)
print(score + " score: " + str(S) + "\n")
......@@ -17,4 +17,4 @@ D = rbn(N, m)
write.bif("randbn.bif", N)
write.dot("randbn.dot", N)
write.table(t(D), "randbn.csv", row.names = FALSE, col.names = FALSE, quote = FALSE)
\ No newline at end of file
write.table(t(D), "randbn.csv", row.names = FALSE, col.names = FALSE, quote = FALSE)
......@@ -63,14 +63,14 @@ public:
template <typename score_functor>
void apply(const std::vector<int>& xi_vect, const set_type& pa, std::vector<score_functor>& F) const {
set_type set_xi = set_empty<N>();
set_type set_xi = set_empty<set_type>();
for (auto xi : xi_vect) set_xi = set_add(set_xi, xi);
apply(set_xi, pa, F);
} // apply
template <typename score_functor>
void apply(int xi, const set_type& pa, score_functor& F) const {
set_type set_xi = set_empty<N>();
set_type set_xi = set_empty<set_type>();
set_xi = set_add(set_xi, xi);
std::vector<score_functor> F_vect{F};
apply(set_xi, pa, F_vect);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment