class Value {
constructor(value, label = "", operator = "", children = [], exponent = 1) {
this.value = value;
this.label = label;
this.operator = operator;
this.gradient = 0;
this.exponent = exponent;
this.children = children;
}
add(other) {
if (typeof other === "number") other = new Value(other);
const newValue = this.value + other.value;
return new Value(newValue, "", "+", [this, other]);
}
sub(other) {
if (typeof other === "number") other = new Value(other);
return this.add(other.mul(-1));
}
mul(other) {
if (typeof other === "number") other = new Value(other);
const newValue = this.value * other.value;
return new Value(newValue, "", "*", [this, other]);
}
div(other) {
if (typeof other === "number") other = new Value(other);
return this.mul(other.pow(-1));
}
neg() {
return this.mul(-1);
}
pow(x) {
const newValue = Math.pow(this.value, x);
this.exponent = x;
return new Value(newValue, "", "^", [this], x);
}
exp() {
const newValue = Math.exp(this.value);
return new Value(newValue, "", "exp", [this]);
}
tanh() {
const newValue = Math.tanh(this.value);
return new Value(newValue, "", "tanh", [this]);
}
backward() {
this.gradient = 1;
let topo = [];
let visited = new Set();
const buildTopo = (v) => {
if (!visited.has(v)) {
visited.add(v);
for (let child of v.children) {
buildTopo(child);
}
topo.push(v);
}
};
buildTopo(this);
for (let node of topo.reverse()) {
node._setChildGradients();
}
}
_setChildGradients() {
const saved = this.children.map(c => ({ v: c.value, g: c.gradient, op: c.operator }));
switch (this.operator) {
case "+": {
const [left, right] = this.children;
left.gradient += this.gradient;
right.gradient += this.gradient;
break;
}
case "*": {
const [left, right] = this.children;
left.gradient += this.gradient * right.value;
right.gradient += this.gradient * left.value;
break;
}
case "^": {
const [c] = this.children;
c.gradient += this.exponent * (Math.pow(c.value, this.exponent - 1)) * this.gradient;
break;
}
case "tanh": {
const [c] = this.children;
c.gradient += this.gradient * (1 - Math.pow(this.value, 2));
break;
}
case "exp": {
const [c] = this.children;
c.gradient += this.gradient * this.value;
break;
}
case "":
break;
default:
throw new Error(`Operator '${this.operator}' not implemented!`);
break;
}
}
}