Commit dd2401c7 authored by Radford Neal's avatar Radford Neal

bug fix / cleanup

parent ad7b256b
......@@ -34,9 +34,11 @@
/* Get gradient identified by env from the gradients in R_gradient, which
is protected for the duration of this function. */
is protected for the duration of this function. The gradient is return
as a named vector list for multiple variables, or a single gradient if
there is only one gradient variable for this environment. */
static inline SEXP get_gradient (SEXP env)
static SEXP get_gradient (SEXP env)
{
if (! (R_variant_result & VARIANT_GRADIENT_FLAG))
R_gradient = R_NilValue;
......@@ -58,20 +60,24 @@ static inline SEXP get_gradient (SEXP env)
for (p = R_gradient; p != R_NilValue; p = CDR(p)) {
if (TAG(p) == env) {
int ix = GRADINDEX(p);
SET_NAMEDCNT_MAX(CAR(p)); /* may be able to be less drastic */
if (nv == 1) {
if (ix != 1 || r != R_NilValue) abort();
r = CAR(p);
break;
/*break;*/ /* could stop, but continue for debug error check */
}
else {
if (ix<1 || ix>nv || VECTOR_ELT(r,ix-1) != R_NilValue) abort();
SET_VECTOR_ELT (r, ix-1, CAR(p));
}
if (GRADINDEX(p) < 1 || GRADINDEX(p) > nv) abort();
SET_VECTOR_ELT (r, GRADINDEX(p)-1, CAR(p));
}
}
if (r == R_NilValue) {
r = ScalarRealMaybeConst(0.0);
if (nv == 1) {
if (r == R_NilValue) r = ScalarRealMaybeConst(0.0);
}
else if (nv > 1) {
else {
int i;
for (i = 0; i < nv; i++) {
if (VECTOR_ELT(r,i) == R_NilValue)
......@@ -85,7 +91,8 @@ static inline SEXP get_gradient (SEXP env)
}
/* Get gradients excluding those for xenv from those in R_gradient, which
is protected for the duration of this function. */
is protected for the duration of this function. The gradients are
returned as a pairlist. */
static inline SEXP get_other_gradients (SEXP xenv)
{
......@@ -248,21 +255,25 @@ static SEXP do_gradient (SEXP call, SEXP op, SEXP args, SEXP env, int variant)
SET_GRADVARS(newenv,gv);
/* Evaluate body. */
SEXP result = evalv (CAR(p), newenv,
VARIANT_GRADIENT | (variant & VARIANT_PENDING_OK));
PROTECT(result);
PROTECT_INDEX rix;
PROTECT_WITH_INDEX(result,&rix);
int res_has_grad = R_variant_result & VARIANT_GRADIENT_FLAG;
/* Attach gradient, and propage gradients backwards with the chain rule. */
SEXP result_grad = get_gradient (newenv);
PROTECT(result_grad);
R_variant_result = 0;
/* For with_gradient, attach gradient attribute. */
if (PRIMVAL(op) == 0 /* with_gradient */ && TYPEOF(result) == REALSXP
&& LENGTH(result) == 1 /* for now */) {
if (NAMEDCNT_GT_0(result))
result = duplicate(result);
REPROTECT (result = duplicate(result), rix);
if (result_grad == R_NilValue)
setAttrib (result, R_GradientSymbol, ScalarRealMaybeConst(0.0));
else {
......@@ -271,13 +282,16 @@ static SEXP do_gradient (SEXP call, SEXP op, SEXP args, SEXP env, int variant)
}
}
if ((R_variant_result & VARIANT_GRADIENT_FLAG)
&& (variant & VARIANT_GRADIENT)) {
/* Propage gradients backwards with the chain rule. */
R_variant_result = 0;
if (res_has_grad && (variant & VARIANT_GRADIENT)) {
SEXP other_grads = get_other_gradients (newenv);
PROTECT(other_grads);
if (result_grad != R_NilValue) {
PROTECT(other_grads);
for (i = 0; i < nv; i++) {
SEXP g = vargrad[i];
SEXP r = nv > 1 ? VECTOR_ELT(result_grad,i) : result_grad;
......@@ -286,6 +300,7 @@ static SEXP do_gradient (SEXP call, SEXP op, SEXP args, SEXP env, int variant)
add_scaled_gradients (other_grads, g, *REAL(r));
}
}
UNPROTECT(1);
}
if (other_grads != R_NilValue) {
......@@ -293,7 +308,6 @@ static SEXP do_gradient (SEXP call, SEXP op, SEXP args, SEXP env, int variant)
R_variant_result = VARIANT_GRADIENT_FLAG;
}
UNPROTECT(1);
}
UNPROTECT(4+nv);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment