class STDPWeights {
constructor(numPre, numPost, tau_plus = 0.02, tau_minus = 0.02, a_plus = 0.01, a_minus = 0.011, g_min = 0, g_max = 1) {
this.numPre = numPre;
this.numPost = numPost;
this.tau_plus = tau_plus;
this.tau_minus = tau_minus;
this.x = times(numPre, 0);
this.y = times(numPost, 0);
this.a_plus = a_plus;
this.a_minus = a_minus;
this.g_min = g_min;
this.g_max = g_max;
this.w = times2d(numPre, numPost, 0);
}
updateWeights(preOutputs, postOutputs) {
const preSpikeIndices = getIndicesGreaterThan0(preOutputs);
const postSpikeIndices = getIndicesGreaterThan0(postOutputs);
preSpikeIndices.forEach((ps_idx) => {
this.x[ps_idx] += this.a_plus;
});
postSpikeIndices.forEach((ps_idx) => {
this.y[ps_idx] -= this.a_minus;
});
const alpha_g = this.g_max - this.g_min;
preSpikeIndices.forEach((ps_idx) => {
for(let i = 0; i<this.numPost; i++) {
this.w[ps_idx][i] += alpha_g * this.y[i];
if(this.w[ps_idx][i] < this.g_min) { this.w[ps_idx][i] = this.g_min; }
if(this.w[ps_idx][i] > this.g_max) { this.w[ps_idx][i] = this.g_max; }
}
});
postSpikeIndices.forEach((ps_idx) => {
for(let i = 0; i<this.numPre; i++) {
this.w[i][ps_idx] += alpha_g * this.x[i];
if(this.w[i][ps_idx] < this.g_min) { this.w[i][ps_idx] = this.g_min; }
if(this.w[i][ps_idx] > this.g_max) { this.w[i][ps_idx] = this.g_max; }
}
});
}
step(t_step) {
for(let i = 0; i<this.x.length; i++) {
this.x[i] = this.x[i] * (1 - t_step/this.tau_plus);
}
for(let i = 0; i<this.y.length; i++) {
this.y[i] = this.y[i] * (1 - t_step/this.tau_minus);
}
}
}