class TensorFlowLIF {
constructor(n = 1, tau_rc=0.02, tau_ref=0.2, thresh=1) {
this.n = n;
this.rc = tf.scalar(tau_rc);
this.ref = tf.scalar(tau_ref);
this.thresh = tf.scalar(thresh);
this.v = tf.variable(tf.zeros([this.n]));
this.output = tf.variable(tf.zeros([this.n]));
this.refractory_time = tf.variable(tf.zeros([this.n]));
}
step(J, dt_num) {
tf.tidy(() => {
const dt = tf.broadcastTo(dt_num, [this.n]);
const refractoryTimesAfterDt = this.refractory_time.sub(dt);
const delta_t = dt.sub(refractoryTimesAfterDt).clipByValue(0, dt_num);
const v_next = tf.add(J, (this.v.sub(J)).mul(tf.exp(delta_t.div(this.rc).neg())));
const spikeMask = tf.greater(v_next, this.thresh);
const nonSpikeMask = tf.lessEqual(v_next, this.thresh);
this.v.assign(v_next.mul(nonSpikeMask));
this.output.assign(spikeMask.mul(1));
this.refractory_time.assign(nonSpikeMask.mul(this.ref));
});
return this.output;
}
dispose() {
tf.dispose([this.rc, this.ref, this.thresh, this.v, this.output, this.refractory_time]);
this.v.dispose();
this.output.dispose();
this.refractory_time.dispose();
}
}