renderable = {
const { device, context, format } = await gpu.init(image_dims[0], image_dims[1])
const module = device.createShaderModule({
code: `
struct Uniforms {
mat: mat4x4f,
slice_index: f32,
}
struct VertexOut {
@builtin(position) position: vec4f,
@location(0) texcoord: vec2f,
};
@group(0) @binding(0) var texSampler: sampler;
@group(0) @binding(1) var tex: texture_3d<f32>;
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
@group(0) @binding(3) var colorPalette: texture_2d<f32>;
@vertex
fn vs(
@builtin(vertex_index) vertexIndex : u32
) -> VertexOut {
let pos = array(
vec2f(0.0, 0.0),
vec2f(1.0, 0.0),
vec2f(0.0, 1.0),
vec2f(0.0, 1.0),
vec2f(1.0, 0.0),
vec2f(1.0, 1.0),
);
let texcoord = pos[vertexIndex];
return VertexOut(
vec4f((texcoord - 0.5) * 2, 0.0, 1.0),
texcoord
);
}
@fragment
fn fs(vout: VertexOut) -> @location(0) vec4f {
// derive the 3d texcoord using 2d texcoord & slice index
let texCoord3d = vec3(vout.texcoord, uniforms.slice_index);
// get (normalized) color from brain scan
let c = textureSample(tex, texSampler, texCoord3d);
// if no data, discard
if (c.r == 0.0) {
discard;
}
// use red channel for (normalized) x-position for palette lookup
let pvec = vec2(c.r, 1.0);
// sample the color from the palette
let color = textureSample(colorPalette, texSampler, pvec);
return color;
}
`,
});
//////////////////////////////////////////////////////
// color palette
//////////////////////////////////////////////////////
const colorTexture = device.createTexture({
format: 'rgba8unorm',
size: [colorPaletteImage.width, colorPaletteImage.height],
usage:
GPUTextureUsage.TEXTURE_BINDING |
GPUTextureUsage.COPY_DST |
GPUTextureUsage.RENDER_ATTACHMENT
})
device.queue.copyExternalImageToTexture(
{ source: colorPaletteImage },
{ texture: colorTexture },
{ width: colorPaletteImage.width, height: colorPaletteImage.height }
);
//////////////////////////////////////////////////////
// 3D texture
//////////////////////////////////////////////////////
const texture = device.createTexture({
dimension: '3d',
size: image_dims,
format: 'r8unorm',
usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST,
});
device.queue.writeTexture(
{ texture, flipY: true },
new Uint8Array(imageArrayBuffer),
{ bytesPerRow: image_dims[0], rowsPerImage: image_dims[1] },
image_dims
);
//////////////////////////////////////////////////////
// pipeline
//////////////////////////////////////////////////////
const pipeline = device.createRenderPipeline({
layout: 'auto',
vertex: {
module,
entryPoint: 'vs',
},
fragment: {
module,
entryPoint: 'fs',
targets: [{ format }],
},
primitive: {
topology: 'triangle-list',
cullMode: 'back',
},
});
const sampler = device.createSampler({
addressModeU: 'clamp-to-edge',
addressMoveV: 'clamp-to-edge',
magFilter: 'nearest'
});
//////////////////////////////////////////////////////
// uniforms
//////////////////////////////////////////////////////
const uniformBufferSize =
4 * 16 + // 4x4 matrix
4 * 1 + // slice index
4 * 3; // padding (min size is 80, at least on my gpu/browser)
// transform matrix just a placeholder for now
const uniforms = new Float32Array([...util.arr(16, 0), INITIAL_SLICE_INDEX])
const uniformBuffer = device.createBuffer({
size: uniformBufferSize,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
})
const bindGroup = device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: sampler },
{ binding: 1, resource: texture.createView() },
{ binding: 2, resource: { buffer: uniformBuffer }},
{ binding: 3, resource: colorTexture.createView() },
]
});
const renderPassDescriptor = {
colorAttachments: [
{
clearValue: [0, 0, 0, 1],
loadOp: 'clear',
storeOp: 'store',
},
],
};
function render(nextUniforms = {}) {
if (nextUniforms.slice_index) {
uniforms.set([nextUniforms.slice_index / (image_dims[2] - 1)], 16)
device.queue.writeBuffer(uniformBuffer, 0, uniforms);
}
renderPassDescriptor.colorAttachments[0].view = context.getCurrentTexture().createView();
const encoder = device.createCommandEncoder();
const pass = encoder.beginRenderPass(renderPassDescriptor);
pass.setPipeline(pipeline);
pass.setBindGroup(0, bindGroup);
pass.draw(6); // call vert shader 6 times
pass.end();
device.queue.submit([encoder.finish()]);
}
render()
return { render, canvas: context.canvas }
}