...
 
Commits (3)
......@@ -3,7 +3,7 @@
:url "http://example.com/FIXME"
:license {:name "Eclipse Public License"
:url "http://www.eclipse.org/legal/epl-v10.html"}
:dependencies [[org.clojure/clojure "1.10.0-beta1"]
:dependencies [[org.clojure/clojure "1.10.0-beta3"]
[anglican "1.0.0"]]
:min-lein-version "2.0.0")
......@@ -4,6 +4,7 @@
[foppl-compiler.free-vars :refer [free-vars]]
[foppl-compiler.desugar :refer [desugar]]
[foppl-compiler.primitives]
[foppl-compiler.gensym :refer [*my-gensym*]]
[foppl-compiler.partial-evaluation :refer [fixed-point-simplify]]))
......@@ -170,7 +171,7 @@
(let [[_ e] exp
e (fixed-point-simplify e)
[rho {:keys [V A P Y]} E] (analyze rho phi e)
v (gensym "sample")
v (*my-gensym* "sample")
Z (free-vars E *bound*)
F (list 'sample* (fixed-point-simplify E))]
[rho
......@@ -193,7 +194,7 @@
[rho1 G1 E1] (analyze rho phi e)
[rho2 G2 E2] (analyze rho phi obs)
{:keys [V A P Y]} (merge-graphs G1 G2)
v (gensym "observe")
v (*my-gensym* "observe")
F1 (list 'observe* e obs)
F (fixed-point-simplify (list 'if phi F1 1))
Z (disj (free-vars e *bound*) v)
......
......@@ -19,7 +19,6 @@
c children]
{c #{p}})))
(defn topo-sort [{:keys [V A P]}]
(let [terminals
(loop [terminals []
......@@ -34,7 +33,6 @@
V))))]
terminals))
(defn graph->instructions [[rho G E]]
(conj
(vec
......@@ -44,10 +42,10 @@
(defn eval-instructions [instructions]
(reduce (fn [acc [s v]]
(let [scope (list 'let (vec (apply concat acc))
v)]
(binding [*ns* (find-ns 'foppl-compiler.core)]
(conj acc [s (eval scope)]))))
(binding [*ns* (find-ns 'foppl-compiler.core)]
(conj acc [s ((eval `(fn [{:syms ~(vec (take-nth 2 (apply concat acc)))}]
~v))
(into {} acc))])))
[]
instructions))
......@@ -91,181 +89,18 @@
observes->samples
eval-instructions))
(defn bind-free-variables [G])
(defn code->graph [code]
(->> code
(map partial-evaluation)
(map symbolic-simplify)
(map desugar)
program->graph))
(defn count-graph [code]
(let [G
(->> code
(map partial-evaluation)
(map symbolic-simplify)
(map desugar)
program->graph
second
)]
(let [[_ G _] (code->graph code)]
[(count-vertices G) (count-edges G)]))
(comment
(->> '((defn observe-data [_ data slope bias]
(let [xn (first data)
yn (second data)
zn (+ (* slope xn) bias)]
(observe (normal zn 1.0) yn)
(rest (rest data))))
(let [slope (sample (normal 0.0 10.0))
bias (sample (normal 0.0 10.0))
data (vector 1.0 2.1 2.0 3.9 3.0 5.3
4.0 7.7 5.0 10.2 6.0 12.9)]
(loop 6 data observe-data slope bias)
(vector slope bias)))
(map partial-evaluation)
(map symbolic-simplify)
(map desugar)
program->graph
graph->instructions
#_observes->samples
#_eval-instructions
)
(->> '((let [a (sample (normal 0 1))
b (sample (normal 0 1))
c (sample (normal (second [(normal 0 1) b]) 1))]
(observe (normal 0 1) 3)
c))
(map partial-evaluation)
(map symbolic-simplify)
(map desugar)
program->graph
graph->instructions
observes->samples
eval-instructions
)
(->> '((let [x (sample (normal 0 1))
y (sample (normal 1 2))]
(observe (normal x (+ 1 5)) 5)))
(map partial-evaluation)
(map symbolic-simplify)
(map desugar)
program->graph
#_graph->instructions
#_eval-instructions
sample-from-prior)
(analyze empty-env true '(loop 10 nil loop-iter a b))
(analyze empty-env true '(if (sample (normal 0 1)) 2 3))
;; TODO improve fixed-point-simplify
(count-graph
'((defn observe-data [_ data slope bias]
(let [xn (first data)
yn (second data)
zn (+ (* slope xn) bias)]
(observe (normal zn 1.0) yn)
(rest (rest data))))
(let [slope (sample (normal 0.0 10.0))
bias (sample (normal 0.0 10.0))
data (vector 1.0 2.1 2.0 3.9 3.0 5.3
4.0 7.7 5.0 10.2 6.0 12.9)]
(loop 6 data observe-data slope bias)
(vector slope bias))))
(count-graph
'((defn hmm-step [t states data trans-dists likes]
(let [z (sample (get trans-dists
(last states)))]
(observe (get likes z)
(get data t))
(append states z)))
(let [data [0.9 0.8 0.7 0.0 -0.025 -5.0 -2.0 -0.1
0.0 0.13 0.45 6 0.2 0.3 -1 -1]
trans-dists [(discrete [0.10 0.50 0.40])
(discrete [0.20 0.20 0.60])
(discrete [0.15 0.15 0.70])]
likes [(normal -1.0 1.0)
(normal 1.0 1.0)
(normal 0.0 1.0)]
states [(sample (discrete [0.33 0.33 0.34]))]]
(loop 16 states hmm-step data trans-dists likes))))
(->> '((defn hmm-step [t states data trans-dists likes]
(let [z (sample (get trans-dists
(last states)))]
(observe (get likes z)
(get data t))
(append states z)))
(let [data [0.9 0.8 0.7 0.0 -0.025 -5.0 -2.0 -0.1
0.0 0.13 0.45 6 0.2 0.3 -1 -1]
trans-dists [(discrete [0.10 0.50 0.40])
(discrete [0.20 0.20 0.60])
(discrete [0.15 0.15 0.70])]
likes [(normal -1.0 1.0)
(normal 1.0 1.0)
(normal 0.0 1.0)]
states [(sample (discrete [0.33 0.33 0.34]))]]
(loop 16 states hmm-step data trans-dists likes)))
(map partial-evaluation)
(map symbolic-simplify)
(map desugar)
program->graph
#_program->graph
#_graph->instructions
#_observes->samples
#_eval-instructions)
(->> '((let [weight-prior (normal 0 1)
W_0 (foreach 10 []
(foreach 1 [] (sample weight-prior)))
W_1 (foreach 10 []
(foreach 10 [] (sample weight-prior)))
W_2 (foreach 1 []
(foreach 10 [] (sample weight-prior)))
b_0 (foreach 10 []
(foreach 1 [] (sample weight-prior)))
b_1 (foreach 10 []
(foreach 1 [] (sample weight-prior)))
b_2 (foreach 1 []
(foreach 1 [] (sample weight-prior)))
x (mat-transpose [[1] [2] [3] [4] [5]])
y [[1] [4] [9] [16] [25]]
h_0 (mat-tanh (mat-add (mat-mul W_0 x)
(mat-repmat b_0 1 5)))
h_1 (mat-tanh (mat-add (mat-mul W_1 h_0)
(mat-repmat b_1 1 5)))
mu (mat-transpose
(mat-tanh (mat-add (mat-mul W_2 h_1)
(mat-repmat b_2 1 5))))]
(foreach 5 [y_r y
mu_r mu]
(foreach 1 [y_rc y_r
mu_rc mu_r]
(observe (normal mu_rc 1) y_rc)))
[W_0 b_0 W_1 b_1]))
#_count-graph
(map partial-evaluation)
(map symbolic-simplify)
(map desugar)
program->graph
graph->instructions
#_observes->samples
eval-instructions
)
)
(ns foppl-compiler.desugar)
(ns foppl-compiler.desugar
(:require [foppl-compiler.gensym :refer [*my-gensym*]]))
(defn dispatch-desugar
[exp]
......@@ -46,7 +47,7 @@
(expand-bindings r)))
((fn expand-body [[f & r]]
(if-not (empty? r)
(list 'let [(gensym "dontcare") (desugar f)]
(list 'let [(*my-gensym* "dontcare") (desugar f)]
(expand-body r))
(desugar f)))
body))))
......@@ -57,7 +58,7 @@
[i c acc f es]
(if (= i c)
acc
(let [new-acc (gensym "acc")]
(let [new-acc (*my-gensym* "acc")]
(list 'let [new-acc (apply list (concat (list f i acc) es))]
(expand-loop (inc i) c new-acc f es)))))
......@@ -65,7 +66,7 @@
(defmethod desugar :loop
[exp]
(let [[_ c acc f & es] exp
as (map (fn [_] (gensym "a")) es)]
as (map (fn [_] (*my-gensym* "a")) es)]
(desugar
(list 'let (vec (interleave as es))
(expand-loop 0 c acc f as)))))
......
(ns foppl-compiler.gensym)
(def ^:dynamic *my-gensym* gensym)
(ns foppl-compiler.metropolis-within-gibbs
(:require [foppl-compiler.core :refer [code->graph sample-from-prior topo-sort]]
[foppl-compiler.analyze :refer [*bound*]]
[foppl-compiler.free-vars :refer [free-vars]]
[anglican.runtime :refer [observe* sample*
normal uniform-continuous sqrt exp
discrete gamma dirichlet flip]]))
(defn build-proposal-map [P]
(->> P
(filter (fn [[k v]] (re-find #"sample\d+" (name k))))
(map (fn [[k v]] [(keyword k) (list 'fn [] (second v))]))
(into {})))
(defn sample->observe [[sym exp]]
(if (re-find #"sample\d+" (name sym))
(list 'observe* (second exp) sym)
exp))
(defn build-log-likelihood-map [P]
(->> P
(map (fn [[k v]]
[(keyword k) (list 'fn [] (sample->observe [k v]))]))
(into {})))
(defn build-log-likelihoods [graph]
`(~'fn [{:syms ~(vec (:V graph))} x#]
((get ~(build-log-likelihood-map (:P graph)) (keyword x#)))))
(defn build-proposals [graph]
`(~'fn [{:syms ~(vec (:V graph))} x#]
((get ~(build-proposal-map (:P graph)) (keyword x#)))))
(fn [{:syms [sample123 sample45]} x]
(get {:sample123 (fn [] (normal sample45 1))} (keyword x)))
(defn accept-markov-blanket [{:keys [proposals likelihoods graph]} x X' X]
(let [q (proposals X x)
q' (proposals X' x)
log-alpha (- (observe* q' (X x))
(observe* q (X' x)))
V_x (conj ((:A graph) x) x)]
(exp
(reduce (fn [log-alpha v]
(+ log-alpha
(likelihoods X' v)
(- (likelihoods X v))))
log-alpha
V_x))))
(defn tol? [a b] (< (Math/abs (- a b)) 1e-10))
(defn gibbs-substep [{:keys [proposals graph] :as params} X x]
(let [X' (assoc X x (sample* (proposals X x)))
alpha-markov (accept-markov-blanket params x X' X)
u (sample* (uniform-continuous 0 1))]
(if (< u alpha-markov)
X'
X)))
(defn gibbs-step [params X]
(reduce (partial gibbs-substep params)
X
(:var-order params)))
(defn metropolis-within-gibbs
([code]
(let [[rho graph return] (code->graph code)
initial-X (->> [rho graph return]
sample-from-prior
butlast
(into {}))
return-fn (eval `(fn [{:syms ~(vec (:V graph))}]
~return))]
(->>
(metropolis-within-gibbs initial-X [rho graph return])
(map return-fn))))
([initial-X [rho graph return]]
(let [var-order (->> (topo-sort graph)
(filter #(re-find #"sample\d+" (name %))))
;; we build fast pre-compiled routines here
proposals (binding [*ns* (find-ns 'foppl-compiler.metropolis-within-gibbs)]
(eval (build-proposals graph)))
likelihoods (binding [*ns* (find-ns 'foppl-compiler.metropolis-within-gibbs)]
(eval (build-log-likelihoods graph)))
params {:likelihoods likelihoods
:proposals proposals
:graph graph
:var-order var-order}]
((fn metropolis-gibbs-internal [X]
(lazy-seq
(let [new-X (gibbs-step params X)]
(cons X
(metropolis-gibbs-internal new-X)))))
initial-X))))
(ns foppl-compiler.reverse-diff
"Reverse mode auto-diff."
(:require [anglican.runtime :refer [observe* normal]]))
(:require [anglican.runtime :refer [observe* normal]]
[foppl-compiler.gensym :refer [*my-gensym*]]))
;; The following code so far follows
;; http://www-bcl.cs.may.ie/~barak/papers/toplas-reverse.pdf
......@@ -30,9 +31,6 @@
(set! *unchecked-math* :warn-on-boxed))
(def ^:dynamic *my-gensym* gensym)
;; some aliasing for formula sanity
(def ** (fn [x p] (Math/pow x p)))
......
This diff is collapsed.
(ns foppl-compiler.reverse-diff-test
(:require [foppl-compiler.reverse-diff :refer :all]
[clojure.test :refer [deftest testing is]]))
[clojure.test :refer [deftest testing is]]
[foppl-compiler.gensym :refer [*my-gensym*]]
[foppl-compiler.test-helpers :refer [local-gensym err?]]))
......@@ -54,12 +56,6 @@
))
(defn local-gensym []
(let [gcounter (atom 0)]
(fn [s]
(symbol (str s (swap! gcounter inc))))))
(deftest tape-expr-test
(testing "Testing tape transformation."
(binding [*my-gensym* (local-gensym)]
......@@ -98,9 +94,6 @@
:backward
[]}))))))
(deftest reverse-diff-test
(testing "Testing reverse symbolic trafo."
(binding [*my-gensym* (local-gensym)]
......@@ -133,10 +126,6 @@
'(fnr [x y]
(+ (* x x)))))))))
(defn err? [a b] (> (Math/abs (- a b)) 5e-2))
(defn check-gradient [code values]
(let [num-grad (apply
(eval
......@@ -146,12 +135,10 @@
(eval (apply reverse-diff* (rest code)))
values)
rev-grad (bp 1.0)]
#_(prn code "Forward: " res " Grad: " rev-grad)
#_(prn #_code "Forward: " res " Grad: " rev-grad)
(and (not (err? res (apply (eval code) values)))
(zero? (count (filter true? (map err? num-grad rev-grad)))))))
(deftest foppl-examples-test
(testing "Testing examples from the lecture."
(binding [*ns* (find-ns 'foppl-compiler.reverse-diff)]
......@@ -183,7 +170,6 @@
)))
(deftest foppl-exercise-test
(testing "Testing examples from the exercise."
(binding [*ns* (find-ns 'foppl-compiler.reverse-diff)]
......
(ns foppl-compiler.test-helpers)
(defn local-gensym []
(let [gcounter (atom 0)]
(fn [s]
(symbol (str s (swap! gcounter inc))))))
(defn err? [a b] (> (Math/abs (- a b)) 5e-2))