diff --git a/litespi/__init__.py b/litespi/__init__.py index c3e7a27..85bfe82 100644 --- a/litespi/__init__.py +++ b/litespi/__init__.py @@ -50,6 +50,18 @@ class LiteSPI(Module, AutoCSR, AutoDoc): mmap_endianness : string If endianness is set to ``small`` then byte order of each 32-bit word comming MMAP core will be reversed. + with_csr : bool + The number of dummy bits can be configure when set to True. + + with_mmap_write : bool or string + MMAP writes are supported when set to True or "csr". When set to "csr", they are disabled by default but + can be enabled on demand using a CSR. + + Please note that only False and "csr" should be used with flash chips! True is only meant for RAM. + + When using "csr" with a flash chip, make sure to erase the corresponding pages of the flash beforehand + using the LiteSPI master. It is also recommended to disable mmap writing once it is not required anymore. + Attributes ---------- bus : Interface(), out @@ -59,7 +71,7 @@ class LiteSPI(Module, AutoCSR, AutoDoc): def __init__(self, phy, clock_domain="sys", with_mmap=True, mmap_endianness="big", with_master=True, master_tx_fifo_depth=1, master_rx_fifo_depth=1, - with_csr=True): + with_csr=True, with_mmap_write=False): self.submodules.crossbar = crossbar = LiteSPICrossbar(clock_domain) self.comb += phy.cs.eq(crossbar.cs) @@ -67,7 +79,8 @@ def __init__(self, phy, clock_domain="sys", if with_mmap: self.submodules.mmap = mmap = LiteSPIMMAP(flash=phy.flash, endianness=mmap_endianness, - with_csr=with_csr) + with_csr=with_csr, + with_write=with_mmap_write) port_mmap = crossbar.get_port(mmap.cs) self.bus = mmap.bus self.comb += [ diff --git a/litespi/core/mmap.py b/litespi/core/mmap.py index 074cbc1..bdd96ae 100644 --- a/litespi/core/mmap.py +++ b/litespi/core/mmap.py @@ -2,6 +2,7 @@ # This file is part of LiteSPI # # Copyright (c) 2020 Antmicro +# Copyright (c) 2024 Matthias Breithaupt # SPDX-License-Identifier: BSD-2-Clause from migen import * @@ -40,6 +41,18 @@ class LiteSPIMMAP(Module, AutoCSR): endianness : string If endianness is set to ``little`` then byte order of each 32-bit word coming from flash will be reversed. + with_csr : bool + The number of dummy bits can be configure when set to True. + + with_write : bool or string "csr" + MMAP writes are supported when set to True or "csr". When set to "csr", they are disabled by default but + can be enabled on demand using a CSR. + + Please note that only False and "csr" should be used with flash chips! True is only meant for RAM. + + When using "csr" with a flash chip, make sure to erase the corresponding pages of the flash beforehand + using the LiteSPI master. It is also recommended to disable mmap writing once it is not required anymore. + Attributes ---------- source : Endpoint(spi_core2phy_layout), out @@ -55,13 +68,17 @@ class LiteSPIMMAP(Module, AutoCSR): CS signal for the flash chip, should be connected to cs signal of the PHY. dummy_bits : CSRStorage - Register which hold a number of dummy bits to send during transmission. + Register which holds the number of dummy bits to send during transmission. + + write_config : CSRStorage + Optional register holding configuration bits for the write mode. """ - def __init__(self, flash, clock_domain="sys", endianness="big", with_csr=True): + def __init__(self, flash, clock_domain="sys", endianness="big", with_csr=True, with_write=False): self.source = source = stream.Endpoint(spi_core2phy_layout) self.sink = sink = stream.Endpoint(spi_phy2core_layout) self.bus = bus = wishbone.Interface() self.cs = cs = Signal() + self.offset = offset = Signal(len(bus.adr)) # Burst Control. burst_cs = Signal() @@ -69,6 +86,10 @@ def __init__(self, flash, clock_domain="sys", endianness="big", with_csr=True): burst_timeout = WaitTimer(MMAP_DEFAULT_TIMEOUT) self.submodules += burst_timeout + write = Signal() + write_enabled = Signal() + write_mask = Signal(len(bus.sel)) + cmd_bits = 8 data_bits = 32 @@ -92,6 +113,20 @@ def __init__(self, flash, clock_domain="sys", endianness="big", with_csr=True): dummy = Signal(data_bits, reset=0xdead) + if with_write and with_write == "csr": + self.write_config = write_config = CSRStorage(fields=[ + CSRField("write_enable", size=1, reset=0, description="MMAP write enable"), + ]) + if clock_domain != "sys": + self.specials += MultiReg(write_config.fields.write_enable, write_enabled, clock_domain) + else: + self.comb += write_enabled.eq(write_config.fields.write_enable) + else: + self.comb += write_enabled.eq(Constant(with_write == True)) + + self.byte_count = byte_count = Signal(2, reset_less=True) + self.data_write = Signal(32) + # FSM. self.submodules.fsm = fsm = FSM(reset_state="IDLE") fsm.act("IDLE", @@ -99,24 +134,62 @@ def __init__(self, flash, clock_domain="sys", endianness="big", with_csr=True): burst_timeout.wait.eq(1), NextValue(burst_cs, burst_cs & ~burst_timeout.done), cs.eq(burst_cs), - # On Bus Read access... - If(bus.cyc & bus.stb & ~bus.we, - # If CS is still active and Bus address matches previous Burst address: - # Just continue the current Burst. - If(burst_cs & (bus.adr == burst_adr), - NextState("BURST-REQ") - # Otherwise initialize a new Burst. - ).Else( - cs.eq(0), - NextState("BURST-CMD") + If(bus.cyc & bus.stb, + NextValue(byte_count, 0), + # On Bus Read access... + If(~bus.we, + # If CS is still active, Bus address matches previous Burst address and previous access was reading: + # Just continue the current Burst. + If(burst_cs & (bus.adr == burst_adr) & (~write_enabled | ~write), + NextState("BURST-REQ") + # Otherwise initialize a new Burst. + ).Else( + cs.eq(0), + NextState("BURST-CMD") + ), + NextValue(write, 0) + # On Bus Write access (if enabled)... + ).Elif(write_enabled, + # If CS is still active, Bus address matches previous Burst address and previous access was writing: + # Just continue the current Burst. + NextValue(write_mask, bus.sel), + NextValue(self.data_write, bus.dat_w), + If(burst_cs & (bus.adr == burst_adr) & bus.sel[0] & write, + NextState("WRITE") + # Otherwise initialize a new Burst. + ).Else( + cs.eq(0), + NextState("PRE-BURST-CMD-WRITE"), + ), + NextValue(write, 1) ) ) ) + fsm.act("PRE-BURST-CMD-WRITE", + cs.eq(0), + If(write_mask[0], + NextState("BURST-CMD"), + NextValue(write, 1) + ).Elif(byte_count == 3, + bus.ack.eq(1), + NextValue(burst_adr, burst_adr + 1), + NextState("IDLE"), + NextValue(write, 0) + ).Else( + NextValue(byte_count, byte_count + 1), + NextValue(write_mask, Cat(write_mask[1:len(bus.sel)], Signal(1))), + ) + ) + fsm.act("BURST-CMD", cs.eq(1), source.valid.eq(1), - source.data.eq(flash.read_opcode.code), # send command. + If(write_enabled & write, + source.data.eq(flash.program_opcode.code), # send command. + ).Else( + source.data.eq(flash.read_opcode.code), # send command. + ), source.len.eq(cmd_bits), source.width.eq(flash.cmd_width), source.mask.eq(cmd_oe_mask[flash.cmd_width]), @@ -139,7 +212,7 @@ def __init__(self, flash, clock_domain="sys", endianness="big", with_csr=True): source.valid.eq(1), source.width.eq(flash.addr_width), source.mask.eq(addr_oe_mask[flash.addr_width]), - source.data.eq(Cat(Signal(2), bus.adr)), # send address. + source.data.eq(Cat(byte_count, bus.adr - offset)), # send address. source.len.eq(flash.addr_bits), NextValue(burst_cs, 1), NextValue(burst_adr, bus.adr), @@ -152,7 +225,9 @@ def __init__(self, flash, clock_domain="sys", endianness="big", with_csr=True): cs.eq(1), sink.ready.eq(1), If(sink.valid, - If(spi_dummy_bits == 0, + If(write_enabled & write, + NextState("WRITE"), + ).Elif(spi_dummy_bits == 0, NextState("BURST-REQ"), ).Else( NextState("DUMMY"), @@ -164,7 +239,7 @@ def __init__(self, flash, clock_domain="sys", endianness="big", with_csr=True): cs.eq(1), source.valid.eq(1), source.width.eq(flash.addr_width), - source.mask.eq(addr_oe_mask[flash.addr_width]), + source.mask.eq(0), source.data.eq(dummy), source.len.eq(spi_dummy_bits), If(source.ready, @@ -202,3 +277,39 @@ def __init__(self, flash, clock_domain="sys", endianness="big", with_csr=True): NextState("IDLE"), ) ) + + fsm.act("WRITE", + cs.eq(1), + source.valid.eq(1), + source.width.eq(flash.addr_width), + source.mask.eq(addr_oe_mask[flash.bus_width]), + source.data.eq(self.data_write), + source.len.eq(8), + If(source.ready, + NextState("WRITE-RET"), + ) + ) + + fsm.act("WRITE-RET", + cs.eq(1), + sink.ready.eq(1), + If(sink.valid, + If(byte_count != 3, + NextValue(write_mask, Cat(write_mask[1:len(bus.sel)], Signal(1))), + NextValue(byte_count, byte_count + 1), + NextValue(self.data_write, self.data_write >> 8), + If(write_mask[1], + NextState("WRITE"), + ).Else( + cs.eq(0), + NextValue(write, 0), + NextState("PRE-BURST-CMD-WRITE"), + ), + + ).Else( + bus.ack.eq(1), + NextValue(burst_adr, burst_adr + 1), + NextState("IDLE"), + ), + ) + ) diff --git a/test/test_spi_mmap.py b/test/test_spi_mmap.py index bb56c46..b575fc8 100644 --- a/test/test_spi_mmap.py +++ b/test/test_spi_mmap.py @@ -34,6 +34,80 @@ class DummyChip(SpiNorFlashModule): def test_spi_mmap_core_syntax(self): spi_mmap = LiteSPIMMAP(flash=self.DummyChip(Codes.READ_1_1_1, [])) + spi_write_mmap = LiteSPIMMAP(flash=self.DummyChip(Codes.READ_1_1_1, [], program_cmd=Codes.PP_1_1_1), with_write=True) + + def test_spi_mmap_write_test(self): + opcode = Codes.PP_1_1_1 + dut = LiteSPIMMAP(flash=self.DummyChip(Codes.READ_1_1_1, [], program_cmd=opcode), with_write=True) + + def wb_gen(dut, addr, data, offset): + dut.done = 0 + + yield dut.offset.eq(offset) + yield dut.bus.adr.eq(addr + offset) + print((yield dut.bus.adr)) + yield dut.bus.we.eq(1) + yield dut.bus.cyc.eq(1) + yield dut.bus.stb.eq(1) + yield dut.bus.dat_w.eq(data) + + while (yield dut.bus.ack) == 0: + yield + + dut.done = 1 + + def phy_gen(dut, addr, data): + dut.addr_ok = 0 + dut.opcode_ok = 0 + dut.cmd_ok = 0 + dut.data_ok = 0 + yield dut.sink.valid.eq(0) + yield dut.source.ready.eq(1) + + while (yield dut.source.valid) == 0: + yield + + + # WRITE CMD + if (yield dut.source.data) == opcode.code: # cmd ok + dut.opcode_ok = 1 + + yield + yield dut.sink.valid.eq(1) + while (yield dut.source.valid) == 0: + yield + yield dut.sink.valid.eq(0) + + # WRITE ADDR + print((yield dut.source.data)) + if (yield dut.source.data) == (addr<<2): # address cmd + dut.addr_ok = 1 + + yield + yield dut.sink.valid.eq(1) + while (yield dut.source.valid) == 0: + yield + yield dut.sink.valid.eq(0) + + # WRITE DATA + if (yield dut.source.data) == (data): # data ok + dut.data_ok = 1 + + yield + yield dut.sink.valid.eq(1) + while (yield dut.source.valid) == 0: + yield + yield dut.sink.valid.eq(0) + yield + addr = 0xcafe + data = 0xdeadbeef + offset = 0x10000 + + run_simulation(dut, [wb_gen(dut, addr, data, offset), phy_gen(dut, addr, data)]) + self.assertEqual(dut.done, 1) + self.assertEqual(dut.addr_ok, 1) + self.assertEqual(dut.opcode_ok, 1) + self.assertEqual(dut.data_ok, 1) def test_spi_mmap_read_test(self): opcode = Codes.READ_1_1_1