Skip to content

Latest commit

 

History

History
187 lines (145 loc) · 6.38 KB

ext-zmv.adoc

File metadata and controls

187 lines (145 loc) · 6.38 KB

Zmv: Matrix for Vector operations

The Zmv extension is defined to provide matrix support with the RISC-V Vector "V" extension.

The Zmv extension allows to load matrix tile slices into vector registers, and move data between slices of a matrix register and vector registers.

The data layout examples of registers and memory in Zmv are shown below.

Data layout examples of registers and memory

Load Instructions

# vd destination, rs1 base address, rs2 row byte stride
# lmul / (eew/sew) rows or columns

# for left matrix, a
mlae8.v    vd, (rs1), rs2 #  8-bit tile slices load to vregs
mlae16.v   vd, (rs1), rs2 # 16-bit tile slices load to vregs
mlae32.v   vd, (rs1), rs2 # 32-bit tile slices load to vregs
mlae64.v   vd, (rs1), rs2 # 64-bit tile slices load to vregs

# for right matrix, b
mlbe8.v    vd, (rs1), rs2 #  8-bit tile slices load to vregs
mlbe16.v   vd, (rs1), rs2 # 16-bit tile slices load to vregs
mlbe32.v   vd, (rs1), rs2 # 32-bit tile slices load to vregs
mlbe64.v   vd, (rs1), rs2 # 64-bit tile slices load to vregs

# for output matrix, c
mlce8.v    vd, (rs1), rs2 #  8-bit tile slices load to vregs
mlce16.v   vd, (rs1), rs2 # 16-bit tile slices load to vregs
mlce32.v   vd, (rs1), rs2 # 32-bit tile slices load to vregs
mlce64.v   vd, (rs1), rs2 # 64-bit tile slices load to vregs

Store Instructions

# vs3 store data, rs1 base address, rs2 row byte stride
# lmul / (eew/sew) rows or columns

# for left matrix, a
msae8.v    vs3, (rs1), rs2 #  8-bit tile slices store from vregs
msae16.v   vs3, (rs1), rs2 # 16-bit tile slices store from vregs
msae32.v   vs3, (rs1), rs2 # 32-bit tile slices store from vregs
msae64.v   vs3, (rs1), rs2 # 64-bit tile slices store from vregs

# for right matrix, b
msbe8.v    vs3, (rs1), rs2 #  8-bit tile slices store from vregs
msbe16.v   vs3, (rs1), rs2 # 16-bit tile slices store from vregs
msbe32.v   vs3, (rs1), rs2 # 32-bit tile slices store from vregs
msbe64.v   vs3, (rs1), rs2 # 64-bit tile slices store from vregs

# for output matrix, c
msce8.v    vs3, (rs1), rs2 #  8-bit tile slices store from vregs
msce16.v   vs3, (rs1), rs2 # 16-bit tile slices store from vregs
msce32.v   vs3, (rs1), rs2 # 32-bit tile slices store from vregs
msce64.v   vs3, (rs1), rs2 # 64-bit tile slices store from vregs

Data Move Instructions

For data moving between vector and matrix, the vsew of vector must equal to msew of matrix.

The number of elements moved is min(VLEN/SEW * VLMUL, matrix_size).

  • For matrix A, matrix_size = mtilem * mtilek.

  • For matrix B, matrix_size = mtilek * mtilen.

  • For matrix C, matrix_size = mtilem * mtilen.

# Data move between matrix register rows and vector registers.

# vd[(i - rs2) * mtilek + j] = md[i, j], i = rs2 .. rs2 + mtilem - 1
mmvare8.v.m   vd, ms1, rs2
mmvare16.v.m  vd, ms1, rs2
mmvare32.v.m  vd, ms1, rs2
mmvare64.v.m  vd, ms1, rs2

# vd[(i - rs2) * mtilen + j] = md[i, j], i = rs2 .. rs2 + mtilek - 1
mmvbre8.v.m   vd, ms1, rs2
mmvbre16.v.m  vd, ms1, rs2
mmvbre32.v.m  vd, ms1, rs2
mmvbre64.v.m  vd, ms1, rs2

# vd[(i - rs2) * mtilen + j] = md[i, j], i = rs2 .. rs2 + mtilem - 1
mmvcre8.v.m   vd, ms1, rs2
mmvcre16.v.m  vd, ms1, rs2
mmvcre32.v.m  vd, ms1, rs2
mmvcre64.v.m  vd, ms1, rs2

# md[i, j] = vd[(i - rs2) * mtilek + j], i = rs2 .. rs2 + mtilem - 1
mmvare8.m.v   md, vs1, rs2
mmvare16.m.v  md, vs1, rs2
mmvare32.m.v  md, vs1, rs2
mmvare64.m.v  md, vs1, rs2

# md[i, j] = vd[(i - rs2) * mtilen + j], i = rs2 .. rs2 + mtilek - 1
mmvbre8.m.v   md, vs1, rs2
mmvbre16.m.v  md, vs1, rs2
mmvbre32.m.v  md, vs1, rs2
mmvbre64.m.v  md, vs1, rs2

# md[i, j] = vd[(i - rs2) * mtilen + j], i = rs2 .. rs2 + mtilem - 1
mmvcre8.m.v   md, vs1, rs2
mmvcre16.m.v  md, vs1, rs2
mmvcre32.m.v  md, vs1, rs2
mmvcre64.m.v  md, vs1, rs2

# Data move between matrix register columns and vector registers.

# vd[(j - rs2) * mtilem + i] = md[i, j], j = rs2 .. rs2 + mtilek - 1
mmvace8.v.m   vd, ms1, rs2
mmvace16.v.m  vd, ms1, rs2
mmvace32.v.m  vd, ms1, rs2
mmvace64.v.m  vd, ms1, rs2

# vd[(j - rs2) * mtilek + i] = md[i, j], j = rs2 .. rs2 + mtilen - 1
mmvbce8.v.m   vd, ms1, rs2
mmvbce16.v.m  vd, ms1, rs2
mmvbce32.v.m  vd, ms1, rs2
mmvbce64.v.m  vd, ms1, rs2

# vd[(j - rs2) * mtilem + i] = md[i, j], j = rs2 .. rs2 + mtilen - 1
mmvcce8.v.m   vd, ms1, rs2
mmvcce16.v.m  vd, ms1, rs2
mmvcce32.v.m  vd, ms1, rs2
mmvcce64.v.m  vd, ms1, rs2

# md[i, j] = vd[(j - rs2) * mtilem + i], j = rs2 .. rs2 + mtilek - 1
mmvace8.m.v   md, vs1, rs2
mmvace16.m.v  md, vs1, rs2
mmvace32.m.v  md, vs1, rs2
mmvace64.m.v  md, vs1, rs2

# md[i, j] = vd[(j - rs2) * mtilek + i], j = rs2 .. rs2 + mtilen - 1
mmvbce8.m.v   md, vs1, rs2
mmvbce16.m.v  md, vs1, rs2
mmvbce32.m.v  md, vs1, rs2
mmvbce64.m.v  md, vs1, rs2

# md[i, j] = vd[(j - rs2) * mtilem + i], j = rs2 .. rs2 + mtilen - 1
mmvcce8.m.v   md, vs1, rs2
mmvcce16.m.v  md, vs1, rs2
mmvcce32.m.v  md, vs1, rs2
mmvcce64.m.v  md, vs1, rs2

Intrinsic Example: Matrix multiplication fused with element-wise vector operation

void fused_matmul_relu_float16(c, a, b, m, k, n) {
    msettype(e16);                              // use 16bit input matrix element
    for (i = 0; i < m; i += tile_m) {           // loop at dim m with tiling
        tile_m = msettilem(m-i);
        for (j = 0; j < n; j += tile_n) {       // loop at dim n with tiling
            tile_n = msettilen(n-j);

            out = mwsub_mm(out, out)            // clear acc reg
            for (s = 0; s < k; s += tile_k) {   // loop at dim k with tiling
                tile_k = msettilek(k-s);

                tr1 = mlae16_m(&a[i][s]);       // load left matrix a
                tr2 = mlbe16_m(&b[s][j]);       // load right matrix b
                out = mfwma_mm(tr1, tr2);       // tiled matrix multiply,
                                                // double widen output
            }

            out = mfncvt_f_fw_m(out, m2);       // convert widen result to single

            for (s = 0; s < tile_m; s += rows) {
                // max rows could move into 8 vregs
                rows = min(tile_m - s, 8*vlenb/rlenb);
                vsetvl(tile_n*rows, e16, m8);

                v1 = mmvcr_v_m(out, s);         // move out rows to vreg
                v1 = vfmax_vf(0.f, v1);         // vfmax.vf for relu

                msce16_v(v1, &c[i+s][j], n);    // store output tile slices
            }
        }
    }
}