Commit 4d0c0c5c authored by Radford Neal's avatar Radford Neal

implemented gradients for dunif, punif, and qunif, other tweak

parent 7b774e7c
......@@ -149,7 +149,7 @@ cos, sin, tan, acos, asin, atan, atan2 \cr
cosh, sinh, tanh, acosh, asinh, atanh \cr
gamma, lgamma, digamma, trigamma, beta, lbeta \cr
\cr
runif \cr
dunif, punif, qunif, runif \cr
dexp, pexp, qexp, rexp \cr
dgeom, pgeom \cr
dpois \cr
......
......@@ -2880,7 +2880,7 @@ static void Dqcauchy (double p, double location, double scale,
{
if (scale <= 0) {
if (dp) *dp = 0;
if (dlocation) *dlocation = 0;
if (dlocation) *dlocation = 1;
if (dscale) *dscale = 0;
}
else {
......@@ -2959,7 +2959,7 @@ static void Dqnorm (double p, double mu, double sigma,
{
if (sigma <= 0) {
if (dp) *dp = 0;
if (dmu) *dmu = 0;
if (dmu) *dmu = 1;
if (dsigma) *dsigma = 0;
}
else {
......@@ -2975,6 +2975,84 @@ static void Dqnorm (double p, double mu, double sigma,
}
}
static void Ddunif (double x, double a, double b,
double *dx, double *da, double *db,
int give_log, double v)
{
if (dx) *dx = 0;
if (b <= a || x <= a || x >= b) {
if (da) *da = 0;
if (db) *db = 0;
}
else {
double t = 1 / (b-a);
if (give_log) {
if (da) *da = t;
if (db) *db = -t;
}
else {
if (da) *da = t * v;
if (db) *db = -t * v;
}
}
}
static void Dpunif (double q, double a, double b,
double *dq, double *da, double *db,
int lower_tail, int log_p, double v)
{
if (b <= a || q <= a || q >= b) {
if (dq) *dq = 0;
if (da) *da = 0;
if (db) *db = 0;
}
else {
double t = 1 / (b-a);
if (lower_tail) {
if (dq) *dq = t;
if (da) *da = ((q-a)*t - 1) * t;
if (db) *db = (a-q) * t * t;
}
else {
if (dq) *dq = -t;
if (da) *da = ((a-q)*t + 1) * t;
if (db) *db = (q-a) * t * t;
}
if (log_p) {
double expv = exp(-v);
if (dq) *dq *= expv;
if (da) *da *= expv;
if (db) *db *= expv;
}
}
}
static void Dqunif (double p, double a, double b,
double *dp, double *da, double *db,
int lower_tail, int log_p, double v)
{
if (b <= a) {
if (dp) *dp = 0;
}
else {
if (dp) *dp = lower_tail ? b-a : a-b;
}
if (log_p) {
if (da) *da = lower_tail ? -expm1(p) : exp(p);
if (db) *db = lower_tail ? exp(p) : -expm1(p);
if (dp) *dp *= exp(p);
}
else {
if (da) *da = lower_tail ? 1-p : p;
if (db) *db = lower_tail ? p : 1-p;
}
}
static void Ddlogis (double x, double location, double scale,
double *dx, double *dlocation, double *dscale,
int give_log, double v)
......@@ -3039,7 +3117,7 @@ static void Dqlogis (double p, double location, double scale,
{
if (scale <= 0) {
if (dp) *dp = 0;
if (dlocation) *dlocation = 0;
if (dlocation) *dlocation = 1;
if (dscale) *dscale = 0;
}
else {
......@@ -3086,9 +3164,9 @@ static struct { double (*fncall)(); void (*Dcall)(); } math3_table[48] = {
{ dnorm, Ddnorm },
{ pnorm, Dpnorm },
{ qnorm, Dqnorm },
{ dunif, 0 },
{ punif, 0 },
{ qunif, 0 },
{ dunif, Ddunif },
{ punif, Dpunif },
{ qunif, Dqunif },
{ dweibull, 0 },
{ pweibull, 0 },
{ qweibull, 0 },
......
......@@ -178,6 +178,7 @@ x <- 0.32739
x1 <- 0.47718; x2 <- 0.89472; x3 <- 0.67325
y1 <- -0.3721; y2 <- -0.8131; y3 <- 1.22213
z1 <- 11.4319; z2 <- 13.1133; z3 <- 6.68901
w1 <- 0.8389; w2 <- 0.1123; w3 <- 4.68701
i1 <- 3
bindgrads <- function (r1,r2)
......@@ -288,6 +289,17 @@ test3z <- function (fun,...) {
with_gradient (z1,z2,z3) fun(z1,z2,z3,...)))
}
test3w <- function (fun,...) {
print (bindgrads (numericDeriv(quote(fun(w1,w2,w3,...)),"w1"),
with_gradient (w1) fun(w1,w2,w3,...)))
print (bindgrads (numericDeriv(quote(fun(w1,w2,w3,...)),"w2"),
with_gradient (w2) fun(w1,w2,w3,...)))
print (bindgrads (numericDeriv(quote(fun(w1,w2,w3,...)),"w3"),
with_gradient (w3) fun(w1,w2,w3,...)))
print (bindgrads (numericDeriv(quote(fun(w1,w2,w3,...)),c("w1","w2","w3")),
with_gradient (w1,w2,w3) fun(w1,w2,w3,...)))
}
test1(abs)
test1(sqrt)
......@@ -416,6 +428,19 @@ test3y(qnorm,log=TRUE,lower=FALSE)
test2r(rnorm)
test3w(dunif)
test3w(dunif,log=TRUE)
test3w(punif)
test3w(punif,log=TRUE)
test3w(punif,lower=FALSE)
test3w(punif,log=TRUE,lower=FALSE)
test3w(qunif)
test3y(qunif,log=TRUE)
test3w(qunif,lower=FALSE)
test3y(qunif,log=TRUE,lower=FALSE)
test2r(runif)
test2r(rweibull)
......@@ -323,6 +323,7 @@ attr(,"gradient")
> x1 <- 0.47718; x2 <- 0.89472; x3 <- 0.67325
> y1 <- -0.3721; y2 <- -0.8131; y3 <- 1.22213
> z1 <- 11.4319; z2 <- 13.1133; z3 <- 6.68901
> w1 <- 0.8389; w2 <- 0.1123; w3 <- 4.68701
> i1 <- 3
>
> bindgrads <- function (r1,r2)
......@@ -433,6 +434,17 @@ attr(,"gradient")
+ with_gradient (z1,z2,z3) fun(z1,z2,z3,...)))
+ }
>
> test3w <- function (fun,...) {
+ print (bindgrads (numericDeriv(quote(fun(w1,w2,w3,...)),"w1"),
+ with_gradient (w1) fun(w1,w2,w3,...)))
+ print (bindgrads (numericDeriv(quote(fun(w1,w2,w3,...)),"w2"),
+ with_gradient (w2) fun(w1,w2,w3,...)))
+ print (bindgrads (numericDeriv(quote(fun(w1,w2,w3,...)),"w3"),
+ with_gradient (w3) fun(w1,w2,w3,...)))
+ print (bindgrads (numericDeriv(quote(fun(w1,w2,w3,...)),c("w1","w2","w3")),
+ with_gradient (w1,w2,w3) fun(w1,w2,w3,...)))
+ }
>
> test1(abs)
[,1] [,2]
r1 0.32739 1
......@@ -1392,6 +1404,139 @@ r2 -1.416619 -2.385397 1 -0.4938256
r1 0.774704 1 0.3325331
r2 0.774704 1 0.3325331
>
> test3w(dunif)
[,1] [,2]
r1 0.2185931 0
r2 0.2185931 0
[,1] [,2]
r1 0.2185931 0.04778294
r2 0.2185931 0.04778294
[,1] [,2]
r1 0.2185931 -0.04778294
r2 0.2185931 -0.04778294
w1 w2 w3
r1 0.2185931 0 0.04778294 -0.04778294
r2 0.2185931 0 0.04778294 -0.04778294
> test3w(dunif,log=TRUE)
[,1] [,2]
r1 -1.520543 0
r2 -1.520543 0
[,1] [,2]
r1 -1.520543 0.2185931
r2 -1.520543 0.2185931
[,1] [,2]
r1 -1.520543 -0.2185931
r2 -1.520543 -0.2185931
w1 w2 w3
r1 -1.520543 0 0.2185931 -0.2185931
r2 -1.520543 0 0.2185931 -0.2185931
>
> test3w(punif)
[,1] [,2]
r1 0.1588297 0.2185931
r2 0.1588297 0.2185931
[,1] [,2]
r1 0.1588297 -0.183874
r2 0.1588297 -0.183874
[,1] [,2]
r1 0.1588297 -0.03471908
r2 0.1588297 -0.03471908
w1 w2 w3
r1 0.1588297 0.2185931 -0.183874 -0.03471908
r2 0.1588297 0.2185931 -0.183874 -0.03471908
> test3w(punif,log=TRUE)
[,1] [,2]
r1 -1.839922 1.376273
r2 -1.839922 1.376273
[,1] [,2]
r1 -1.839922 -1.15768
r2 -1.839922 -1.15768
[,1] [,2]
r1 -1.839922 -0.2185931
r2 -1.839922 -0.2185931
w1 w2 w3
r1 -1.839922 1.376273 -1.15768 -0.2185931
r2 -1.839922 1.376273 -1.15768 -0.2185931
> test3w(punif,lower=FALSE)
[,1] [,2]
r1 0.8411703 -0.2185931
r2 0.8411703 -0.2185931
[,1] [,2]
r1 0.8411703 0.183874
r2 0.8411703 0.183874
[,1] [,2]
r1 0.8411703 0.03471908
r2 0.8411703 0.03471908
w1 w2 w3
r1 0.8411703 -0.2185931 0.183874 0.03471908
r2 0.8411703 -0.2185931 0.183874 0.03471908
> test3w(punif,log=TRUE,lower=FALSE)
[,1] [,2]
r1 -0.1729612 -0.2598678
r2 -0.1729612 -0.2598678
[,1] [,2]
r1 -0.1729612 0.2185930
r2 -0.1729612 0.2185931
[,1] [,2]
r1 -0.1729612 0.04127474
r2 -0.1729612 0.04127474
w1 w2 w3
r1 -0.1729612 -0.2598678 0.2185930 0.04127474
r2 -0.1729612 -0.2598678 0.2185931 0.04127474
>
> test3w(qunif)
[,1] [,2]
r1 3.950024 4.57471
r2 3.950024 4.57471
[,1] [,2]
r1 3.950024 0.1611003
r2 3.950024 0.1611000
[,1] [,2]
r1 3.950024 0.8389
r2 3.950024 0.8389
w1 w2 w3
r1 3.950024 4.57471 0.1611003 0.8389
r2 3.950024 4.57471 0.1611000 0.8389
> test3y(qunif,log=TRUE)
[,1] [,2]
r1 0.5897541 1.402854
r2 0.5897541 1.402854
[,1] [,2]
r1 0.5897541 0.3107147
r2 0.5897541 0.3107147
[,1] [,2]
r1 0.5897541 0.6892853
r2 0.5897541 0.6892853
y1 y2 y3
r1 0.5897541 1.402854 0.3107147 0.6892853
r2 0.5897541 1.402854 0.3107147 0.6892853
> test3w(qunif,lower=FALSE)
[,1] [,2]
r1 0.8492858 -4.57471
r2 0.8492858 -4.57471
[,1] [,2]
r1 0.8492858 0.8389001
r2 0.8492858 0.8389000
[,1] [,2]
r1 0.8492858 0.1611
r2 0.8492858 0.1611
w1 w2 w3
r1 0.8492858 -4.57471 0.8389001 0.1611
r2 0.8492858 -4.57471 0.8389000 0.1611
> test3y(qunif,log=TRUE,lower=FALSE)
[,1] [,2]
r1 -0.1807241 -1.402854
r2 -0.1807241 -1.402854
[,1] [,2]
r1 -0.1807241 0.6892853
r2 -0.1807241 0.6892853
[,1] [,2]
r1 -0.1807241 0.3107147
r2 -0.1807241 0.3107147
y1 y2 y3
r1 -0.1807241 -1.402854 0.6892853 0.3107147
r2 -0.1807241 -1.402854 0.6892853 0.3107147
>
> test2r(runif)
x1 x2
r1 0.7403373 0.3697434 0.6302566
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment