Published
Edited
Jun 9, 2018
2 forks
30 stars
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
model = {
const model = tf.sequential();
// Two fully connected layers with dropout between each:
model.add(tf.layers.dense({units: 24, activation: 'relu', inputShape: [5]}));
model.add(tf.layers.dropout({rate: 0.01}));
model.add(tf.layers.dense({units: 16, activation: 'relu'}));
model.add(tf.layers.dropout({rate: 0.01}));
// Only two classes: "strike" and "ball":
model.add(tf.layers.dense({units: 2, activation: 'softmax'}));
model.compile({
optimizer: tf.train.adam(0.01),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
return model;
}
Insert cell
Insert cell
Insert cell
data = {
const data = [];
csvData.forEach((values) => {
// 'logit' data uses the 5 fields:
const x = [];
x.push(parseFloat(values.px));
x.push(parseFloat(values.pz));
x.push(parseFloat(values.sz_top));
x.push(parseFloat(values.sz_bot));
x.push(parseFloat(values.left_handed_batter));
// The label is simply 'is strike' or 'is ball':
const y = parseInt(values.is_strike, 10);
data.push({x: x, y: y});
});
// Shuffle the contents to ensure the model does not always train on the same
// sequence of pitch data:
tf.util.shuffle(data);
return data;
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
// Trains and reports loss+accuracy for one batch of training data:
async function trainBatch(index) {
const history = await model.fit(batches[index].x, batches[index].y, {
epochs: 1,
shuffle: false,
validationData: [batches[index].x, batches[index].y],
batchSize: CONSTANTS.BATCH_SIZE
});
mutable step++;
mutable loss = history.history.loss[0];
mutable accuracy = history.history.acc[0];
await tf.nextFrame();
updateHeatmap();
await tf.nextFrame();
}
Insert cell
Insert cell
// Trains one epoch and updates the heatmap after each batch:
async function runEpochTrainAndVisual() {
mutable isTraining = true;
for (let i = 0; i < batches.length; i++) {
await trainBatch(i);
// The tf.nextFrame() helper function returns a Promise when requestAnimationFrame()
// has completed. These calls ensure that the single-threaded JS event loop is not
// blocked during training:
await tf.nextFrame();

updateHeatmap();
await tf.nextFrame();
}
mutable isTraining = false;
}
Insert cell
// Function that ensures the model is trained to the number of epochs selected
// by the user in this codelab:
async function runTraining() {
while (mutable epoch < mutable totalEpochs) {
await runEpochTrainAndVisual();
mutable epoch++;
}
}
Insert cell
Insert cell
Insert cell
Insert cell
function generateZone() {
const yMin = 0;
const yMax = 4;
const xMin = -2;
const xMax = 2;
const length = 50;

const zoneData = [];
const zoneCoordinates = [];
const isLeftHanded = document.getElementById("batterSelect").selectedIndex;

for (let y = yMax; y >= yMin; y = y - (yMax - yMin) / length) {
for (let x = xMin; x <= xMax; x = x + (xMax - xMin) / length) {
zoneData.push(normalize(x, CONSTANTS.PX_MIN, CONSTANTS.PX_MAX));
zoneData.push(normalize(y, CONSTANTS.PZ_MIN, CONSTANTS.PZ_MAX));
zoneData.push(normalize(3.5, CONSTANTS.SZ_TOP_MIN, CONSTANTS.SZ_TOP_MAX));
zoneData.push(normalize(1.5, CONSTANTS.SZ_BOT_MIN, CONSTANTS.SZ_BOT_MAX));
zoneData.push(isLeftHanded);
zoneCoordinates.push({x: x, y: y});
}
}
return {
data: tf.tensor2d(zoneData, [zoneCoordinates.length, 5]),
coords: zoneCoordinates,
lefty: isLeftHanded
};
}
Insert cell
Insert cell
Insert cell
function predictZone() {
const predictions = model.predictOnBatch(mutable zone.data);
const values = predictions.dataSync();

// Sort each value so the higher prediction is the first element in the array:
const results = [];
let index = 0;
for (let i = 0; i < values.length; i++) {
let list = [];
list.push({value: values[index++], strike: 0});
list.push({value: values[index++], strike: 1});
list = list.sort((a, b) => b.value - a.value);
results.push(list);
}
return results;
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
scaleY = d3.scaleLinear().domain([zone.coords[0].y, zone.coords[zone.coords.length - 1].y])
.range([0, CONSTANTS.HEATMAP_HEIGHT / CONSTANTS.HEATMAP_SIZE])
Insert cell
csvScaleY = d3.scaleLinear().domain([0, 4])
.range([CONSTANTS.HEATMAP_HEIGHT, 0])
Insert cell
Insert cell
Insert cell
Insert cell
circleAttrs = circles
.attr('cx', (d) => { return csvScaleX(parseFloat(d.px)) })
.attr('cy', (d) => { return csvScaleY(parseFloat(d.pz)) })
.attr('r', (d) => { return CONSTANTS.CSV_SIZE })
.style('fill', (d) => {
if (d.is_strike === '1') {
return "rgba(255, 165, 0, 0.25)";
} else {
return "rgba(0, 90, 255, 0.25)";
}
})
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Type JavaScript, then Shift-Enter. Ctrl-space for more options. Arrow ↑/↓ to switch modes.

Insert cell
Insert cell

One platform to build and deploy the best data apps

Experiment and prototype by building visualizations in live JavaScript notebooks. Collaborate with your team and decide which concepts to build out.
Use Observable Framework to build data apps locally. Use data loaders to build in any language or library, including Python, SQL, and R.
Seamlessly deploy to Observable. Test before you ship, use automatic deploy-on-commit, and ensure your projects are always up-to-date.
Learn more