Published
Edited
Mar 28, 2019
1 star
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
model = {
const model = tf.sequential();
// Two fully connected layers with dropout between each:
model.add(tf.layers.dense({units: 24, activation: 'relu', inputShape: [5]}));
model.add(tf.layers.dropout({rate: 0.01}));
model.add(tf.layers.dense({units: 16, activation: 'relu'}));
model.add(tf.layers.dropout({rate: 0.01}));
// Only two classes: "strike" and "ball":
model.add(tf.layers.dense({units: 2, activation: 'softmax'}));
model.compile({
optimizer: tf.train.adam(0.01),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
return model;
}
Insert cell
Insert cell
Insert cell
data = {
const data = [];
csvData.forEach((values) => {
// 'logit' data uses the 5 fields:
const x = [];
x.push(parseFloat(values.px));
x.push(parseFloat(values.pz));
x.push(parseFloat(values.sz_top));
x.push(parseFloat(values.sz_bot));
x.push(parseFloat(values.left_handed_batter));
// The label is simply 'is strike' or 'is ball':
const y = parseInt(values.is_strike, 10);
data.push({x: x, y: y});
});
// Shuffle the contents to ensure the model does not always train on the same
// sequence of pitch data:
tf.util.shuffle(data);
return data;
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
// Trains and reports loss+accuracy for one batch of training data:
async function trainBatch(index) {
const history = await model.fit(batches[index].x, batches[index].y, {
epochs: 1,
shuffle: false,
validationData: [batches[index].x, batches[index].y],
batchSize: CONSTANTS.BATCH_SIZE
});
mutable step++;
mutable loss = history.history.loss[0];
mutable accuracy = history.history.acc[0];
await tf.nextFrame();
updateHeatmap();
await tf.nextFrame();
}
Insert cell
Insert cell
// Trains one epoch and updates the heatmap after each batch:
async function runEpochTrainAndVisual() {
mutable isTraining = true;
for (let i = 0; i < batches.length; i++) {
await trainBatch(i);
// The tf.nextFrame() helper function returns a Promise when requestAnimationFrame()
// has completed. These calls ensure that the single-threaded JS event loop is not
// blocked during training:
await tf.nextFrame();

updateHeatmap();
await tf.nextFrame();
}
mutable isTraining = false;
}
Insert cell
// Function that ensures the model is trained to the number of epochs selected
// by the user in this codelab:
async function runTraining() {
while (mutable epoch < mutable totalEpochs) {
await runEpochTrainAndVisual();
mutable epoch++;
}
}
Insert cell
Insert cell
Insert cell
Insert cell
function generateZone() {
const yMin = 0;
const yMax = 4;
const xMin = -2;
const xMax = 2;
const length = 50;

const zoneData = [];
const zoneCoordinates = [];
const isLeftHanded = document.getElementById("batterSelect").selectedIndex;

for (let y = yMax; y >= yMin; y = y - (yMax - yMin) / length) {
for (let x = xMin; x <= xMax; x = x + (xMax - xMin) / length) {
zoneData.push(normalize(x, CONSTANTS.PX_MIN, CONSTANTS.PX_MAX));
zoneData.push(normalize(y, CONSTANTS.PZ_MIN, CONSTANTS.PZ_MAX));
zoneData.push(normalize(3.5, CONSTANTS.SZ_TOP_MIN, CONSTANTS.SZ_TOP_MAX));
zoneData.push(normalize(1.5, CONSTANTS.SZ_BOT_MIN, CONSTANTS.SZ_BOT_MAX));
zoneData.push(isLeftHanded);
zoneCoordinates.push({x: x, y: y});
}
}
return {
data: tf.tensor2d(zoneData, [zoneCoordinates.length, 5]),
coords: zoneCoordinates,
lefty: isLeftHanded
};
}
Insert cell
Insert cell
Insert cell
function predictZone() {
const predictions = model.predictOnBatch(mutable zone.data);
const values = predictions.dataSync();

// Sort each value so the higher prediction is the first element in the array:
const results = [];
let index = 0;
for (let i = 0; i < values.length; i++) {
let list = [];
list.push({value: values[index++], strike: 0});
list.push({value: values[index++], strike: 1});
list = list.sort((a, b) => b.value - a.value);
results.push(list);
}
return results;
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
scaleY = d3.scaleLinear().domain([zone.coords[0].y, zone.coords[zone.coords.length - 1].y])
.range([0, CONSTANTS.HEATMAP_HEIGHT / CONSTANTS.HEATMAP_SIZE])
Insert cell
csvScaleY = d3.scaleLinear().domain([0, 4])
.range([CONSTANTS.HEATMAP_HEIGHT, 0])
Insert cell
Insert cell
Insert cell
Insert cell
circleAttrs = circles
.attr('cx', (d) => { return csvScaleX(parseFloat(d.px)) })
.attr('cy', (d) => { return csvScaleY(parseFloat(d.pz)) })
.attr('r', (d) => { return CONSTANTS.CSV_SIZE })
.style('fill', (d) => {
if (d.is_strike === '1') {
return "rgba(255, 165, 0, 0.25)";
} else {
return "rgba(0, 90, 255, 0.25)";
}
})
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Type JavaScript, then Shift-Enter. Ctrl-space for more options. Arrow ↑/↓ to switch modes.

Insert cell
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more