Published
Edited
Nov 20, 2020
1 fork
Insert cell
Insert cell
streetlights = [[1, 0, 1],
[0, 1, 1],
[0, 0, 1],
[1, 1, 1]]
Insert cell
walkVsStop = [[1], [1], [0], [0]]
Insert cell
alpha = 0.2
Insert cell
hiddenSize = 4
Insert cell
epochs = 60
Insert cell
Insert cell
Insert cell
train = {
let layer_0, layer_1, layer_2
let layer_1_delta, layer_2_delta
let weights_0_1 = randoms(streetlights[0].length, hiddenSize)
let weights_1_2 = randoms(hiddenSize, walkVsStop[0].length)
let logs = []
for (let epoch = 0; epoch < epochs; epoch++) {
let layer_2_error = 0
for (let i = 0; i < streetlights.length; i++) {
// Forward prediction
layer_0 = streetlights[i]
layer_1 = relu(dot(layer_0, weights_0_1))
layer_2 = dot(layer_1, weights_1_2)
layer_2_error += mse(layer_2, walkVsStop[i])
// Backpropagation (learn)
layer_2_delta = sub(layer_2, walkVsStop[i])
layer_1_delta = mul(dot(layer_2_delta, transpose(weights_1_2)), relu2deriv(layer_1))
weights_1_2 = sub(weights_1_2, scale(alpha, dot(transpose(layer_1), [layer_2_delta]))) // [1,4]
weights_0_1 = sub(weights_0_1, scale(alpha, dot(transpose(layer_0), [layer_1_delta]))) // [3,4]
}
logs.push({ epoch, error: layer_2_error, layer_0, weights_0_1, layer_1, weights_1_2, layer_2, layer_1_delta, layer_2_delta })
}
return logs
}
Insert cell
predict = (x) => {
let lastEpoch = train.slice(-1)[0]
let weights_0_1 = lastEpoch.weights_0_1
let weights_1_2 = lastEpoch.weights_1_2
return x.map(x => {
let layer_0 = x
let layer_1 = relu(dot(layer_0, weights_0_1))
let layer_2 = dot(layer_1, weights_1_2)
return layer_2
})
}
Insert cell
predict([[1, 0, 1],
[0, 1, 1],
[0, 0, 1],
[1, 1, 1],
[0, 0, 0],
[1, 1, 0],
[0, 1, 0],
[1, 0, 0]])
Insert cell
Insert cell
relu = (x) => {
const relu = (a) => a > 0 ? a : 0
return Array.isArray(x) ? x.map(a => relu(a)) : relu(x)
}
Insert cell
relu2deriv = (output) => {
const relu2deriv = (a) => a > 0 ? 1 : 0
return Array.isArray(output) ? output.map(a => relu2deriv(a)) : relu2deriv(output)
}
Insert cell
mse = (layer, test) => layer.map((n) => Math.pow(n - test, 2)).reduce((a, c) => a + c)
Insert cell
mse([0.5, 0.5, 0.5], [1])
Insert cell
randoms = (rows, cols) => Array.from({length: rows},
() => Array.from({length: cols},
() => 2 * Math.random() - 1)
)
Insert cell
dot = (a, b) => {
const dotV = (a, b) => a.map((a, i) => a * b[i]).reduce((a, c) => a + c)
const dotM = (a, b) => a.map((a, i) => b[i].map(b => a * b))
.reduce((a, c) => a.map((a, i) => a + c[i]))
if(Array.isArray(a[0]) && Array.isArray(b[0])) return a.map(a => dotM(a, b))
if(Array.isArray(b[0])) return dotM(a, b)
return dotV(a, b)
}
Insert cell
dot([2,2,2], [2,3,4]) // 18
Insert cell
dot(transpose([1,1,1]), [[2,3,4,5]])
Insert cell
dot([2,2], [[1,1,1],[2,4,6]]) // [6, 10, 14]
Insert cell
transpose = (a) => Array.isArray(a[0]) ? a[0].map((x, i) => a.map(y => y[i])) : a.map(x => [x])
Insert cell
scale = (scalar, a) => {
const scale = (scalar, a) => a.map(a => scalar * a);
return Array.isArray(a[0]) ? a.map(a => scale(scalar, a)) : scale(scalar, a);
}
Insert cell
mul = (a, b) => {
const mul = (a, b) => a.map((a, i) => a * b[i])
return Array.isArray(a[0]) ? a.map((a, i) => mul(a, b[i])) : mul(a, b)
}
Insert cell
sub = (a, b) => {
const sub = (a, b) => a.map((a, i) => a - b[i])
return Array.isArray(a[0]) ? a.map((a, i) => sub(a, b[i])) : sub(a, b)
}
Insert cell
lineChart = (x, y, xLabel, yLabel, xDomain, yDomain) => {
const height = 240;
const margin = ({top: 20, right: 30, bottom: 30, left: 40})
const xScale = d3.scaleLinear()
.domain(xDomain || d3.extent(x))
.range([margin.left, width - margin.right]);

const yScale = d3.scaleSqrt()
.domain(yDomain || d3.extent(y))
.range([height - margin.bottom, margin.top]);
const xAxis = g => g.attr("transform", `translate(0,${height - margin.bottom})`)
.call(d3.axisBottom(xScale).ticks(10));

const yAxis = g => g
.attr("transform", `translate(${margin.left},0)`)
.call(d3.axisLeft(yScale).ticks(height / 20).tickFormat(d3.format(".0%")));
const line = d3.line().x(d => xScale(d.x)).y(d => yScale(d.y));
const data = x.map((_, i) => { return { x: x[i], y: y[i] } });

return html`<svg viewBox="0 0 ${width} ${height}">
<path d="${line(data)}" fill="none" stroke="steelblue" stroke-width="2" stroke-miterlimit="1"></path>
${d3.select(svg`<g>`).call(xAxis).node()}
${d3.select(svg`<g>`).call(yAxis).node()}
${d3.select(svg`<g><text x="${width-margin.right - 30}" y="${height - margin.bottom / 2}" style="font-size: 8pt; font-family: sans-serif; align: right;">${xLabel}</text>`).node()}
${d3.select(svg`<g><text x="${margin.left - 5}" y="${margin.top - 5}" style="font-size: 8pt; font-family: sans-serif">${yLabel}</text>`).node()}
</svg>`
}
Insert cell
d3 = require("d3@6")
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more