diff --git a/stack/port_tcp.go b/stack/port_tcp.go index e479fee..6820714 100644 --- a/stack/port_tcp.go +++ b/stack/port_tcp.go @@ -2,7 +2,6 @@ package stack import ( "errors" - "io" "strconv" "time" @@ -10,6 +9,12 @@ import ( "github.com/soypat/seqs/eth" ) +// tcphandler represents a user provided function for handling incoming TCP packets on a port. +// Incoming data is sent inside the `pkt` TCPPacket argument when pkt.HasPacket returns true. +// Outgoing data is stored into the `response` byte slice. The function must return the number of +// bytes written to `response` and an error. +// +// See [PortStack] for information on how to use this function and other port handlers. type tcphandler func(response []byte, pkt *TCPPacket) (int, error) type tcpPort struct { @@ -57,7 +62,7 @@ func (u *tcpPort) HandleEth(dst []byte) (n int, err error) { packet := &u.packets[0] n, err = u.handler(dst, &u.packets[0]) - if err == io.ErrNoProgress { + if err == ErrFlagPending { packet.Rx = forcedTime // Mark socket as needing handling but packet having no data. } else { packet.Rx = time.Time{} // Invalidate packet normally. diff --git a/stack/port_udp.go b/stack/port_udp.go index 11af8c5..a483d8a 100644 --- a/stack/port_udp.go +++ b/stack/port_udp.go @@ -1,7 +1,6 @@ package stack import ( - "io" "strconv" "time" @@ -48,7 +47,7 @@ func (u *udpPort) HandleEth(dst []byte) (int, error) { packet := &u.packets[0] n, err := u.handler(dst, &u.packets[0]) - if err == io.ErrNoProgress { + if err == ErrFlagPending { packet.Rx = forcedTime // Mark socket as needing handling but packet having no data. } else { packet.Rx = time.Time{} // Invalidate packet normally. diff --git a/stack/portstack.go b/stack/portstack.go index 05ee7e7..ba23f10 100644 --- a/stack/portstack.go +++ b/stack/portstack.go @@ -39,14 +39,41 @@ func NewPortStack(cfg PortStackConfig) *PortStack { return &s } +var ErrFlagPending = io.ErrNoProgress + // PortStack implements partial TCP/UDP packet muxing to respective sockets with [PortStack.RcvEth]. // This implementation limits itself basic header validation and port matching. // Users of PortStack are expected to implement connection state, packet buffering and retransmission logic. // - In the case of TCP this means implementing the TCP state machine. // - In the case of UDP PortStack should be enough to build most applications. +// +// # Notes on PortStack handlers +// +// - While PortStack.HandleEth has yet to find a outgoing packet it will look for +// a port that has a pending packet or has been flagged as pending and call its handler. +// +// - A call to a handler may or may not have an incoming packet ready to process. +// When pkt.HasPacket() returns true then pkt contains an incoming packet to the port. +// +// - When pkt.HasPacket() returns false the contents are undefined. +// +// - Users can safely use pkt even if pkt.HasPacket() returns false. +// +// - If the handler returns an error that is not ErrFlagPending then the port +// is immediately closed and written data is discarded. +// +// - ErrFlagPending: When returned by the handler then the port is flagged as +// pending and the written data is handled normally if there is any. If no data is written +// the call to HandleEth proceeds looking for another port to handle. +// +// - ErrFlagPending: When returned by the handler then for UDP/TCP implementations the +// incoming packet argument `pkt` is flagged as not present in future calls to the handler in pkt.HasPacket calls. +// The handler however can be aware of this fact and still use the pkt argument since the header+payload contents +// are not modified by the stack. type PortStack struct { lastRx time.Time lastRxSuccess time.Time + lastTx time.Time mac [6]byte // Set IP to non-nil to ignore packets not meant for us. IP netip.Addr @@ -106,7 +133,7 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) { return errPacketSmol } ps.debug("Stack.RecvEth:start", slog.Int("plen", len(payload))) - ps.lastRx = time.Now() + ps.lastRx = ps.now() // Ethernet parsing block ehdr = eth.DecodeEthernetHeader(payload) @@ -124,7 +151,7 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) { } switch ahdr.Operation { case 1: // ARP request. - if ps.hasPendingARP() || ahdr.ProtoTarget != ps.IP.As4() { + if ps.pendingReplyToARP() || ahdr.ProtoTarget != ps.IP.As4() { return nil // ARP reply pending or not for us. } // We need to respond to this ARP request. @@ -259,16 +286,21 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) { // HandleEth searches for a socket with a pending packet and writes the response // into the dst argument. The length written to dst is returned. -// [io.ErrNoProgress] can be returned by value by a handler to indicate the packet was +// [ErrFlagPending] can be returned by value by a handler to indicate the packet was // not processed and that a future call to HandleEth is required to complete. // // If a handler returns any other error the port is closed. -func (ps *PortStack) HandleEth(dst []byte) (int, error) { +func (ps *PortStack) HandleEth(dst []byte) (n int, err error) { + defer func() { + if n > 0 && err == nil { + ps.lastTx = ps.now() + } + }() switch { case len(dst) < _MTU: return 0, io.ErrShortBuffer - case ps.ARPresult.Operation == 1: + case ps.pendingRequestARP(): // We have a pending request from user to perform ARP. ehdr := eth.EthernetHeader{ Destination: broadcastMAC, @@ -280,7 +312,7 @@ func (ps *PortStack) HandleEth(dst []byte) (int, error) { ps.ARPresult.Operation = arpOpWait // Clear pending ARP to not loop. return eth.SizeEthernetHeader + eth.SizeARPv4Header, nil - case ps.hasPendingARP(): + case ps.pendingReplyToARP(): // We need to respond to an ARP request that queries our address. ehdr := eth.EthernetHeader{ Destination: ps.pendingARPresponse.HardwareSender, @@ -304,28 +336,29 @@ func (ps *PortStack) HandleEth(dst []byte) (int, error) { HandleEth(dst []byte) (int, error) } - handleSocket := func(dst []byte, pending *uint32, sock Socket) (int, error) { + handleSocket := func(dst []byte, sock Socket) (int, bool, error) { if !sock.IsPendingHandling() { - return 0, nil // Nothing to handle, just skip. + return 0, false, nil // Nothing to handle, just skip. } // Socket has an unhandled packet. n, err := sock.HandleEth(dst) - if err == io.ErrNoProgress { - // Special case: Socket may have written data but needs future handling, flagged with the io.ErrNoProgress error. - return n, nil + if err == ErrFlagPending { + // Special case: Socket may have written data but needs future handling, flagged with the ErrFlagPending error. + return n, true, nil } - // If we get here the socket has been handled, so we decrement the pending counter. - *pending-- if err != nil { sock.Close() - return 0, err + n = 0 } - return n, nil + return n, false, err } if ps.pendingUDPv4 > 0 { for i := range ps.UDPv4 { - n, err := handleSocket(dst, &ps.pendingUDPv4, &ps.UDPv4[i]) + n, pending, err := handleSocket(dst, &ps.UDPv4[i]) + if !pending { + ps.pendingUDPv4-- + } if err != nil { return 0, err } else if n > 0 { @@ -337,7 +370,10 @@ func (ps *PortStack) HandleEth(dst []byte) (int, error) { if ps.pendingTCPv4 > 0 { for i := range ps.TCPv4 { - n, err := handleSocket(dst, &ps.pendingTCPv4, &ps.TCPv4[i]) + n, pending, err := handleSocket(dst, &ps.TCPv4[i]) + if !pending { + ps.pendingTCPv4-- + } if err != nil { return 0, err } else if n > 0 { @@ -371,10 +407,19 @@ func (ps *PortStack) ARPv4Result() (eth.ARPv4Header, bool) { return ps.ARPresult, ps.ARPresult.Operation == 2 } -func (ps *PortStack) hasPendingARP() bool { +func (ps *PortStack) pendingReplyToARP() bool { return ps.pendingARPresponse.Operation == 2 // 2 means reply. } +func (ps *PortStack) pendingRequestARP() bool { + return ps.ARPresult.Operation == 1 // User asked for a ARP request. +} + +// IsPendingHandling checks if a call to HandleEth could possibly result in a packet being generated by the PortStack. +func (ps *PortStack) IsPendingHandling() bool { + return ps.pendingUDPv4 > 0 || ps.pendingTCPv4 > 0 || ps.pendingRequestARP() || ps.pendingReplyToARP() +} + // OpenUDP opens a UDP port and sets the handler. If the port is already open // or if there is no socket available it returns an error. func (ps *PortStack) OpenUDP(port uint16, handler func([]byte, *UDPPacket) (int, error)) error { @@ -510,6 +555,10 @@ func (ps *PortStack) getTCP(port uint16) *tcpPort { return nil } +func (ps *PortStack) now() time.Time { + return time.Now() +} + func (ps *PortStack) info(msg string, attrs ...slog.Attr) { ps.logAttrsPrint(slog.LevelInfo, msg, attrs...) } diff --git a/stack/socket_tcp.go b/stack/socket_tcp.go index 601ccbf..05843a0 100644 --- a/stack/socket_tcp.go +++ b/stack/socket_tcp.go @@ -139,7 +139,7 @@ func (t *TCPSocket) Close() error { func (t *TCPSocket) handleMain(response []byte, pkt *TCPPacket) (n int, err error) { defer func() { - if err != nil && t.abortErr == nil && err != io.ErrNoProgress { + if err != nil && t.abortErr == nil && err != ErrFlagPending { err = nil // Only close socket if socket is aborted. } else if err != nil { t.stack.error("tcp socket", slog.Int("port", int(t.localPort)), slog.String("err", err.Error())) @@ -225,7 +225,7 @@ func (t *TCPSocket) handleSend(response []byte, pkt *TCPPacket) (n int, err erro pkt.PutHeaders(response) if t.scb.HasPending() { - err = io.ErrNoProgress // Flag to PortStack that we have pending data to send. + err = ErrFlagPending // Flag to PortStack that we have pending data to send. } return sizeTCPNoOptions + n, err } diff --git a/stack/stack_test.go b/stack/stack_test.go index a0528e1..1fb4edc 100644 --- a/stack/stack_test.go +++ b/stack/stack_test.go @@ -230,6 +230,7 @@ func exchangeStacks(t *testing.T, maxExchanges int, stacks ...*stack.PortStack) if sentInTx == 0 { break // No more data being sent. } + for isend := 0; isend < len(stacks); isend++ { // We deliver each in-flight packet to all stacks, except the one that sent it. payload := getPayload(isend)