The Zmv extension is defined to provide matrix support with the RISC-V Vector "V" extension.
The Zmv extension allows to load matrix tile slices into vector registers, and move data between slices of a matrix register and vector registers.
The data layout examples of registers and memory in Zmv are shown below.
# vd destination, rs1 base address, rs2 row byte stride
# lmul / (eew/sew) rows or columns
# for left matrix, a
mlae8.v vd, (rs1), rs2 # 8-bit tile slices load to vregs
mlae16.v vd, (rs1), rs2 # 16-bit tile slices load to vregs
mlae32.v vd, (rs1), rs2 # 32-bit tile slices load to vregs
mlae64.v vd, (rs1), rs2 # 64-bit tile slices load to vregs
# for right matrix, b
mlbe8.v vd, (rs1), rs2 # 8-bit tile slices load to vregs
mlbe16.v vd, (rs1), rs2 # 16-bit tile slices load to vregs
mlbe32.v vd, (rs1), rs2 # 32-bit tile slices load to vregs
mlbe64.v vd, (rs1), rs2 # 64-bit tile slices load to vregs
# for output matrix, c
mlce8.v vd, (rs1), rs2 # 8-bit tile slices load to vregs
mlce16.v vd, (rs1), rs2 # 16-bit tile slices load to vregs
mlce32.v vd, (rs1), rs2 # 32-bit tile slices load to vregs
mlce64.v vd, (rs1), rs2 # 64-bit tile slices load to vregs
# vs3 store data, rs1 base address, rs2 row byte stride
# lmul / (eew/sew) rows or columns
# for left matrix, a
msae8.v vs3, (rs1), rs2 # 8-bit tile slices store from vregs
msae16.v vs3, (rs1), rs2 # 16-bit tile slices store from vregs
msae32.v vs3, (rs1), rs2 # 32-bit tile slices store from vregs
msae64.v vs3, (rs1), rs2 # 64-bit tile slices store from vregs
# for right matrix, b
msbe8.v vs3, (rs1), rs2 # 8-bit tile slices store from vregs
msbe16.v vs3, (rs1), rs2 # 16-bit tile slices store from vregs
msbe32.v vs3, (rs1), rs2 # 32-bit tile slices store from vregs
msbe64.v vs3, (rs1), rs2 # 64-bit tile slices store from vregs
# for output matrix, c
msce8.v vs3, (rs1), rs2 # 8-bit tile slices store from vregs
msce16.v vs3, (rs1), rs2 # 16-bit tile slices store from vregs
msce32.v vs3, (rs1), rs2 # 32-bit tile slices store from vregs
msce64.v vs3, (rs1), rs2 # 64-bit tile slices store from vregs
For data moving between vector and matrix, the vsew of vector must equal to msew of matrix.
The number of elements moved is min(VLEN/SEW * VLMUL, matrix_size).
-
For matrix A, matrix_size = mtilem * mtilek.
-
For matrix B, matrix_size = mtilek * mtilen.
-
For matrix C, matrix_size = mtilem * mtilen.
# Data move between matrix register rows and vector registers.
# vd[(i - rs2) * mtilek + j] = md[i, j], i = rs2 .. rs2 + mtilem - 1
mmvare8.v.m vd, ms1, rs2
mmvare16.v.m vd, ms1, rs2
mmvare32.v.m vd, ms1, rs2
mmvare64.v.m vd, ms1, rs2
# vd[(i - rs2) * mtilen + j] = md[i, j], i = rs2 .. rs2 + mtilek - 1
mmvbre8.v.m vd, ms1, rs2
mmvbre16.v.m vd, ms1, rs2
mmvbre32.v.m vd, ms1, rs2
mmvbre64.v.m vd, ms1, rs2
# vd[(i - rs2) * mtilen + j] = md[i, j], i = rs2 .. rs2 + mtilem - 1
mmvcre8.v.m vd, ms1, rs2
mmvcre16.v.m vd, ms1, rs2
mmvcre32.v.m vd, ms1, rs2
mmvcre64.v.m vd, ms1, rs2
# md[i, j] = vd[(i - rs2) * mtilek + j], i = rs2 .. rs2 + mtilem - 1
mmvare8.m.v md, vs1, rs2
mmvare16.m.v md, vs1, rs2
mmvare32.m.v md, vs1, rs2
mmvare64.m.v md, vs1, rs2
# md[i, j] = vd[(i - rs2) * mtilen + j], i = rs2 .. rs2 + mtilek - 1
mmvbre8.m.v md, vs1, rs2
mmvbre16.m.v md, vs1, rs2
mmvbre32.m.v md, vs1, rs2
mmvbre64.m.v md, vs1, rs2
# md[i, j] = vd[(i - rs2) * mtilen + j], i = rs2 .. rs2 + mtilem - 1
mmvcre8.m.v md, vs1, rs2
mmvcre16.m.v md, vs1, rs2
mmvcre32.m.v md, vs1, rs2
mmvcre64.m.v md, vs1, rs2
# Data move between matrix register columns and vector registers.
# vd[(j - rs2) * mtilem + i] = md[i, j], j = rs2 .. rs2 + mtilek - 1
mmvace8.v.m vd, ms1, rs2
mmvace16.v.m vd, ms1, rs2
mmvace32.v.m vd, ms1, rs2
mmvace64.v.m vd, ms1, rs2
# vd[(j - rs2) * mtilek + i] = md[i, j], j = rs2 .. rs2 + mtilen - 1
mmvbce8.v.m vd, ms1, rs2
mmvbce16.v.m vd, ms1, rs2
mmvbce32.v.m vd, ms1, rs2
mmvbce64.v.m vd, ms1, rs2
# vd[(j - rs2) * mtilem + i] = md[i, j], j = rs2 .. rs2 + mtilen - 1
mmvcce8.v.m vd, ms1, rs2
mmvcce16.v.m vd, ms1, rs2
mmvcce32.v.m vd, ms1, rs2
mmvcce64.v.m vd, ms1, rs2
# md[i, j] = vd[(j - rs2) * mtilem + i], j = rs2 .. rs2 + mtilek - 1
mmvace8.m.v md, vs1, rs2
mmvace16.m.v md, vs1, rs2
mmvace32.m.v md, vs1, rs2
mmvace64.m.v md, vs1, rs2
# md[i, j] = vd[(j - rs2) * mtilek + i], j = rs2 .. rs2 + mtilen - 1
mmvbce8.m.v md, vs1, rs2
mmvbce16.m.v md, vs1, rs2
mmvbce32.m.v md, vs1, rs2
mmvbce64.m.v md, vs1, rs2
# md[i, j] = vd[(j - rs2) * mtilem + i], j = rs2 .. rs2 + mtilen - 1
mmvcce8.m.v md, vs1, rs2
mmvcce16.m.v md, vs1, rs2
mmvcce32.m.v md, vs1, rs2
mmvcce64.m.v md, vs1, rs2
void fused_matmul_relu_float16(c, a, b, m, k, n) {
msettype(e16); // use 16bit input matrix element
for (i = 0; i < m; i += tile_m) { // loop at dim m with tiling
tile_m = msettilem(m-i);
for (j = 0; j < n; j += tile_n) { // loop at dim n with tiling
tile_n = msettilen(n-j);
out = mwsub_mm(out, out) // clear acc reg
for (s = 0; s < k; s += tile_k) { // loop at dim k with tiling
tile_k = msettilek(k-s);
tr1 = mlae16_m(&a[i][s]); // load left matrix a
tr2 = mlbe16_m(&b[s][j]); // load right matrix b
out = mfwma_mm(tr1, tr2); // tiled matrix multiply,
// double widen output
}
out = mfncvt_f_fw_m(out, m2); // convert widen result to single
for (s = 0; s < tile_m; s += rows) {
// max rows could move into 8 vregs
rows = min(tile_m - s, 8*vlenb/rlenb);
vsetvl(tile_n*rows, e16, m8);
v1 = mmvcr_v_m(out, s); // move out rows to vreg
v1 = vfmax_vf(0.f, v1); // vfmax.vf for relu
msce16_v(v1, &c[i+s][j], n); // store output tile slices
}
}
}
}