MultiD = {
function MultiD(x, y) {
assert(x.d === y.n, "matmul dimension misaligned");
this.n = x.n;
this.d = y.d;
this.x = x;
this.y = y;
this.require_grad = true;
this.items = new Mat(this.n, this.d);
this.out = this.items.out;
this.dout = this.items.dout
this.func_name = "<Multiply>"
for (var i = 0; i < x.n; i++) {
for (var j = 0; j < y.d; j++) {
var dot = 0.0;
for (var k = 0; k < x.d; k++) {
dot += this.x.out[x.d * i + k] * this.y.out[y.d * k + j];
}
this.out[this.d * i + j] = dot;
}
}
}
MultiD.prototype = {
backward: function () {
if (this.x.require_grad) {
for(var i = 0;i< this.x.n;i++){
for(var j=0;j<this.y.d;j++){
for(var k =0;k<this.x.d;k++){
var b = this.dout[this.y.d*i+j];
this.x.dout[this.x.d*i+k] += this.y.out[this.y.d*k+j] * b;
}
}
}
if ("backward" in this.x) {
this.x.backward()
}
}
if (this.y.require_grad) {
for(var i = 0;i< this.x.n;i++){
for(var j=0;j<this.y.d;j++){
for(var k =0;k<this.x.d;k++){
var b = this.dout[this.y.d*i+j];
this.y.dout[this.y.d*k+j] += this.x.out[this.x.d*i+k] * b;
}
}
}
if ("backward" in this.y) {
this.y.backward()
}
}
},
grad: function (g) {
assert(this.dout.length === g.length);
this.dout = g;
}
}
return MultiD
}