NIR: Handling LIF neuron threshold for tensor with gradients
The issue arises when handling the threshold parameter, which can either be a Tensor with gradients or a numpy.ndarray (supported by defualt).
In the case where the threshold is a Tensor requiring gradients, the code throws an error when attempting to directly call .numpy():
RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
A proposed solution for this is the following edit in s2_nir.py:
neuron_params = {
"threshold": (
node.v_threshold.flatten().detach().numpy() * scale
if isinstance(node.v_threshold, torch.Tensor) and node.v_threshold.requires_grad
else node.v_threshold.flatten() * scale
),
"alpha_decay": alpha_decay,
"i_offset": v_leak_factor * node.v_leak.flatten() * scale + bias.flatten() * w_scale,
"reset": get_s2_reset_method(config.reset),
}
return neuron_params, scale