...
 
Commits (2)
......@@ -134,7 +134,7 @@
(+ (* x x)))))))))
(defn err? [a b] (> (Math/abs (- a b)) 1e-2))
(defn err? [a b] (> (Math/abs (- a b)) 5e-2))
(defn check-gradient [code values]
......@@ -142,12 +142,13 @@
(eval
(finite-difference-grad code))
values)
rev-grad (let [[res bp]
(apply
(eval (apply reverse-diff* (rest code)))
values)]
(bp 1.0))]
(zero? (count (filter true? (map err? num-grad rev-grad))))))
[res bp] (apply
(eval (apply reverse-diff* (rest code)))
values)
rev-grad (bp 1.0)]
#_(prn code "Forward: " res " Grad: " rev-grad)
(and (not (err? res (apply (eval code) values)))
(zero? (count (filter true? (map err? num-grad rev-grad)))))))
......@@ -183,11 +184,33 @@
)))
(deftest foppl-exercise-test
(testing "Testing examples from the exercise."
(binding [*ns* (find-ns 'foppl-compiler.reverse-diff)]
(is (check-gradient '(fn [x] (exp (sin x))) [0]))
(is (check-gradient '(fn [x y] (+ (* x x) (sin x))) [0 10]))
(is (check-gradient '(fn [x] (if (> x 5) (* x x) (+ x 18))) [5.000001]))
(is (check-gradient '(fn [x] (log x)) [0.1]))
(is (check-gradient '(fn [x mu sigma]
(+ (- 0 (/ (* (- x mu) (- x mu))
(* 2 (* sigma sigma))))
(* (- 0 (/ 1 2)) (log (* 2 (* 3.141592653589793 (* sigma sigma)))))))
[10 0 2]))
(is (check-gradient '(fn [x mu sigma] (normpdf x mu sigma)) [10 0 2]))
(is (check-gradient '(fn [x1 x2 x3] (+ (+ (normpdf x1 2 5)
(if (> x2 7)
(normpdf x2 0 1)
(normpdf x2 10 1)))
(normpdf x3 -4 10)))
[2 7.01 5]))
)))