{
const { context, adapter, device } = await gpu.init(1, 1)
const scaleShaderModule = device.createShaderModule({
label: 'scale shader module',
code: `
@group(0) @binding(0) var<storage, read_write> data: array<f32>;
// should come up with a shader architecture that allows easy composition
fn scale(value: f32) -> f32 {
return sin(value);
}
@compute @workgroup_size(1)
fn cs(
@builtin(global_invocation_id) id: vec3<u32>
) {
let i = id.x;
data[i] = scale(data[i]);
}
`
})
const extentShaderModule = device.createShaderModule({
label: 'extent quantization compute module',
code: `
@group(0) @binding(0) var<storage, read> data: array<f32>;
@group(0) @binding(1) var<storage, read_write> quantized_extent: array<atomic<i32>>;
const QUANTIZE_FACTOR = 32768.0;
fn updateExtent(index: u32, value: f32) {
let quantizedValue = i32(value * QUANTIZE_FACTOR);
if (index == 0) {
atomicStore(&quantized_extent[0], quantizedValue);
atomicStore(&quantized_extent[1], quantizedValue);
}
atomicMin(&quantized_extent[0], quantizedValue);
atomicMax(&quantized_extent[1], quantizedValue);
}
@compute @workgroup_size(1)
fn cs(
@builtin(global_invocation_id) id: vec3<u32>
) {
let i = id.x;
updateExtent(i, data[i]);
}
`,
})
const extentDequantShaderModule = device.createShaderModule({
label: 'extent dequantization compute module',
code: `
@group(0) @binding(0) var<storage, read> quantized_extent: array<i32>;
@group(0) @binding(1) var<storage, read_write> extent: array<f32, 2>;
const DEQUANTIZE_FACTOR = 1.0 / 32768.0;
fn getExtent(index: u32) -> vec2f {
// Loads the quantized normal values into a vector and dequantizes them.
return vec2f(f32(quantized_extent[0]), f32(quantized_extent[1])) * DEQUANTIZE_FACTOR;
}
@compute @workgroup_size(1)
fn cs(
@builtin(global_invocation_id) id: vec3<u32>
) {
let i = id.x;
let ext = getExtent(i);
extent = array(ext.x, ext.y);
}
`,
})
const scalePipeline = device.createComputePipeline({
label: 'scale compute pipeline',
layout: 'auto',
compute: {
module: scaleShaderModule,
entryPoint: 'cs',
},
})
const extentPipeline = device.createComputePipeline({
label: 'extent quant compute pipeline',
layout: 'auto',
compute: {
module: extentShaderModule,
entryPoint: 'cs',
},
})
const extentDequantPipeline = device.createComputePipeline({
label: 'extent dequant compute pipeline',
layout: 'auto',
compute: {
module: extentDequantShaderModule,
entryPoint: 'cs',
},
})
const dataBuffer = device.createBuffer({
label: 'data buffer',
size: input.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
});
device.queue.writeBuffer(dataBuffer, 0, input);
const extentQuantBuffer = device.createBuffer({
size: 2 * 4,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
})
const extentDequantBuffer = device.createBuffer({
size: 2 * 4,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
})
const extentResultBuffer = device.createBuffer({
label: 'extent dequant result buffer',
size: extentDequantBuffer.size,
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
})
const dataResultBuffer = device.createBuffer({
label: 'data result buffer',
size: dataBuffer.size,
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
})
const scaleBindGroup = device.createBindGroup({
label: 'scale bind group',
layout: scalePipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: dataBuffer } },
],
})
const extentBindGroup = device.createBindGroup({
label: 'extent quant bind group',
layout: extentPipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: dataBuffer } },
{ binding: 1, resource: { buffer: extentQuantBuffer } },
],
})
const extentDequantBindGroup = device.createBindGroup({
label: 'extent dequant bind group',
layout: extentDequantPipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: extentQuantBuffer } },
{ binding: 1, resource: { buffer: extentDequantBuffer } },
],
})
const encoder = device.createCommandEncoder({ label: 'encoder' })
const pass = encoder.beginComputePass({ label: 'compute pass' })
pass.setPipeline(scalePipeline)
pass.setBindGroup(0, scaleBindGroup)
pass.dispatchWorkgroups(input.length)
pass.setPipeline(extentPipeline)
pass.setBindGroup(0, extentBindGroup)
pass.dispatchWorkgroups(input.length)
pass.setPipeline(extentDequantPipeline)
pass.setBindGroup(0, extentDequantBindGroup)
pass.dispatchWorkgroups(input.length)
pass.end()
encoder.copyBufferToBuffer(dataBuffer, 0, dataResultBuffer, 0, dataResultBuffer.size)
encoder.copyBufferToBuffer(extentDequantBuffer, 0, extentResultBuffer, 0, extentResultBuffer.size)
device.queue.submit([encoder.finish()])
const scaled = await gpu.readBuffer(dataResultBuffer)
const extent = await gpu.readBuffer(extentResultBuffer)
return htl.html`${[
Inputs.table([['inputs', ...input], ['scaled', ...scaled]]),
Inputs.table([{ '(scaled) min': extent[0], '(scaled) max': extent[1] }], { columns: ['(scaled) min', '(scaled) max'] })
]}`
}