async function matrixMultiply2D(A, B) {
const aRows = A.length;
if (aRows === 0) throw new Error("Matrix A has no rows.");
const aCols = A[0].length;
if (aCols === 0) throw new Error("Matrix A has no columns.");
for (let r = 1; r < aRows; r++) {
if (A[r].length !== aCols) {
throw new Error("All rows of A must have the same length.");
}
}
const bRows = B.length;
if (bRows === 0) throw new Error("Matrix B has no rows.");
const bCols = B[0].length;
if (bCols === 0) throw new Error("Matrix B has no columns.");
for (let r = 1; r < bRows; r++) {
if (B[r].length !== bCols) {
throw new Error("All rows of B must have the same length.");
}
}
if (aCols !== bRows) {
throw new Error(
`Dimension mismatch: A is ${aRows}x${aCols}, B is ${bRows}x${bCols}.`
);
}
// -- 2) Check WebGPU availability
if (!('gpu' in navigator)) {
console.error("WebGPU not supported. Falling back to CPU multiplication.");
return cpuMatrixMultiply2D(A, B);
}
// Try GPU multiplication in a try/catch
try {
const adapter = await navigator.gpu.requestAdapter();
const device = await adapter.requestDevice();
// Flatten A
const Adata = new Float32Array(aRows * aCols);
let idx = 0;
for (let i = 0; i < aRows; i++) {
for (let j = 0; j < aCols; j++) {
Adata[idx++] = A[i][j];
}
}
// Flatten B
const Bdata = new Float32Array(bRows * bCols);
idx = 0;
for (let i = 0; i < bRows; i++) {
for (let j = 0; j < bCols; j++) {
Bdata[idx++] = B[i][j];
}
}
// Prepare buffers
const bufferA = device.createBuffer({
size: Adata.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const bufferB = device.createBuffer({
size: Bdata.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const cRows = aRows;
const cCols = bCols;
const cSizeBytes = 4 * cRows * cCols;
const bufferC = device.createBuffer({
size: cSizeBytes,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
});
// Uniform buffer {aRows, aCols, bCols, 0}
const uniformBuffer = device.createBuffer({
size: 16, // 4 x 4 bytes
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
// Upload data
device.queue.writeBuffer(bufferA, 0, Adata);
device.queue.writeBuffer(bufferB, 0, Bdata);
const dims = new Uint32Array([aRows, aCols, bCols, 0]);
device.queue.writeBuffer(uniformBuffer, 0, dims);
// WGSL shader
const shaderCode = /* wgsl */`
@group(0) @binding(0) var<storage, read> A : array<f32>;
@group(0) @binding(1) var<storage, read> B : array<f32>;
@group(0) @binding(2) var<storage, read_write> C : array<f32>;
@group(0) @binding(3) var<uniform> dims : vec4<u32>;
// dims.x = aRows, dims.y = aCols, dims.z = bCols
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
let aRows = dims.x;
let aCols = dims.y;
let bCols = dims.z;
let row = gid.y;
let col = gid.x;
if (row < aRows && col < bCols) {
var sum = 0.0;
for (var k = 0u; k < aCols; k++) {
sum += A[row * aCols + k] * B[k * bCols + col];
}
C[row * bCols + col] = sum;
}
}
`;
const shaderModule = device.createShaderModule({ code: shaderCode });
// Pipeline
const pipeline = device.createComputePipeline({
layout: 'auto',
compute: {
module: shaderModule,
entryPoint: 'main',
},
});
// Bind group
const bindGroup = device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: bufferA } },
{ binding: 1, resource: { buffer: bufferB } },
{ binding: 2, resource: { buffer: bufferC } },
{ binding: 3, resource: { buffer: uniformBuffer } },
],
});
// Encode commands
const commandEncoder = device.createCommandEncoder();
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(pipeline);
passEncoder.setBindGroup(0, bindGroup);
// Dispatch
const workgroupSize = 16;
const dispatchX = Math.ceil(cCols / workgroupSize);
const dispatchY = Math.ceil(cRows / workgroupSize);
passEncoder.dispatchWorkgroups(dispatchX, dispatchY);
passEncoder.end();
// Copy back to a CPU-readable buffer
const readBuffer = device.createBuffer({
size: cSizeBytes,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
});
commandEncoder.copyBufferToBuffer(bufferC, 0, readBuffer, 0, cSizeBytes);
// Submit
device.queue.submit([commandEncoder.finish()]);
await readBuffer.mapAsync(GPUMapMode.READ);
const arrBuffer = readBuffer.getMappedRange();
const result = new Float32Array(arrBuffer.slice(0));
readBuffer.unmap();
// Convert to 2D
const C = [];
let pos = 0;
for (let i = 0; i < cRows; i++) {
const rowData = [];
for (let j = 0; j < cCols; j++) {
rowData.push(result[pos++]);
}
C.push(rowData);
}
return C;
} catch (err) {
console.error("WebGPU error:", err, "Falling back to CPU multiplication.");
// If anything fails on GPU side, fallback:
return cpuMatrixMultiply2D(A, B);
}
}