Due to hardware resource constraints, one of the common ways to handle large-sized matrix multiplication is "tiling", where each iteration of the loop processes a subset of elements, and then continues to iterate until all elements are processed. The Matrix extension provides direct, portable support for this approach.
The block processing of matrix multiplication requires three levels of loops to iterate in the direction of the number of rows of the left matrix (m), the number of columns of the left matrix (k, also the number of rows of the right matrix), and the number of columns of the right matrix (n), given by the application.
The shapes of the matrix tiles to be processed, m (application tile length m or ATM
), k (ATK
), n (ATN
), is used as candidates for mtilem
/ mtilek
/ mtilen
. Based on microarchitecture implementation, hardware returns a new mtilem
/ mtilek
/ mtilen
value via a general purpose register (usually smaller), also stored in mtilem
/ mtilek
/ mtilen
CSR, which is the shape of tile per iteration handled by hardware.
For a simple matrix multiplication example, check out the Section Intrinsic Example, which describes how the code keeps track of the matrices processed by the hardware each iteration.
A set of instructions is provided to allow rapid configuration of the values in mtile*
and mtype
to match application needs.
The msettype[i|hi]
instructions set the mtype
CSR based on their arguments, and write the new value of mtype into rd.
msettypei rd, imm # rd = new mtype, imm = new mtype[9:0] setting.
msettypehi rd, imm # rd = new mtype, imm = new mtype[19:10] setting.
msettype rd, rs1 # rd = new mtype, rs1 = new mtype value.
The mset*
instructions set the specified field of mtype
without affecting other fields.
# Set msew field.
msetsew rd, imm # rd = new mtype, set msew to imm[2:0].
msetsew rd, e8 # rd = new mtype, imm = 0.
msetsew rd, e16 # rd = new mtype, imm = 1.
msetsew rd, e32 # rd = new mtype, imm = 2.
msetsew rd, e64 # rd = new mtype, imm = 3.
# Set mba field.
msetba rd, imm # rd = new mtype, set mba to imm[0].
msetba rd, bu # rd = new mtype, imm = 0.
msetba rd, ba # rd = new mtype, imm = 1.
# Set integer type fields.
msetint rd, int4 # rd = new mtype, set mint4 = 1 to enable INT4 type.
msetint rd, int8 # rd = new mtype, set mint8 = 1 to enable INT8 type.
msetint rd, int16 # rd = new mtype, set mint16 = 1 to enable INT16 type.
msetint rd, int32 # rd = new mtype, set mint32 = 1 to enable INT32 type.
msetint rd, int64 # rd = new mtype, set mint64 = 1 to enable INT64 type.
# Set float point type fields.
msetfp rd, e4m3 # rd = new mtype, set mfp8 = 01 to enable FP8 E4M3 type.
msetfp rd, e5m2 # rd = new mtype, set mfp8 = 10 to enable FP8 E5M2 type.
msetfp rd, e3m4 # rd = new mtype, set mfp8 = 11 to enable FP8 E3M4 type.
msetfp rd, fp16 # rd = new mtype, set mfp16 = 01 to enable FP16 E5M10 type.
msetfp rd, bf16 # rd = new mtype, set mfp16 = 10 to enable BF16 E8M7 type.
msetfp rd, fp32 # rd = new mtype, set mfp32 = 01 to enable FP32 E8M23 type.
msetfp rd, tf32 # rd = new mtype, set mfp32 = 10 to enable TF32 E8M10 type.
msetfp rd, fp64 # rd = new mtype, set mfp64 = 1 to enable FP64 type.
The munset*
instructions unset the specified field of mtype
without affecting other fields.
munsetint rd, int4 # rd = new mtype, set mint4 = 0 to disable INT4 type.
munsetint rd, int8 # rd = new mtype, set mint8 = 0 to disable INT8 type.
munsetint rd, int16 # rd = new mtype, set mint16 = 0 to disable INT16 type.
munsetint rd, int32 # rd = new mtype, set mint32 = 0 to disable INT32 type.
munsetint rd, int64 # rd = new mtype, set mint64 = 0 to disable INT64 type.
munsetfp rd, fp8 # rd = new mtype, set mfp8 = 00 to disable FP8 type.
munsetfp rd, fp16 # rd = new mtype, set mfp16 = 00 to disable FP16 type.
munsetfp rd, fp32 # rd = new mtype, set mfp32 = 00 to disable FP32 type.
munsetfp rd, fp64 # rd = new mtype, set mfp64 = 0 to disable FP64 type.
The field to be set or unset is specified by inst[18:15] and the value is specified by inst[24:20].
inst[18:15] |
field |
0000 |
msew |
0001 |
mint4 |
0010 |
mint8 |
0011 |
mint16 |
0100 |
mint32 |
0101 |
mint64 |
0110 |
mfp8 |
0111 |
mfp16 |
1000 |
mfp32 |
1001 |
mfp64 |
1010 |
mba |
The msettile{m|k|n}[i]
instructions set the mtilem/mtilek/mtilen CSRs based on their arguments, and write the new value into rd.
msettilemi rd, imm # rd = new mtilem, imm = ATM
msettilem rd, rs1 # rd = new mtilem, rs1 = ATM
msettileki rd, imm # rd = new mtilek, imm = ATN
msettilek rd, rs1 # rd = new mtilek, rs1 = ATN
msettileni rd, imm # rd = new mtilen, imm = ATK
msettilen rd, rs1 # rd = new mtilen, rs1 = ATK
The new mtype
value is encoded in the immediate fields of msettypei
/ msettypehi
, and in the rs1
register for msettype
. Each field can be set or unset with msetsew
, msetba
, msetfp
, msetint
, munsetfp
and munsetint
instructions independently.
There are three values, TMMAX
, TKMAX
and TNMAX
, represent the maximum shapes of the matrix tiles that could be stored in matrix registers, and can be operated on with a single matrix instruction given the current SEW settings.
The values of TMMAX
, TKMAX
and TNMAX
are related to MLEN and RLEN.
-
TMMAX = MLEN / RLEN
-
TKMAX = min(MLEN / RLEN, RLEN / SEW)
-
TNMAX = RLEN / SEW
For examples, with MLEN=256
and RLEN=64
, TMMAX
, TKMAX
and TNMAX
values are shown below.
SEW=8, TMMAX=4, TKMAX=4, TNMAX=8 # 4x4x8 8-bit matmul
SEW=16, TMMAX=4, TKMAX=4, TNMAX=4 # 4x4x4 16-bit matmul
SEW=32, TMMAX=4, TKMAX=2, TNMAX=2 # 4x2x2 32-bit matmul
The new tile shape settings are based on ATM
/ ATK
/ ATN
values, which for msettile{m|k|n}
is encoded in the rs1 and rd fields.
rd |
rs1 |
|
Effect on |
- |
!x0 |
Value in |
Normal tiling |
!x0 |
x0 |
~0 |
Set |
x0 |
x0 |
Value in |
Keep existing |
For the msettile{m|k|n}i
instructions, the ATM
/ ATK
/ ATN
is encoded as a 10-bit unsigned immediate in the rs1.
The msettile{m|k|n}[i]
instructions first set TMMAX/TKMAX/TNMAX
according to the mtype CSR, then set mtilem/mtilek/mtilen
obeying the following constraints (using mtilem
& ATM
& TMMAX
as an example, and the same with mtilek
& ATK
& TKMAX
and mtilen
& ATN
& TNMAX
):
-
mtilem = ATM
ifATM <= TMMAX
-
ceil(ATM / 2) <= mtilem <= TMMAX
ifATM < (2 * TMMAX)
-
mtilem = TMMAX
ifATM >= (2 * TMMAX)
-
Deterministic on any given implementation for same input ATM and TMMAX values
-
These specific properties follow from the prior rules:
-
mtilem = 0
ifATM = 0
-
mtilem > 0
ifATM > 0
-
mtilem <= TMMAX
-
mtilem <= ATM
-
a value read from
mtilem
when used as the ATM argument tomsettile{m|k|n}{i}
results in the same value inmtilem
, provided the resultant TMMAX equals the value of TMMAX at the time thatmtilem
was read.
-
Continue to use MLEN=256
and RLEN=64
as an example. When SEW=16, TMMAX=4, TKMAX=4, TNMAX=8.
If A is a 7 x 8 matrix and B is a 8 x 14 matrix, we could get mtilem/mtilek/mtilen
values as show below, in the last loop of tiling.