Commit b7117762 authored by Radford Neal's avatar Radford Neal

implement/document/test derivatives for beta and lbeta

parent 2eee17c4
......@@ -147,7 +147,7 @@ with respect all their real-valued arguments (unless noted):
abs, sqrt, expm1, exp, log1p, log2, log10, log (one-argument form only) \cr
cos, sin, tan, acos, asin, atan, atan2 \cr
cosh, sinh, tanh, acosh, asinh, atanh \cr
gamma, lgamma, digamma, trigamma \cr
gamma, lgamma, digamma, trigamma, beta, lbeta \cr
\cr
runif \cr
dexp, rexp \cr
......
......@@ -2466,8 +2466,8 @@ static void Datan2 (double y, double x, double *dy, double *dx, double v)
double r2 = x*x + y*y;
if (r2 == 0) {
if (dy) *dy = NA_REAL;
if (dx) *dx = NA_REAL;
if (dy) *dy = 0;
if (dx) *dx = 0;
}
else {
if (dy) *dy = x / r2;
......@@ -2475,6 +2475,32 @@ static void Datan2 (double y, double x, double *dy, double *dx, double v)
}
}
static void Dlbeta (double a, double b, double *da, double *db, double v)
{
if (a == 0 || b == 0) {
if (da) *da = a == 0 ? R_NegInf : 0;
if (db) *db = b == 0 ? R_NegInf : 0;
}
else {
double diab = digamma(a+b);
if (da) *da = digamma(a) - diab;
if (db) *db = digamma(b) - diab;
}
}
static void Dbeta (double a, double b, double *da, double *db, double v)
{
if (a == 0 || b == 0) {
if (da) *da = a == 0 ? R_NegInf : 0;
if (db) *db = b == 0 ? R_NegInf : 0;
}
else {
double diab = digamma(a+b);
if (da) *da = v * (digamma(a) - diab);
if (db) *db = v * (digamma(b) - diab);
}
}
static void Ddexp (double x, double scale, double *dx, double *dscale,
double v, int give_log)
{
......@@ -2527,8 +2553,8 @@ static double *Bessel_work_array (int n2, double *ap2)
static struct { double (*fncall)(); void (*Dcall)(); } math2_table[31] = {
{ 0, 0 },
{ atan2, Datan2 },
{ lbeta, 0 },
{ beta, 0 },
{ lbeta, Dlbeta },
{ beta, Dbeta },
{ lchoose, 0 },
{ choose, 0 },
{ dchisq, 0 },
......
......@@ -176,6 +176,7 @@ with_gradient (x=2,y=3) fiddler(x,y)
x <- 0.32739
x1 <- 0.47718; x2 <- 0.89472
z1 <- 11.4319; z2 <- 13.1133
i1 <- 3
bindgrads <- function (r1,r2)
......@@ -210,6 +211,22 @@ test2 <- function (fun,...) {
))
}
test2z <- function (fun,...) {
print (bindgrads (numericDeriv(quote(fun(z1,z2,...)),"z1"),
with_gradient (z1) fun(z1,z2,...)))
print (bindgrads (numericDeriv(quote(fun(z1,z2,...)),"z2"),
with_gradient (z2) fun(z1,z2,...)))
print (bindgrads (numericDeriv(quote(fun(z1,z2,...)),c("z1","z2")),
with_gradient (z1,z2) fun(z1,z2,...)))
print (bindgrads (numericDeriv(quote(fun(z1,z2,...)),c("z1","z2")),
{ r <- with_gradient (z1) { s <- with_gradient (z2) fun(z1,z2,...);
g2 <<- attr(s,"gradient"); s }
attr(r,"gradient") <- cbind(g1=attr(r,"gradient"),g2=g2)
r
}
))
}
test2i <- function (fun,...) {
print (bindgrads (numericDeriv(quote(fun(i1,x2,...)),"x2"),
with_gradient (x2) fun(i1,x2,...)))
......@@ -255,6 +272,13 @@ test1(digamma)
test1(trigamma)
test2(atan2)
test2z(atan2)
test2(beta)
test2z(beta)
test2(lbeta)
test2z(lbeta)
test2(dexp)
test2(dexp,log=TRUE)
......
......@@ -321,6 +321,7 @@ attr(,"gradient")
>
> x <- 0.32739
> x1 <- 0.47718; x2 <- 0.89472
> z1 <- 11.4319; z2 <- 13.1133
> i1 <- 3
>
> bindgrads <- function (r1,r2)
......@@ -355,6 +356,22 @@ attr(,"gradient")
+ ))
+ }
>
> test2z <- function (fun,...) {
+ print (bindgrads (numericDeriv(quote(fun(z1,z2,...)),"z1"),
+ with_gradient (z1) fun(z1,z2,...)))
+ print (bindgrads (numericDeriv(quote(fun(z1,z2,...)),"z2"),
+ with_gradient (z2) fun(z1,z2,...)))
+ print (bindgrads (numericDeriv(quote(fun(z1,z2,...)),c("z1","z2")),
+ with_gradient (z1,z2) fun(z1,z2,...)))
+ print (bindgrads (numericDeriv(quote(fun(z1,z2,...)),c("z1","z2")),
+ { r <- with_gradient (z1) { s <- with_gradient (z2) fun(z1,z2,...);
+ g2 <<- attr(s,"gradient"); s }
+ attr(r,"gradient") <- cbind(g1=attr(r,"gradient"),g2=g2)
+ r
+ }
+ ))
+ }
>
> test2i <- function (fun,...) {
+ print (bindgrads (numericDeriv(quote(fun(i1,x2,...)),"x2"),
+ with_gradient (x2) fun(i1,x2,...)))
......@@ -484,6 +501,73 @@ r2 0.4899538 0.8701601 -0.4640815
g1 g2
r1 0.4899538 0.8701601 -0.4640815
r2 0.4899538 0.8701601 -0.4640815
> test2z(atan2)
[,1] [,2]
r1 0.7170028 0.0433287
r2 0.7170028 0.0433287
[,1] [,2]
r1 0.7170028 -0.03777305
r2 0.7170028 -0.03777305
z1 z2
r1 0.7170028 0.0433287 -0.03777305
r2 0.7170028 0.0433287 -0.03777305
g1 g2
r1 0.7170028 0.0433287 -0.03777305
r2 0.7170028 0.0433287 -0.03777305
>
> test2(beta)
[,1] [,2]
r1 2.239741 -4.457335
r2 2.239741 -4.457335
[,1] [,2]
r1 2.239741 -1.510769
r2 2.239741 -1.510769
x1 x2
r1 2.239741 -4.457335 -1.510769
r2 2.239741 -4.457335 -1.510769
g1 g2
r1 2.239741 -4.457335 -1.510769
r2 2.239741 -4.457335 -1.510769
> test2z(beta)
[,1] [,2]
r1 4.434014e-08 -3.493888e-08
r2 4.434014e-08 -3.493888e-08
[,1] [,2]
r1 4.434014e-08 -2.859912e-08
r2 4.434014e-08 -2.859912e-08
z1 z2
r1 4.434014e-08 -3.493888e-08 -2.859912e-08
r2 4.434014e-08 -3.493888e-08 -2.859912e-08
g1 g2
r1 4.434014e-08 -3.493888e-08 -2.859912e-08
r2 4.434014e-08 -3.493888e-08 -2.859912e-08
>
> test2(lbeta)
[,1] [,2]
r1 0.8063602 -1.990112
r2 0.8063602 -1.990112
[,1] [,2]
r1 0.8063602 -0.6745284
r2 0.8063602 -0.6745284
x1 x2
r1 0.8063602 -1.990112 -0.6745284
r2 0.8063602 -1.990112 -0.6745284
g1 g2
r1 0.8063602 -1.990112 -0.6745284
r2 0.8063602 -1.990112 -0.6745284
> test2z(lbeta)
[,1] [,2]
r1 -16.93138 -0.7879742
r2 -16.93138 -0.7879742
[,1] [,2]
r1 -16.93138 -0.644994
r2 -16.93138 -0.644994
z1 z2
r1 -16.93138 -0.7879742 -0.644994
r2 -16.93138 -0.7879742 -0.644994
g1 g2
r1 -16.93138 -0.7879742 -0.644994
r2 -16.93138 -0.7879742 -0.644994
>
> test2(dexp)
[,1] [,2]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment