Commit c45bd258 authored by Christian Weilbach's avatar Christian Weilbach

Implement reverse-mode source transformation a la Stalingrad.

parent d98b2e5b
Pipeline #32466673 failed with stages
in 12 minutes and 59 seconds
(ns foppl-compiler.reverse-diff
"Reverse mode auto-diff."
(:require [anglican.runtime :refer [observe* normal]]))
;; The following code so far follows
;; http://www-bcl.cs.may.ie/~barak/papers/toplas-reverse.pdf
;; and Griewank A. Evaluating derivatives. 2008.
;; Proposed roadmap
;; 1. generalization
;; + function composition (boundary type)
;; + arbitrary Anglican style nested values
;; + external primitive functions
;; + integrate into CPS trafo of Anglican
;; 3. implement tape version through operator overloading
;; following diffsharp
;; 3. linear algebra support
;; + extend to core.matrix
;; + support simple deep learning style composition
;; 4. performance optimizations
(set! *warn-on-reflection* true)
(comment
(set! *unchecked-math* :warn-on-boxed))
(def ^:dynamic *my-gensym* gensym)
;; some aliasing for formula sanity
(def ** (fn [x p] (Math/pow x p)))
(def sqrt (fn [x] (Math/sqrt x)))
(def log (fn [x] (Math/log x)))
(def exp (fn [x] (Math/exp x)))
(def pow (fn [x p] (Math/pow x p)))
(def sin (fn [x] (Math/sin x)))
(def cos (fn [x] (Math/cos x)))
(defn normpdf [x mu sigma]
(observe* (normal mu sigma) x)
#_(+ (- 0 (/ (* (- x mu) (- x mu))
(* 2 (* sigma sigma))))
(* (- 0 (/ 1 2)) (log (* 2 (* 3.141592653589793 (* sigma sigma)))))))
(defn term? [exp]
(or (number? exp)
(symbol? exp)))
(defn dispatch-exp [exp p]
(assert (seq? exp) "All differentiation happens on arithmetic expressions.")
(assert (zero? (count (filter seq? exp))) "Differentiation works on flat (not-nested) expressions only.")
(keyword (first exp)))
;; derivative definitions
(defmulti partial-deriv dispatch-exp)
(defmethod partial-deriv :+ [[_ & args] p]
(seq (into '[+]
(reduce (fn [nargs a]
(if (= a p)
(conj nargs 1)
nargs))
[]
args))))
(defmethod partial-deriv :- [[_ & args] p]
(seq (into '[-]
(reduce (fn [nargs a]
(if (= a p)
(conj nargs 1)
(conj nargs 0)))
[]
args))))
(defmethod partial-deriv :* [[_ & args] p]
(let [pn (count (filter #(= % p) args))]
(seq (into ['* pn (list 'pow p (dec pn))]
(filter #(not= % p) args)))))
(defmethod partial-deriv :/ [[_ & [a b]] p]
;; TODO support any arity
(if (= a p)
(if (= b p)
0
(list '/ 1 b))
(if (= b p)
(list '- (list '* a (list 'pow b -2)))
0)))
(defmethod partial-deriv :sin [[_ a] p]
(if (= a p)
(list 'cos a)
0))
(defmethod partial-deriv :cos [[_ a] p]
(if (= a p)
(list 'sin a)
0))
(defmethod partial-deriv :exp [[_ a] p]
(if (= a p)
(list 'exp a)
0))
(defmethod partial-deriv :log [[_ a] p]
(if (= a p)
(list '/ 1 a)
0))
(defmethod partial-deriv :pow [[_ & [base expo]] p]
(if (= base p)
(if (= expo p)
;; TODO p^p only defined for p > 0
(list '* (list '+ 1 (list 'log p))
(list 'pow p p))
(list '* expo (list 'pow p (list 'dec expo))))
(if (= expo p)
(list '* (list 'log base) (list 'pow base p))
0)))
(defmethod partial-deriv :normpdf [[_ x mu sigma] p]
(cond (= x p)
(list '*
(list '- (list '/ 1 (list '* sigma sigma)))
(list '- x mu))
(= mu p)
(list '*
(list '- (list '/ 1 (list '* sigma sigma)))
(list '- mu x))
(= sigma p)
(list '-
(list '*
(list '/ 1 (list '* sigma sigma sigma))
(list 'pow (list '- x mu) 2))
(list '/ 1 sigma))
:else
0))
(def empty-tape {:forward [] :backward []})
(defn adjoint-sym [sym]
(symbol (str sym "_")))
(defn tape-expr
"The tape returns flat variable assignments for forward and backward pass.
It allows multiple assignments following Griewank p. 125 or chapter 3.2.
Once lambdas are supported this should be A-normal form of code."
[bound sym exp tape]
(cond (and (seq? exp)
(= (first exp) 'if))
(let [[_ condition then else] exp
{:keys [forward backward]} tape
then-s (*my-gensym* "then")
else-s (*my-gensym* "else")
{then-forward :forward
then-backward :backward} (tape-expr bound then-s then empty-tape)
{else-forward :forward
else-backward :backward} (tape-expr bound else-s else empty-tape)
if-forward (concat (map (fn [[s e]] [s (list 'if condition e 0)])
then-forward)
(map (fn [[s e]] [s (list 'if-not condition e 0)])
else-forward))
if-backward (concat (map (fn [[s e]] [s (list 'if condition e s)])
then-backward)
(map (fn [[s e]] [s (list 'if-not condition e s)])
else-backward))]
{:forward (vec (concat forward
if-forward
[[sym (list 'if condition then-s else-s)]]))
:backward (vec (concat backward
if-backward
[[(adjoint-sym then-s) (adjoint-sym sym)]
[(adjoint-sym else-s) (adjoint-sym sym)]]))})
:else
(let [[f & args] exp
new-*my-gensym*s (atom [])
nargs (map (fn [a] (if (term? a) a
(let [ng (*my-gensym* "v")]
(swap! new-*my-gensym*s conj ng)
ng))) args)
nexp (conj nargs f)
{:keys [forward backward]}
(reduce (fn [{:keys [forward backward] :as tape} [s a]]
(if (term? a)
tape
(tape-expr bound s a tape)))
tape
(partition 2 (interleave nargs args)))
bound (into bound (map first forward))]
{:forward
(conj forward
[sym nexp])
:backward
(vec (concat backward
;; reverse chain-rule (backpropagator)
(for [a (distinct nargs)
:when (bound a) ;; we only do backward on our vars
:let [a-back (adjoint-sym a)]]
[a-back
(list '+ a-back
(list '*
(adjoint-sym sym)
(partial-deriv nexp a)))])
;; initialize new variables with 0
(map (fn [a]
[(adjoint-sym a) 0])
@new-*my-gensym*s)))})))
(defn sensitivities [args]
(mapv (fn [a] (symbol (str (name a) "_")))
args))
(defn init-sensitivities [args]
(->> (interleave (sensitivities args) (repeat 0))
(partition 2)
(apply concat)))
(defn reverse-diff*
"Splice the tape "
[args code]
(let [{:keys [forward backward]} (tape-expr (into #{} args)
(*my-gensym* "v")
code
{:forward
[]
:backward
[]})
ret (first (last forward))]
(list 'fn args
(list 'let (vec (apply concat forward))
[ret
(list 'fn [(symbol (str ret "_"))]
(list 'let
(vec
(concat (init-sensitivities args)
(apply concat
(reverse
backward))))
(sensitivities args)))]))))
(defmacro fnr [args code]
`~(reverse-diff* args code))
(defn addd [exprl i d]
(if (= i 0)
(reduce conj [`(~'+ ~d ~(first exprl))] (subvec exprl 1))
(reduce conj (subvec exprl 0 i)
(reduce conj [`(~'+ ~d ~(get exprl i))] (subvec exprl (+ i 1))))))
(defn finite-difference-expr [expr args i d]
`(~'/ (~'- (~expr ~@(addd args i d)) (~expr ~@args)) ~d))
(defn finite-difference-grad [expr]
(let [[op args body] expr
d (*my-gensym*)
fdes (mapv #(finite-difference-expr expr args % d) (range (count args)))
argsyms (map (fn [x] `(~'quote ~x)) args)]
`(~'fn [~@args]
(~'let [~d 0.001]
~fdes
#_~(zipmap argsyms fdes)))))
(ns foppl-compiler.reverse-diff-test
(:require [foppl-compiler.reverse-diff :refer :all]
[clojure.test :refer [deftest testing is]]))
(deftest partial-deriv-tests
(testing "Testing partial derivatives."
(is (= '(- 0 0 1)
(partial-deriv '(- 1 2 x) 'x)))
(is (= '(- 1 0)
(partial-deriv '(- x 1) 'x)))
(is (= '(* 1 (pow x 0) 1 2)
(partial-deriv '(* 1 2 x) 'x)))
(is (= '(- (* 2 (pow x -2)))
(partial-deriv '(/ 2 x) 'x)))
(is (= '(/ 1 2)
(partial-deriv '(/ x 2) 'x)))
(is (= 0
(partial-deriv '(/ 1 2) 'x)))
(is (= 0
(partial-deriv '(/ x x) 'x)))
(is (= '(/ 1 y)
(partial-deriv '(/ x y) 'x)))
(is (= '(* 2 (pow x (dec 2)))
(partial-deriv '(pow x 2) 'x)))
(is (= '(* (log 2) (pow 2 x))
(partial-deriv '(pow 2 x) 'x)))
(is (= '(* (+ 1 (log x)) (pow x x))
(partial-deriv '(pow x x) 'x)))
(is (= '(* (- (/ 1 (* sigma sigma))) (- x mu))
(partial-deriv '(normpdf x mu sigma) 'x)))
(is (= '(* (- (/ 1 (* sigma sigma))) (- mu x))
(partial-deriv '(normpdf x mu sigma) 'mu)))
(is (= '(- (* (/ 1 (* sigma sigma sigma)) (pow (- x mu) 2)) (/ 1 sigma))
(partial-deriv '(normpdf x mu sigma) 'sigma)))
))
(defn local-gensym []
(let [gcounter (atom 0)]
(fn [s]
(symbol (str s (swap! gcounter inc))))))
(deftest tape-expr-test
(testing "Testing tape transformation."
(binding [*my-gensym* (local-gensym)]
(is (= {:forward '[[v4 (if (> x x) (+ 5 x) 0)]
[then2 (if (> x x) (* 2 v4) 0)]
[else3 (if-not (> x x) (+ x 5) 0)]
[v1 (if (> x x) then2 else3)]],
:backward '[[x_ (if (> x x) (+ x_ (* v4_ (+ 1))) x_)]
[v4_ (if (> x x) (+ v4_ (* then2_ (* 1 (pow v4 0) 2))) v4_)]
[v4_ (if (> x x) 0 v4_)]
[x_ (if-not (> x x) (+ x_ (* else3_ (+ 1))) x_)]
[then2_ v1_]
[else3_ v1_]]}
(tape-expr #{'x}
(*my-gensym* "v")
'(if (> x x) (* 2 (+ 5 x)) (+ x 5)) ;; TODO support non-seqs
{:forward
[]
:backward
[]}))))
(binding [*my-gensym* (local-gensym)]
(is (= {:forward '[[v2 (- 3 x)]
[v3 (x y)]
[v1 (* v2 v3)]],
:backward '[[x_ (+ x_ (* v2_ (- 0 1)))]
[v2_ (+ v2_ (* v1_ (* 1 (pow v2 0) v3)))]
[v3_ (+ v3_ (* v1_ (* 1 (pow v3 0) v2)))]
[v2_ 0]
[v3_ 0]]}
(tape-expr #{'x}
(*my-gensym* "v")
'(* (- 3 x) ( x y))
{:forward
[]
:backward
[]}))))))
(deftest reverse-diff-test
(testing "Testing reverse symbolic trafo."
(binding [*my-gensym* (local-gensym)]
(is (= '(fn [x]
(let [v2 (+ x x) v1 (* v2 y)]
[v1 (fn [v1_]
(let [x_ 0
v2_ 0
v2_ (+ v2_ (* v1_ (* 1 (pow v2 0) y)))
x_ (+ x_ (* v2_ (+ 1 1)))]
[x_]))]))
(reverse-diff* '[x] '(* (+ x x) y)))))))
(deftest fnr-macro-test
(testing "Macro expension for reverse-diff macro."
(binding [*my-gensym* (local-gensym)]
(is (= '(fn [x y]
(let [v2 (* x x) v1 (+ v2)]
[v1 (fn [v1_]
(let [x_ 0
y_ 0
v2_ 0
v2_ (+ v2_ (* v1_ (+ 1)))
x_ (+ x_ (* v2_ (* 2 (pow x 1))))]
[x_ y_]))]))
(macroexpand-1
'(fnr [x y]
(+ (* x x)))))))))
(defn err? [a b] (> (Math/abs (- a b)) 1e-2))
(defn check-gradient [code values]
(let [num-grad (apply
(eval
(finite-difference-grad code))
values)
rev-grad (let [[res bp]
(apply
(eval (apply reverse-diff* (rest code)))
values)]
(bp 1.0))]
(zero? (count (filter true? (map err? num-grad rev-grad))))))
(deftest foppl-examples-test
(testing "Testing examples from the lecture."
(is (check-gradient '(fn [x] (exp (sin x))) [3.2]))
(is (check-gradient '(fn [x y] (+ (* x x) (sin x))) [5.1 8.7]))
(is (check-gradient '(fn [x] (if (> x 5) (* x x) (+ x 18))) [3]))
(is (check-gradient '(fn [x] (if (> x 5) (* x x) (+ x 18))) [6]))
(is (check-gradient '(fn [x] (log x)) [2.7]))
(is (check-gradient '(fn [x mu sigma]
(+ (- 0 (/ (* (- x mu) (- x mu))
(* 2 (* sigma sigma))))
(* (- 0 (/ 1 2)) (log (* 2 (* 3.141592653589793 (* sigma sigma)))))))
[3.1 -2.5 8]))
(is (check-gradient '(fn [x mu sigma] (normpdf x mu sigma)) [3.1 -2.5 8]))
(is (check-gradient '(fn [x1 x2 x3] (+ (+ (normpdf x1 2 5)
(if (> x2 7)
(normpdf x2 0 1)
(normpdf x2 10 1)))
(normpdf x3 -4 10)))
[1.2 2.1 4]))
))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment