{
const { context, adapter, device } = await gpu.init(1, 1)
const extentShaderModule = device.createShaderModule({
label: 'extent quantization compute module',
code: `
@group(0) @binding(0) var<storage, read> data: array<f32>;
@group(0) @binding(1) var<storage, read_write> quantized_extent: array<atomic<i32>>;
const QUANTIZE_FACTOR = 32768.0;
fn updateExtent(index: u32, value: f32) {
let quantizedValue = i32(value * QUANTIZE_FACTOR);
if (index == 0) {
atomicStore(&quantized_extent[0], quantizedValue);
atomicStore(&quantized_extent[1], quantizedValue);
}
atomicMin(&quantized_extent[0], quantizedValue);
atomicMax(&quantized_extent[1], quantizedValue);
}
@compute @workgroup_size(1)
fn cs(
@builtin(global_invocation_id) id: vec3<u32>
) {
let i = id.x;
updateExtent(i, data[i]);
}
`,
})
const extentDequantShaderModule = device.createShaderModule({
label: 'extent dequantization compute module',
code: `
@group(0) @binding(0) var<storage, read> quantized_extent: array<i32>;
@group(0) @binding(1) var<storage, read_write> extent: array<f32, 2>;
const DEQUANTIZE_FACTOR = 1.0 / 32768.0;
fn getExtent(index: u32) -> vec2f {
// Loads the quantized normal values into a vector and dequantizes them.
return vec2f(f32(quantized_extent[0]), f32(quantized_extent[1])) * DEQUANTIZE_FACTOR;
}
@compute @workgroup_size(1)
fn cs(
@builtin(global_invocation_id) id: vec3<u32>
) {
let i = id.x;
let ext = getExtent(i);
extent = array(ext.x, ext.y);
}
`,
})
// ==================================================
// pipelines
// ==================================================
const extentPipeline = device.createComputePipeline({
label: 'extent quant compute pipeline',
layout: 'auto',
compute: {
module: extentShaderModule,
entryPoint: 'cs',
},
})
const extentDequantPipeline = device.createComputePipeline({
label: 'extent dequant compute pipeline',
layout: 'auto',
compute: {
module: extentDequantShaderModule,
entryPoint: 'cs',
},
})
// ==================================================
// buffers
// ==================================================
const dataBuffer = device.createBuffer({
label: 'data buffer',
size: input.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
});
device.queue.writeBuffer(dataBuffer, 0, input);
// this buffer contains quantized u32 values that represent f32s.
// we will convert them to f32s in a subsequent shader pass.
const extentQuantBuffer = device.createBuffer({
size: 2 * 4, // min and max * 4 bytes per (u32),
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
})
// this contains f32 values after they've been converted in the
// dequantization shader
const extentDequantBuffer = device.createBuffer({
size: 2 * 4, // min and max * 4 bytes per (f32),
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
})
// mapped extent results buffer that lets us read gpu-derived values
// on the cpu
const resultBuffer = device.createBuffer({
label: 'extent dequant result buffer',
size: extentDequantBuffer.size,
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
})
// ==================================================
// bind groups
// ==================================================
const extentBindGroup = device.createBindGroup({
label: 'extent quant bind group',
layout: extentPipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: dataBuffer } },
{ binding: 1, resource: { buffer: extentQuantBuffer } },
],
})
const extentDequantBindGroup = device.createBindGroup({
label: 'extent dequant bind group',
layout: extentDequantPipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: extentQuantBuffer } },
{ binding: 1, resource: { buffer: extentDequantBuffer } },
],
})
// ==================================================
// render
// ==================================================
const encoder = device.createCommandEncoder({ label: 'encoder' })
const pass = encoder.beginComputePass({ label: 'compute pass' })
pass.setPipeline(extentPipeline)
pass.setBindGroup(0, extentBindGroup)
pass.dispatchWorkgroups(input.length)
pass.setPipeline(extentDequantPipeline)
pass.setBindGroup(0, extentDequantBindGroup)
pass.dispatchWorkgroups(input.length)
pass.end()
encoder.copyBufferToBuffer(extentDequantBuffer, 0, resultBuffer, 0, resultBuffer.size)
device.queue.submit([encoder.finish()])
// ==================================================
// gpu -> cpu
// ==================================================
const output = await gpu.readBuffer(resultBuffer)
return JSON.stringify({ input, min_and_max_output_from_compute_shader: output }, null, 2)
}