Public
Edited
Sep 2, 2020
1 fork
Importers
3 stars
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
function logisticRegression(features,
labels,
normalize=true,
epochs=50,
learningRate=0.1,
batchSize=50,
l2Regularization=0.01) {
// Preprocess labels
let oneHotLabels = labels;
if (!Array.isArray(labels[0]) || !labels.columns)
oneHotLabels = oneHotEncode(labels);
// Define model
const model = tf.sequential();
model.add(
tf.layers.dense({
units: oneHotLabels.columns.length,
activation: "softmax",
inputShape: features.columns.length,
useBias: true,
kernelRegularizer: tf.regularizers.l2({l2: l2Regularization}),
biasRegularizer: tf.regularizers.l2({l2: l2Regularization})
})
);
// Define loss and optimization algorithm
const optimizer = tf.train.adam(learningRate);
model.compile({
optimizer: optimizer,
loss: "binaryCrossentropy",
metrics: ["accuracy"]
});
// Prepare features
const featureMatrix = rows => rows.map(row => features.columns.map(col => row[col]));
let inputTensor = tf.tensor2d(featureMatrix(features));
const labelTensor = tf.tensor2d(oneHotLabels);
let moments;
if (normalize) {
moments = tf.moments(inputTensor, 0);
const std = tf.sqrt(moments.variance);
const zeroMean = inputTensor.sub(moments.mean);
inputTensor.dispose();
inputTensor = zeroMean.div(std);
std.dispose();
zeroMean.dispose();
}
// Fit model
const fitPromise = model.fit(
inputTensor,
labelTensor, {
batchSize,
epochs,
shuffle: true,
});
return fitPromise.then(() => {
inputTensor.dispose();
labelTensor.dispose();
return {
predict: featuresPredict => {
let predictions;
tf.tidy(() => {
let testTensor = tf.tensor2d(featureMatrix(featuresPredict));
if (normalize)
testTensor = testTensor
.sub(moments.mean)
.div(tf.sqrt(moments.variance));
const predictionsTensor = model.predict(testTensor);
predictions = predictionsTensor.array();
predictions.then(value => {
value.columns = oneHotLabels.columns;
return value;
})
});
return predictions;
},
dispose: () => {
model.dispose();
optimizer.dispose();
if (moments) {
moments.mean.dispose();
moments.variance.dispose();
}
}
}
})
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
model = logisticRegression(X, y)
.then(model => {
invalidation.then(() => model.dispose());
return model;
})
.catch(error => console.log(error))
Insert cell
Insert cell
Insert cell
y_prob = model.predict(X_test)
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more