Skip to content

Commit

Permalink
PortStack docs revamp; add ErrFlagPending
Browse files Browse the repository at this point in the history
  • Loading branch information
soypat committed Nov 20, 2023
1 parent ae863b1 commit e6396dc
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 24 deletions.
9 changes: 7 additions & 2 deletions stack/port_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@ package stack

import (
"errors"
"io"
"strconv"
"time"

"github.com/soypat/seqs"
"github.com/soypat/seqs/eth"
)

// tcphandler represents a user provided function for handling incoming TCP packets on a port.
// Incoming data is sent inside the `pkt` TCPPacket argument when pkt.HasPacket returns true.
// Outgoing data is stored into the `response` byte slice. The function must return the number of
// bytes written to `response` and an error.
//
// See [PortStack] for information on how to use this function and other port handlers.
type tcphandler func(response []byte, pkt *TCPPacket) (int, error)

type tcpPort struct {
Expand Down Expand Up @@ -57,7 +62,7 @@ func (u *tcpPort) HandleEth(dst []byte) (n int, err error) {
packet := &u.packets[0]

n, err = u.handler(dst, &u.packets[0])
if err == io.ErrNoProgress {
if err == ErrFlagPending {
packet.Rx = forcedTime // Mark socket as needing handling but packet having no data.
} else {
packet.Rx = time.Time{} // Invalidate packet normally.
Expand Down
3 changes: 1 addition & 2 deletions stack/port_udp.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package stack

import (
"io"
"strconv"
"time"

Expand Down Expand Up @@ -48,7 +47,7 @@ func (u *udpPort) HandleEth(dst []byte) (int, error) {
packet := &u.packets[0]

n, err := u.handler(dst, &u.packets[0])
if err == io.ErrNoProgress {
if err == ErrFlagPending {
packet.Rx = forcedTime // Mark socket as needing handling but packet having no data.
} else {
packet.Rx = time.Time{} // Invalidate packet normally.
Expand Down
85 changes: 67 additions & 18 deletions stack/portstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,41 @@ func NewPortStack(cfg PortStackConfig) *PortStack {
return &s
}

var ErrFlagPending = io.ErrNoProgress

// PortStack implements partial TCP/UDP packet muxing to respective sockets with [PortStack.RcvEth].
// This implementation limits itself basic header validation and port matching.
// Users of PortStack are expected to implement connection state, packet buffering and retransmission logic.
// - In the case of TCP this means implementing the TCP state machine.
// - In the case of UDP PortStack should be enough to build most applications.
//
// # Notes on PortStack handlers
//
// - While PortStack.HandleEth has yet to find a outgoing packet it will look for
// a port that has a pending packet or has been flagged as pending and call its handler.
//
// - A call to a handler may or may not have an incoming packet ready to process.
// When pkt.HasPacket() returns true then pkt contains an incoming packet to the port.
//
// - When pkt.HasPacket() returns false the contents are undefined.
//
// - Users can safely use pkt even if pkt.HasPacket() returns false.
//
// - If the handler returns an error that is not ErrFlagPending then the port
// is immediately closed and written data is discarded.
//
// - ErrFlagPending: When returned by the handler then the port is flagged as
// pending and the written data is handled normally if there is any. If no data is written
// the call to HandleEth proceeds looking for another port to handle.
//
// - ErrFlagPending: When returned by the handler then for UDP/TCP implementations the
// incoming packet argument `pkt` is flagged as not present in future calls to the handler in pkt.HasPacket calls.
// The handler however can be aware of this fact and still use the pkt argument since the header+payload contents
// are not modified by the stack.
type PortStack struct {
lastRx time.Time
lastRxSuccess time.Time
lastTx time.Time
mac [6]byte
// Set IP to non-nil to ignore packets not meant for us.
IP netip.Addr
Expand Down Expand Up @@ -106,7 +133,7 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) {
return errPacketSmol
}
ps.debug("Stack.RecvEth:start", slog.Int("plen", len(payload)))
ps.lastRx = time.Now()
ps.lastRx = ps.now()

// Ethernet parsing block
ehdr = eth.DecodeEthernetHeader(payload)
Expand All @@ -124,7 +151,7 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) {
}
switch ahdr.Operation {
case 1: // ARP request.
if ps.hasPendingARP() || ahdr.ProtoTarget != ps.IP.As4() {
if ps.pendingReplyToARP() || ahdr.ProtoTarget != ps.IP.As4() {
return nil // ARP reply pending or not for us.
}
// We need to respond to this ARP request.
Expand Down Expand Up @@ -259,16 +286,21 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) {

// HandleEth searches for a socket with a pending packet and writes the response
// into the dst argument. The length written to dst is returned.
// [io.ErrNoProgress] can be returned by value by a handler to indicate the packet was
// [ErrFlagPending] can be returned by value by a handler to indicate the packet was
// not processed and that a future call to HandleEth is required to complete.
//
// If a handler returns any other error the port is closed.
func (ps *PortStack) HandleEth(dst []byte) (int, error) {
func (ps *PortStack) HandleEth(dst []byte) (n int, err error) {
defer func() {
if n > 0 && err == nil {
ps.lastTx = ps.now()
}
}()
switch {
case len(dst) < _MTU:
return 0, io.ErrShortBuffer

case ps.ARPresult.Operation == 1:
case ps.pendingRequestARP():
// We have a pending request from user to perform ARP.
ehdr := eth.EthernetHeader{
Destination: broadcastMAC,
Expand All @@ -280,7 +312,7 @@ func (ps *PortStack) HandleEth(dst []byte) (int, error) {
ps.ARPresult.Operation = arpOpWait // Clear pending ARP to not loop.
return eth.SizeEthernetHeader + eth.SizeARPv4Header, nil

case ps.hasPendingARP():
case ps.pendingReplyToARP():
// We need to respond to an ARP request that queries our address.
ehdr := eth.EthernetHeader{
Destination: ps.pendingARPresponse.HardwareSender,
Expand All @@ -304,28 +336,29 @@ func (ps *PortStack) HandleEth(dst []byte) (int, error) {
HandleEth(dst []byte) (int, error)
}

handleSocket := func(dst []byte, pending *uint32, sock Socket) (int, error) {
handleSocket := func(dst []byte, sock Socket) (int, bool, error) {
if !sock.IsPendingHandling() {
return 0, nil // Nothing to handle, just skip.
return 0, false, nil // Nothing to handle, just skip.
}
// Socket has an unhandled packet.
n, err := sock.HandleEth(dst)
if err == io.ErrNoProgress {
// Special case: Socket may have written data but needs future handling, flagged with the io.ErrNoProgress error.
return n, nil
if err == ErrFlagPending {
// Special case: Socket may have written data but needs future handling, flagged with the ErrFlagPending error.
return n, true, nil
}
// If we get here the socket has been handled, so we decrement the pending counter.
*pending--
if err != nil {
sock.Close()
return 0, err
n = 0
}
return n, nil
return n, false, err
}

if ps.pendingUDPv4 > 0 {
for i := range ps.UDPv4 {
n, err := handleSocket(dst, &ps.pendingUDPv4, &ps.UDPv4[i])
n, pending, err := handleSocket(dst, &ps.UDPv4[i])
if !pending {
ps.pendingUDPv4--
}
if err != nil {
return 0, err
} else if n > 0 {
Expand All @@ -337,7 +370,10 @@ func (ps *PortStack) HandleEth(dst []byte) (int, error) {

if ps.pendingTCPv4 > 0 {
for i := range ps.TCPv4 {
n, err := handleSocket(dst, &ps.pendingTCPv4, &ps.TCPv4[i])
n, pending, err := handleSocket(dst, &ps.TCPv4[i])
if !pending {
ps.pendingTCPv4--
}
if err != nil {
return 0, err
} else if n > 0 {
Expand Down Expand Up @@ -371,10 +407,19 @@ func (ps *PortStack) ARPv4Result() (eth.ARPv4Header, bool) {
return ps.ARPresult, ps.ARPresult.Operation == 2
}

func (ps *PortStack) hasPendingARP() bool {
func (ps *PortStack) pendingReplyToARP() bool {
return ps.pendingARPresponse.Operation == 2 // 2 means reply.
}

func (ps *PortStack) pendingRequestARP() bool {
return ps.ARPresult.Operation == 1 // User asked for a ARP request.
}

// IsPendingHandling checks if a call to HandleEth could possibly result in a packet being generated by the PortStack.
func (ps *PortStack) IsPendingHandling() bool {
return ps.pendingUDPv4 > 0 || ps.pendingTCPv4 > 0 || ps.pendingRequestARP() || ps.pendingReplyToARP()
}

// OpenUDP opens a UDP port and sets the handler. If the port is already open
// or if there is no socket available it returns an error.
func (ps *PortStack) OpenUDP(port uint16, handler func([]byte, *UDPPacket) (int, error)) error {
Expand Down Expand Up @@ -510,6 +555,10 @@ func (ps *PortStack) getTCP(port uint16) *tcpPort {
return nil
}

func (ps *PortStack) now() time.Time {
return time.Now()
}

func (ps *PortStack) info(msg string, attrs ...slog.Attr) {
ps.logAttrsPrint(slog.LevelInfo, msg, attrs...)
}
Expand Down
4 changes: 2 additions & 2 deletions stack/socket_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (t *TCPSocket) Close() error {

func (t *TCPSocket) handleMain(response []byte, pkt *TCPPacket) (n int, err error) {
defer func() {
if err != nil && t.abortErr == nil && err != io.ErrNoProgress {
if err != nil && t.abortErr == nil && err != ErrFlagPending {
err = nil // Only close socket if socket is aborted.
} else if err != nil {
t.stack.error("tcp socket", slog.Int("port", int(t.localPort)), slog.String("err", err.Error()))
Expand Down Expand Up @@ -225,7 +225,7 @@ func (t *TCPSocket) handleSend(response []byte, pkt *TCPPacket) (n int, err erro
pkt.PutHeaders(response)

if t.scb.HasPending() {
err = io.ErrNoProgress // Flag to PortStack that we have pending data to send.
err = ErrFlagPending // Flag to PortStack that we have pending data to send.
}
return sizeTCPNoOptions + n, err
}
Expand Down
1 change: 1 addition & 0 deletions stack/stack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ func exchangeStacks(t *testing.T, maxExchanges int, stacks ...*stack.PortStack)
if sentInTx == 0 {
break // No more data being sent.
}

for isend := 0; isend < len(stacks); isend++ {
// We deliver each in-flight packet to all stacks, except the one that sent it.
payload := getPayload(isend)
Expand Down

0 comments on commit e6396dc

Please sign in to comment.