Public
Edited
Nov 10, 2024
Insert cell
Insert cell
Insert cell
Insert cell
{
const { context, adapter, device } = await gpu.init(1, 1)
// ==================================================
// shaders
// ==================================================
// this shader uses atomics which only work with u32 and i32. to
// get the extent of float values, we have to calculate them using
// int types, output them in a buffer, then pass that buffer
// to another shader to convert it to f32s in a subsequent shader pass.
const extentShaderModule = device.createShaderModule({
label: 'extent quantization compute module',
code: `
@group(0) @binding(0) var<storage, read> data: array<f32>;
@group(0) @binding(1) var<storage, read_write> quantized_extent: array<atomic<i32>>;

const QUANTIZE_FACTOR = 32768.0;
fn updateExtent(index: u32, value: f32) {
let quantizedValue = i32(value * QUANTIZE_FACTOR);
if (index == 0) {
atomicStore(&quantized_extent[0], quantizedValue);
atomicStore(&quantized_extent[1], quantizedValue);
}
atomicMin(&quantized_extent[0], quantizedValue);
atomicMax(&quantized_extent[1], quantizedValue);
}

@compute @workgroup_size(1)
fn cs(
@builtin(global_invocation_id) id: vec3<u32>
) {
let i = id.x;
updateExtent(i, data[i]);
}
`,
})

const extentDequantShaderModule = device.createShaderModule({
label: 'extent dequantization compute module',
code: `
@group(0) @binding(0) var<storage, read> quantized_extent: array<i32>;
@group(0) @binding(1) var<storage, read_write> extent: array<f32, 2>;

const DEQUANTIZE_FACTOR = 1.0 / 32768.0;
fn getExtent(index: u32) -> vec2f {
// Loads the quantized normal values into a vector and dequantizes them.
return vec2f(f32(quantized_extent[0]), f32(quantized_extent[1])) * DEQUANTIZE_FACTOR;
}

@compute @workgroup_size(1)
fn cs(
@builtin(global_invocation_id) id: vec3<u32>
) {
let i = id.x;
let ext = getExtent(i);
extent = array(ext.x, ext.y);
}
`,
})

// ==================================================
// pipelines
// ==================================================
const extentPipeline = device.createComputePipeline({
label: 'extent quant compute pipeline',
layout: 'auto',
compute: {
module: extentShaderModule,
entryPoint: 'cs',
},
})

const extentDequantPipeline = device.createComputePipeline({
label: 'extent dequant compute pipeline',
layout: 'auto',
compute: {
module: extentDequantShaderModule,
entryPoint: 'cs',
},
})

// ==================================================
// buffers
// ==================================================
const dataBuffer = device.createBuffer({
label: 'data buffer',
size: input.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
});
device.queue.writeBuffer(dataBuffer, 0, input);

// this buffer contains quantized u32 values that represent f32s.
// we will convert them to f32s in a subsequent shader pass.
const extentQuantBuffer = device.createBuffer({
size: 2 * 4, // min and max * 4 bytes per (u32),
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
})

// this contains f32 values after they've been converted in the
// dequantization shader
const extentDequantBuffer = device.createBuffer({
size: 2 * 4, // min and max * 4 bytes per (f32),
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
})

// mapped extent results buffer that lets us read gpu-derived values
// on the cpu
const resultBuffer = device.createBuffer({
label: 'extent dequant result buffer',
size: extentDequantBuffer.size,
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
})

// ==================================================
// bind groups
// ==================================================
const extentBindGroup = device.createBindGroup({
label: 'extent quant bind group',
layout: extentPipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: dataBuffer } },
{ binding: 1, resource: { buffer: extentQuantBuffer } },
],
})

const extentDequantBindGroup = device.createBindGroup({
label: 'extent dequant bind group',
layout: extentDequantPipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: extentQuantBuffer } },
{ binding: 1, resource: { buffer: extentDequantBuffer } },
],
})

// ==================================================
// render
// ==================================================
const encoder = device.createCommandEncoder({ label: 'encoder' })
const pass = encoder.beginComputePass({ label: 'compute pass' })

pass.setPipeline(extentPipeline)
pass.setBindGroup(0, extentBindGroup)
pass.dispatchWorkgroups(input.length)

pass.setPipeline(extentDequantPipeline)
pass.setBindGroup(0, extentDequantBindGroup)
pass.dispatchWorkgroups(input.length)
pass.end()

encoder.copyBufferToBuffer(extentDequantBuffer, 0, resultBuffer, 0, resultBuffer.size)

device.queue.submit([encoder.finish()])

// ==================================================
// gpu -> cpu
// ==================================================

const output = await gpu.readBuffer(resultBuffer)
return JSON.stringify({ input, min_and_max_output_from_compute_shader: output }, null, 2)
}
Insert cell
gpu = ({
init: async (width = 512, height = 512) => {
const canvas = document.createElement('canvas');
canvas.width = width;
canvas.height = height;
const context = canvas.getContext('webgpu');

const adapter = await navigator.gpu.requestAdapter();
const device = await adapter.requestDevice();
const format = navigator.gpu.getPreferredCanvasFormat();
context.configure({ device, format });
return { context, adapter, device }
},
readBuffer: async (buffer, ArrayType = Float32Array) => {
await buffer.mapAsync(GPUMapMode.READ)
const result = new ArrayType(buffer.getMappedRange().slice())
buffer.unmap()
return result
}
})
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more