Public
Edited
May 5, 2023
1 fork
1 star
Insert cell
Insert cell
class Value {
constructor(value, label = "", operator = "", children = [], exponent = 1) {
this.value = value;
this.label = label;
this.operator = operator;
this.gradient = 0;
this.exponent = exponent;
this.children = children;
}

add(other) {
if (typeof other === "number") other = new Value(other);
const newValue = this.value + other.value;
return new Value(newValue, "", "+", [this, other]);
}

sub(other) {
if (typeof other === "number") other = new Value(other);
return this.add(other.mul(-1));
}

mul(other) {
if (typeof other === "number") other = new Value(other);
const newValue = this.value * other.value;
return new Value(newValue, "", "*", [this, other]);
}

div(other) {
if (typeof other === "number") other = new Value(other);
return this.mul(other.pow(-1));
}

neg() {
return this.mul(-1);
}

pow(x) {
const newValue = Math.pow(this.value, x);
this.exponent = x;
return new Value(newValue, "", "^", [this], x);
}

exp() {
const newValue = Math.exp(this.value);
return new Value(newValue, "", "exp", [this]);
}

tanh() {
const newValue = Math.tanh(this.value);
return new Value(newValue, "", "tanh", [this]);
}

backward() {
this.gradient = 1;

let topo = [];
let visited = new Set();

const buildTopo = (v) => {
if (!visited.has(v)) {
visited.add(v);
for (let child of v.children) {
buildTopo(child);
}
topo.push(v);
}
};

buildTopo(this);

for (let node of topo.reverse()) {
node._setChildGradients();
}
}

_setChildGradients() {
const saved = this.children.map(c => ({ v: c.value, g: c.gradient, op: c.operator }));
switch (this.operator) {
case "+": {
const [left, right] = this.children;
left.gradient += this.gradient;
right.gradient += this.gradient;
break;
}
case "*": {
const [left, right] = this.children;
left.gradient += this.gradient * right.value;
right.gradient += this.gradient * left.value;
break;
}
case "^": {
const [c] = this.children;
c.gradient += this.exponent * (Math.pow(c.value, this.exponent - 1)) * this.gradient;
break;
}
case "tanh": {
const [c] = this.children;
c.gradient += this.gradient * (1 - Math.pow(this.value, 2));
break;
}
case "exp": {
const [c] = this.children;
c.gradient += this.gradient * this.value;
break;
}
case "":
break;
default:
throw new Error(`Operator '${this.operator}' not implemented!`);
break;
}
}
}
Insert cell
function preventInfinity(x) {
if (isFinite(x)) return x;
if (x > 0) return Number.MAX_VALUE;
return Number.MIN_VALUE;
}
Insert cell
function visualize(value) {
const children = value.children.map(c => visualize(c)).reduce((curr, prev) => html`${prev}${curr}`, html``);
return html`
<div class="tree">
<div><b>${value.label}</b>(${value.value}, grad = ${value.gradient})</div>
<div class="tree-branch-wrapper">
<div class="operator">${value.operator === "exp" ? "eˣ" : value.operator}${value.exponent !== 1 ? value.exponent : ""}</div>
<div class="tree-branch">
${children}
</div>
</div>
</div>
`;
}
Insert cell
Insert cell
{
const a = new Value(2, "a");
const b = new Value(3, "b");
const c = a.add(b); c.label = "c";
const d = new Value(5, "d");
const e = d.mul(b); e.label = "e";
const f = c.mul(e); f.label = "f";
f.backward();
return visualize(f)
}
Insert cell
class Neuron {
constructor(nin) {
this.w = [];
for (let i = 0; i < nin; i++)
this.w.push(new Value((Math.random() * 2) - 1))
this.b = new Value((Math.random() * 2) - 1);
}

call(x) {
const z = x.map((x, i) => this.w[i].mul(x)).reduce((sum, p) => sum.add(p), new Value(0)).add(this.b);
return z.tanh();
}

parameters() {
return [...this.w, this.b];
}
}
Insert cell
class Layer {
constructor(nin, nout) {
this.neurons = [];
for (let i = 0; i < nout; i++) {
this.neurons.push(new Neuron(nin));
}
}

call(x) {
const outs = this.neurons.map(n => n.call(x));
return outs.length === 1 ? outs[0] : outs;
}

parameters() {
return this.neurons.reduce((p, n) => p.concat(n.parameters()), []);
}
}
Insert cell
class MLP {
constructor(nin, nouts) {
const sizes = [nin].concat(nouts);
this.layers = [];
for (let i = 0; i < nouts.length; i++) {
this.layers.push(new Layer(sizes[i], sizes[i + 1]));
}
}

call(x) {
for (const layer of this.layers) {
x = layer.call(x);
}
return x;
}

parameters() {
return this.layers.reduce((p, n) => p.concat(n.parameters()), []);
}
}
Insert cell
{
const n = new MLP(3, [4, 4, 1]);
const xs = [
[2.0, 3.0, -1.0],
[3.0, -1.0, 0.5],
[0.5, 1.0, 1.0],
[1.0, 1.0, -1.0],
];
const ys = [1.0, -1.0, -1.0, 1.0];
let ypred;
let loss;
const iterations = 1000;
const learningRate = 0.1;
for (let k = 0; k < iterations; k++) {
// forward pass
ypred = xs.map(x => n.call(x));
loss = ys.reduce((acc, ygt, index) => {
const yout = ypred[index];
return yout.sub(ygt).pow(2).add(acc);
}, 0);

// backward pass
for (const p of n.parameters()) {
p.gradient = 0;
}
loss.backward();

// learning
for (const p of n.parameters()) {
p.value -= p.gradient * learningRate;
}
}
return { loss, ypred };
}
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more