class Value {
constructor(value, label = "", operator = "", children = [], exponent = 1, ixp = 0) {
this.value = value;
this.label = label;
this.operator = operator;
this.gradient = 0;
this.exponent = exponent;
this.ixp = ixp;
this.children = children;
}
add(other) {
if (typeof other === "number") other = new Value(other);
const newValue = this.value + other.value;
return new Value(newValue, "", "+", [this, other]);
}
sub(other) {
if (typeof other === "number") other = new Value(other);
return this.add(other.mul(-1));
}
mul(other) {
if (typeof other === "number") other = new Value(other);
const newValue = this.value * other.value;
return new Value(newValue, "", "*", [this, other]);
}
div(other) {
if (typeof other === "number") other = new Value(other);
return this.mul(other.pow(-1));
}
neg() {
return this.mul(-1);
}
pow(x) {
const newValue = Math.pow(this.value, x);
this.exponent = x;
return new Value(newValue, "", "^", [this], x);
}
exp() {
const newValue = Math.exp(this.value);
return new Value(newValue, "", "exp", [this]);
}
tanh() {
const newValue = Math.tanh(this.value);
return new Value(newValue, "", "tanh", [this]);
}
relu() {
const newValue = this.value < 0 ? 0 : this.value;
return new Value(newValue, "", "ReLU", [this]);
}
// softmax cross-entropy loss
static sce(xs, ixp) { // xs: predictions, ixp: index of the positive class
const xp = xs[ixp];
const xsv = xs.map(x => x.value);
const D = -Math.max(...xsv);
const exps = xs.map(x => Math.exp(x.value + D));
const sum = exps.reduce((acc, e) => acc + e, 0);
const quot = Math.exp(xp.value + D) / sum;
const sce = -Math.log(quot);
return new Value(sce, "", "SCE", [...xs], 1, ixp);
}
backward() {
this.gradient = 1;
let topo = [];
let visited = new Set();
const buildTopo = (v) => {
if (!visited.has(v)) {
visited.add(v);
for (let child of v.children) {
buildTopo(child);
}
topo.push(v);
}
};
buildTopo(this);
for (let node of topo.reverse()) {
node._setChildGradients();
}
}
_setChildGradients() {
const saved = this.children.map(c => ({ v: c.value, g: c.gradient, op: c.operator }));
switch (this.operator) {
case "+": {
const [left, right] = this.children;
left.gradient += this.gradient;
right.gradient += this.gradient;
break;
}
case "*": {
const [left, right] = this.children;
left.gradient += this.gradient * right.value;
right.gradient += this.gradient * left.value;
break;
}
case "^": {
const [c] = this.children;
c.gradient += this.exponent * (Math.pow(c.value, this.exponent - 1)) * this.gradient;
break;
}
case "tanh": {
const [c] = this.children;
c.gradient += this.gradient * (1 - Math.pow(this.value, 2));
break;
}
case "exp": {
const [c] = this.children;
c.gradient += this.gradient * this.value;
break;
}
case "ReLU": {
const [c] = this.children;
c.gradient += this.gradient * (c.value < 0 ? 0 : 1);
break;
}
case "SCE": {
const D = -Math.max(...this.children.map(x => x.value));
//throw new Error (`childValues: ${JSON.stringify(this.children.map(c => c.value))}, D: ${D}`);
const exps = this.children.map(x => Math.exp(x.value + D));
const sum = exps.reduce((acc, e) => acc + e, 0);
//throw new Error (`exps: ${JSON.stringify(exps)}, sum: ${sum}`);
const softmax = exps.map(exp => exp / sum);
//throw new Error (`softmax: ${JSON.stringify(softmax)}`);
softmax.forEach((sm, i) => {
const yi = i === this.ixp ? 1 : 0;
//throw new Error(`this.children[${i}].gradient === ${this.children[i].gradient}, this.gradient === ${this.gradient}, (sm - yi) === ${sm - yi}, sm === ${sm}, yi === ${yi}`)
this.children[i].gradient += this.gradient * (sm - yi);
})
break;
}
case "":
break;
default:
throw new Error(`Operator '${this.operator}' not implemented!`);
break;
}
}
}