class ImageCaptionGenerator {
constructor(runner_image, runner_caption, word_data) {
this.runner_image = runner_image;
this.runner_caption = runner_caption;
this.word_data = word_data;
let image_in_views = this.runner_image.getInputViews();
let image_out_views = this.runner_image.getOutputViews();
this.view_image_raw_in = image_in_views[0];
this.view_image_feature_out = image_out_views[0];
let cap_in_views = this.runner_caption.getInputViews();
let cap_out_views = this.runner_caption.getOutputViews();
this.view_cap_image_feature_in = cap_in_views[0];
this.view_cap_word_in = cap_in_views[1];
this.view_cap_image_switch = cap_in_views[2];
this.view_cap_word_switch = cap_in_views[3];
this.view_cap_h_in = cap_in_views[4];
this.view_cap_c_in = cap_in_views[5];
this.view_cap_word_prob = cap_out_views[0];
this.view_cap_h_out = cap_out_views[1];
this.view_cap_c_out = cap_out_views[2];
this.switch_off = new Float32Array(word_data.hidden_num);
this.zero_state = this.switch_off;
this.switch_on = new Float32Array(word_data.hidden_num);
this.switch_on.fill(1);
this.beam_stack = null;
this.beam_width = 10;
this.max_length = 20;
}
async generate_caption(image_data) {
console.log('Extracting feature');
let image_feature = await this.extract_image_feature(image_data);
console.log('Initializing caption model');
return await this.generate_caption_from_image_feature(image_feature);
}
async generate_caption_from_image_feature(image_feature) {
await this.set_image_feature(image_feature);
for (let i = 0; i < this.max_length; i++) {
let next_stack = [];
let any_updated = false;
for (let j = 0; j < this.beam_stack.length; j++) {
let current_status = this.beam_stack[j];
if (current_status[2] == this.word_data.eos_id) {
next_stack.push(current_status);
} else {
await this.predict_next_word(current_status, next_stack);
any_updated = true;
}
}
// sort by likelihood desc
next_stack.sort((a, b) => b[3] - a[3]);
next_stack.splice(this.beam_width);
this.beam_stack = next_stack;
if (!any_updated) {
break;
}
}
let sentence_strs = [];
for (let i = 0; i < this.beam_stack.length; i++) {
let current_status = this.beam_stack[i];
let words = current_status[1].map(
word_id => this.word_data.id_to_word[word_id]
);
words.splice(-1); //remove EOS
let sentence_str = words.join(' ');
sentence_strs.push(sentence_str);
}
return sentence_strs;
}
async extract_image_feature(image_data) {
this.view_image_raw_in.set(image_data);
await this.runner_image.run();
return this.view_image_feature_out.toActual();
}
async set_image_feature(image_feature) {
this.view_cap_image_feature_in.set(image_feature);
this.view_cap_word_in.set(new Float32Array([0]));
this.view_cap_image_switch.set(this.switch_on);
this.view_cap_word_switch.set(this.switch_off);
this.view_cap_h_in.set(this.zero_state);
this.view_cap_c_in.set(this.zero_state);
await this.runner_caption.run();
// lstm_states, sentence, last_word, likelihood
this.beam_stack = [
[
{
h: this.view_cap_h_out.toActual().slice(),
c: this.view_cap_c_out.toActual().slice()
},
[],
this.word_data.bos_id,
0.0
]
];
this.view_cap_image_switch.set(this.switch_off);
this.view_cap_word_switch.set(this.switch_on);
}
async predict_next_word(current_status, next_stack) {
this.view_cap_word_in.set(new Float32Array([current_status[2]]));
this.view_cap_h_in.set(current_status[0].h);
this.view_cap_c_in.set(current_status[0].c);
await this.runner_caption.run();
let h_array = this.view_cap_h_out.toActual().slice();
let c_array = this.view_cap_c_out.toActual().slice();
let word_probs = this.view_cap_word_prob.toActual();
let top_words = WebDNN.Math.argmax(word_probs, this.beam_width);
for (let i = 0; i < top_words.length; i++) {
let selected_word = top_words[i];
let new_sentence = current_status[1].concat(selected_word);
let new_likelihood =
current_status[3] + Math.log(word_probs[selected_word]);
next_stack.push([
{ h: h_array, c: c_array },
new_sentence,
selected_word,
new_likelihood
]);
}
}
}