Public
Edited
Dec 6, 2023
Insert cell
Insert cell
sdk = await import("https://aaronge-2020.github.io/JSTransformer/sdk.js")
// sdk = await import("http://127.0.0.1:5500/sdk.js")
Insert cell
response = await fetch('https://aaronge-2020.github.io/JSTransformer/curated_dataset.json');

Insert cell
jsonData = await response.json();

Insert cell
processedData = sdk.processJson(jsonData.slice(start_point, start_point+numb_samples), "English", "Spanish", input_vocab_size, target_vocab_size, MAX_TOKENS)
Insert cell
sentenceEnglishTokenized = processedData.trainData[10].English
Insert cell
sentenceSpanishTokenized = processedData.trainData[10].Spanish
Insert cell
Object.keys(processedData.tokenizers.English)
Insert cell
Object.keys(processedData.tokenizers.Spanish)
Insert cell
sdk.detokenizeSentence(sentenceEnglishTokenized, processedData.tokenizers.English)
Insert cell
sdk.detokenizeSentence(sentenceSpanishTokenized, processedData.tokenizers.Spanish)
Insert cell
tf = await import("https://esm.sh/@tensorflow/tfjs@4.10.0");

Insert cell
en_train = processedData.trainData.map((item) => item.English)
Insert cell
Insert cell
// en_validation = processedData.validationData.map((item) => item.English)
Insert cell
sp_train = processedData.trainData.map((item) => item.Spanish)
Insert cell
Insert cell
Insert cell
function countOccurrences(arrOfArrs, target) {
let count = 0;

// Loop through each sub-array
for (let i = 0; i < arrOfArrs.length; i++) {
const subArray = arrOfArrs[i];

// Loop through each element in the sub-array
for (let j = 0; j < subArray.length; j++) {
if (subArray[j] === target) {
count++;
}
}
}

return count;
}
Insert cell
// sp_validation = processedData.validationData.map((item) => item.Spanish)
Insert cell
start_point = 0
Insert cell
batch_size = 32*4;
Insert cell
numb_of_batches = 125;

Insert cell
numb_samples = batch_size * numb_of_batches;
Insert cell
training_dataset = sdk.shiftTokens(sp_train, 2)
Insert cell
target_lang_input_train = training_dataset[0]

Insert cell
target_lang_label_train = training_dataset[1]

Insert cell
seq_len = target_lang_label_train[0].length; // assuming all items have the same length

Insert cell
word_probs_label_train = new Array(sp_train.length).fill(null).map(() =>
new Array(seq_len).fill(null).map(() =>
new Array(target_vocab_size).fill(0)
)
);
Insert cell
target_lang_label_train.forEach((batch, batchIndex) => {
batch.forEach((token, tokenIndex) => {
if (token < target_vocab_size) {
word_probs_label_train[batchIndex][tokenIndex][token] = 1;
}
});
});
Insert cell
word_probs_label_validation = new Array(sp_validation.length).fill(null).map(() =>
new Array(seq_len).fill(null).map(() =>
new Array(target_vocab_size).fill(0)
)
);

Insert cell
sp_validation.forEach((batch, batchIndex) => {
batch.forEach((token, tokenIndex) => {
if (token < target_vocab_size) {
word_probs_label_validation[batchIndex][tokenIndex][token] = 1;
}
});
});
Insert cell
function translateProbabilityToSentence(probability, detokenizer){
return probability.map((sentence) => sentence.map ( (word) => detokenizer[word.indexOf(Math.max(...word))]))
}
Insert cell
translateProbabilityToSentence(word_probs_label_train, processedData.tokenizers.Spanish)
Insert cell
en_train.map ((sentence) => sdk.detokenizeSentence(sentence, processedData.tokenizers.English))
Insert cell
Insert cell
// Test the Transformer
num_layers = 4

Insert cell
d_model = 128

Insert cell
dff = 128
Insert cell
numb_epochs = 500
Insert cell
num_heads = 8
Insert cell
dropout_rate = 0.1

Insert cell
input_vocab_size = 3000

Insert cell
target_vocab_size = 5500

Insert cell
MAX_TOKENS = 13;
Insert cell
transformerModel = new sdk.TransformerModel(num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, dropout_rate, input_vocab_size, target_vocab_size, MAX_TOKENS);

Insert cell
model = transformerModel.model;
Insert cell
model.compile({
loss: sdk.maskedLoss,
optimizer: 'adam',
metrics: sdk.maskedAccuracy
});

Insert cell
en_train_batches = sdk.createMiniBatches(en_train, batch_size);
Insert cell
target_lang_input_train_batches = sdk.createMiniBatches(target_lang_input_train, batch_size);
Insert cell
word_probs_label_train_batches = sdk.createMiniBatches(word_probs_label_train, batch_size);
Insert cell
// loadedModel = await tf.loadLayersModel(
// "http://127.0.0.1:5500/my-model-epoch-70-batch-40.json"
// );
Insert cell
// loadedModel.compile({
// loss: sdk.maskedLoss,
// optimizer: "adam",
// metrics: sdk.maskedAccuracy,
// });
Insert cell
Insert cell
training =
{
for (let j = 0; j < numb_epochs; j++) {
for (let i = 1; i < en_train_batches.length - 1; i++) {
const train_x_batch = tf.tensor(en_train_batches[i]);
const train_y_batch = tf.tensor(target_lang_input_train_batches[i]);
const labels_batch = tf.tensor(word_probs_label_train_batches[i]);

try {
// Train the model on the current batch
await model.trainOnBatch(
[train_x_batch, train_y_batch],
labels_batch
);

// save the model every 20 batches
if (i % 20 === 0){
await model.save('downloads://my-model' + '-epoch-' + (j + 1) + '-batch-' + (i));
}


} catch (error) {
console.error(error);
}

// Dispose tensors to free memory
train_x_batch.dispose();
train_y_batch.dispose();
labels_batch.dispose();

console.log(
`Batch ${i + 1} completed. ${
en_train_batches.length - i - 1
} batches remaining.`
);
}

console.log(
`Epoch ${j + 1} completed. ${numb_epochs - j - 1} epochs remaining.`
);
}
}
Insert cell
// target_lang_input_train_batches.length
Insert cell
Insert cell
Insert cell
// model.save("downloads://test-model-v2", {
// customLayers: true
// })
Insert cell
Insert cell
// loadedModelv1.weights[0].val
Insert cell
// loadedModel.input
Insert cell
English_sentences = en_train_batches[8].slice(0,10)
Insert cell
Spanish_sentences = target_lang_input_train_batches[8].slice(0,10)
Insert cell
English_sentences.map((sentence) => sdk.detokenizeSentence(sentence, processedData.tokenizers.English))
Insert cell
Spanish_sentences.map((sentence) => sdk.detokenizeSentence(sentence, processedData.tokenizers.Spanish))
Insert cell
Insert cell
translateProbabilityToSentence(predictions, processedData.tokenizers.Spanish)
Insert cell

function translate(sentence, tokenizers, transformer, MAX_TOKENS) {
// Tokenize the English input sentence
const encoderInput = sdk.wordsToIntTokens(sentence,tokenizers.English, MAX_TOKENS);

// console.log(encoderInput);
// Initialize the Spanish output with the [START] token
const startEnd = [1,2];
const start = startEnd[0];
const end = startEnd[1];
let outputArray = [start].concat(Array(MAX_TOKENS-1).fill(0));

for (let i = 0; i < MAX_TOKENS; i++) {
// Prepare encoder and decoder inputs
const encoderInputTensor = tf.tensor([encoderInput]);
const outputTensor = tf.tensor([outputArray]);

// Get predictions
const predictions = transformer.predict([encoderInputTensor, outputTensor]);
// console.log(predictions.shape);
const lastPrediction = predictions.slice([0, i, 0], [1, 1, predictions.shape[2]]);

// Get the ID of the predicted token
const predictedId = lastPrediction.argMax(-1).dataSync()[0];

console.log(predictedId);

// Replace the placeholder 0 with the predicted token to output
outputArray[i+1] = predictedId;

// Check for [END] token
if (predictedId === end) {
break;
}
}

// Detokenize the output to get the translated sentence
const translatedSentence = sdk.detokenizeSentence(outputArray, tokenizers.Spanish);

return translatedSentence;
}

Insert cell
translate("Look back", processedData.tokenizers, loadedModel, 13)
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more