Commit 7b774e7c authored by Radford Neal's avatar Radford Neal

add gradient tracing

parent 59cf0f6e
......@@ -1135,6 +1135,15 @@ extern0 Rboolean R_mat_mult_with_BLAS [R_mat_mult_with_BLAS_len]
#endif
;
/* Tracing of gradient computations. */
extern0 int R_gradient_trace INI_as(0); /* trace gradient tracking? */
extern void Rf_gradient_trace(SEXP); /* function to trace R_gradient */
#define GRADIENT_TRACE(call) \
do { if (R_gradient_trace) Rf_gradient_trace(call); } while (0)
/* Can BLAS routines be done in helper threads? */
extern0 Rboolean R_BLAS_in_helpers
......
......@@ -1580,6 +1580,7 @@ SEXP attribute_hidden R_binary (SEXP call, int opcode, SEXP x, SEXP y,
R_variant_result = VARIANT_GRADIENT_FLAG;
break;
}
GRADIENT_TRACE(call);
}
UNPROTECT(nprotect);
......@@ -1709,6 +1710,7 @@ SEXP attribute_hidden R_unary (SEXP call, int opcode, SEXP s1, int obj1,
double d = opcode == MINUSOP ? -1 : 1;
R_gradient = copy_scaled_gradients(grad1,d);
R_variant_result = VARIANT_GRADIENT_FLAG;
GRADIENT_TRACE(call);
}
return ans;
......@@ -2262,6 +2264,7 @@ SEXP attribute_hidden do_math1 (SEXP call, SEXP op, SEXP args, SEXP env,
double d = R_math1_deriv_table[opcode] (opr, res);
R_gradient = copy_scaled_gradients(grad,d);
R_variant_result = VARIANT_GRADIENT_FLAG;
GRADIENT_TRACE(call);
}
}
......@@ -2450,6 +2453,7 @@ SEXP do_abs(SEXP call, SEXP op, SEXP args, SEXP env, int variant)
if (TYPEOF(s) == REALSXP && LENGTH(s) == 1 && !ISNAN(*REAL(s))) {
R_gradient = copy_scaled_gradients (g, sign(opr));
R_variant_result = VARIANT_GRADIENT_FLAG;
GRADIENT_TRACE(call);
}
}
UNPROTECT(3); /* g, args, s */
......@@ -2772,6 +2776,7 @@ SEXP do_math2 (SEXP call, SEXP op, SEXP args, SEXP env)
(copy_scaled_gradients (g1, grad1), g2, grad2);
R_variant_result = VARIANT_GRADIENT_FLAG;
GRADIENT_TRACE(call);
}
if (naflag) NaN_warning();
......@@ -3218,6 +3223,7 @@ SEXP do_math3 (SEXP call, SEXP op, SEXP args, SEXP env)
}
R_variant_result = VARIANT_GRADIENT_FLAG;
GRADIENT_TRACE(call);
}
if (naflag) NaN_warning();
......
......@@ -30,6 +30,7 @@
#endif
#define USE_FAST_PROTECT_MACROS
#define R_USE_SIGNALS
#include "Defn.h"
......@@ -485,6 +486,43 @@ static SEXP do_all_gradients_of (SEXP call, SEXP op, SEXP args, SEXP env,
}
/* Trace tracking of the gradients in R_gradient. */
attribute_hidden void Rf_gradient_trace (SEXP call)
{
REprintf("GRADIENT TRACE: ");
SEXP p;
for (p = R_gradient; p != R_NilValue; p = CDR(p)) {
int ix = GRADINDEX(p);
SEXP env = TAG(p);
SEXP gv = GRADVARS(env);
if (gv==R_NoObject || TYPEOF(gv)!=VECSXP || ix < 1 || ix > LENGTH(gv))
REprintf("?? ");
else
REprintf("%s ",CHAR(PRINTNAME(VECTOR_ELT(gv,ix-1))));
}
RCNTXT *cptr;
REprintf (": ");
if (call != R_NilValue && TYPEOF(CAR(call)) == SYMSXP)
REprintf ("\"%s\" ", CHAR(PRINTNAME(CAR(call))));
for (cptr = R_GlobalContext; cptr; cptr = cptr->nextcontext) {
if ((cptr->callflag & (CTXT_FUNCTION | CTXT_BUILTIN))
&& TYPEOF(cptr->call) == LANGSXP) {
SEXP fun = CAR(cptr->call);
REprintf ("\"%s\" ",
TYPEOF(fun) == SYMSXP ? CHAR(PRINTNAME(fun)) : "<Anonymous>");
}
}
REprintf("\n");
}
/* .Internal, for debugging gradient implementation. */
static SEXP do_tracking_gradients (SEXP call, SEXP op, SEXP args, SEXP env,
......
......@@ -390,6 +390,11 @@ void attribute_hidden InitOptions(void)
SET_TAG(v, install("helpers_trace"));
SETCAR(v, ScalarLogical(getenv("R_HELPERS_TRACE")!=0));
SETCDR(v,CONS(R_NilValue,R_NilValue));
v = CDR(v);
SET_TAG(v, install("gradient_trace"));
SETCAR(v, ScalarLogical(FALSE));
SET_SYMVALUE(install(".Options"), CDR(val));
UNPROTECT(1);
}
......@@ -588,6 +593,13 @@ static SEXP do_options(SEXP call, SEXP op, SEXP args, SEXP rho)
break;
case 'g':
if (streql(opname, "gradient_trace")) {
if (TYPEOF(argi) != LGLSXP || LENGTH(argi) != 1)
error(_("invalid value for '%s'"), opname);
R_gradient_trace = asLogical(argi);
val = ScalarLogical(R_gradient_trace);
goto set;
}
break;
case 'h':
......
......@@ -113,6 +113,7 @@ static SEXP do_random1(SEXP call, SEXP op, SEXP args, SEXP rho)
if (Dcall != 0) {
R_gradient = copy_scaled_gradients (g, Dcall(r,av));
R_variant_result = VARIANT_GRADIENT_FLAG;
GRADIENT_TRACE(call);
}
}
}
......@@ -304,6 +305,7 @@ static SEXP do_random2(SEXP call, SEXP op, SEXP args, SEXP rho)
g2, gv2);
}
R_variant_result = VARIANT_GRADIENT_FLAG;
GRADIENT_TRACE(call);
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment