Skip to content

Commit

Permalink
add MTU semantics to PortStack
Browse files Browse the repository at this point in the history
  • Loading branch information
soypat committed Nov 20, 2023
1 parent a04addc commit 2cc1e3a
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 43 deletions.
4 changes: 3 additions & 1 deletion stack/port_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ func (port *tcpPort) freePacket() *TCPPacket {
func (port *tcpPort) Open(portNum uint16, handler tcphandler) {
if portNum == 0 || handler == nil {
panic("invalid port or nil handler" + strconv.Itoa(int(port.port)))
} else if port.port != 0 {
panic("port already open")
}
port.handler = handler
port.port = portNum
Expand Down Expand Up @@ -110,7 +112,7 @@ func (port *tcpPort) forceResponse() (added bool) {
return false
}

const tcpMTU = _MTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeTCPHeader
const tcpMTU = defaultMTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeTCPHeader

type TCPPacket struct {
Rx time.Time
Expand Down
4 changes: 3 additions & 1 deletion stack/port_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ func (port *udpPort) HandleEth(dst []byte) (int, error) {
func (port *udpPort) Open(portNum uint16, h udphandler) {
if portNum == 0 || h == nil {
panic("invalid port or nil handler" + strconv.Itoa(int(port.port)))
} else if port.port != 0 {
panic("port already open")
}
port.handler = h
port.port = portNum
Expand Down Expand Up @@ -110,7 +112,7 @@ type UDPPacket struct {
Eth eth.EthernetHeader
IP eth.IPv4Header
UDP eth.UDPHeader
payload [_MTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeUDPHeader]byte
payload [defaultMTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeUDPHeader]byte
}

func (pkt *UDPPacket) HasPacket() bool { return pkt.Rx != forcedTime && !pkt.Rx.IsZero() }
Expand Down
91 changes: 58 additions & 33 deletions stack/portstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,29 @@ import (
"net"
"net/netip"
"slices"

Check failure on line 12 in stack/portstack.go

View workflow job for this annotation

GitHub Actions / build

package slices is not in GOROOT (/opt/hostedtoolcache/go/1.19.13/x64/src/slices)
"strconv"
"time"

"github.com/soypat/seqs/eth"
)

const (
_MTU = 1500
arpOpWait = 0xffff
defaultMTU = 2048
arpOpWait = 0xffff
)

type ethernethandler = func(ehdr *eth.EthernetHeader, ethPayload []byte) error

type PortStackConfig struct {
MAC [6]byte
// IP netip.Addr
MaxOpenPortsUDP int
MaxOpenPortsTCP int
Logger *slog.Logger
// GlobalHandler processes all incoming ethernet frames before they reach the port handlers.
// If GlobalHandler returns an error the frame is discarded and PortStack.HandleEth returns the error.
GlobalHandler ethernethandler
Logger *slog.Logger
MAC [6]byte
// MTU is the maximum transmission unit of the ethernet interface.
MTU uint16
}

// NewPortStack creates a ready to use TCP/UDP Stack instance.
Expand All @@ -36,6 +43,10 @@ func NewPortStack(cfg PortStackConfig) *PortStack {
s.portsUDP = make([]udpPort, cfg.MaxOpenPortsUDP)
s.portsTCP = make([]tcpPort, cfg.MaxOpenPortsTCP)
s.logger = cfg.Logger
if cfg.MTU < defaultMTU {
panic("please use a larger MTU. min=" + strconv.Itoa(defaultMTU))
}
s.mtu = cfg.MTU
return &s
}

Expand Down Expand Up @@ -77,7 +88,7 @@ type PortStack struct {
lastRx time.Time
lastRxSuccess time.Time
lastTx time.Time
glob func([]byte)
glob ethernethandler
logger *slog.Logger
portsUDP []udpPort
portsTCP []tcpPort
Expand All @@ -91,8 +102,11 @@ type PortStack struct {
// pending ARP reply that must be sent out.
pendingARPresponse eth.ARPv4Header
ARPresult eth.ARPv4Header
mac [6]byte
ip [4]byte
// Auxiliary struct to avoid allocations passed to global handler.
auxEth eth.EthernetHeader
mac [6]byte
ip [4]byte
mtu uint16
}

// Common errors.
Expand Down Expand Up @@ -123,6 +137,7 @@ func (ps *PortStack) SetAddr(addr netip.Addr) {
ps.ip = addr.As4()
}

func (ps *PortStack) MTU() uint16 { return ps.mtu }
func (ps *PortStack) MAC() net.HardwareAddr { return slices.Clone(ps.mac[:]) }
func (ps *PortStack) MACAs6() [6]byte { return ps.mac }

Expand All @@ -132,16 +147,13 @@ func (ps *PortStack) MACAs6() [6]byte { return ps.mac }
// If [Stack.HandleEth] is not called often enough prevent packet queue from
// filling up on a socket RecvEth will start to return [ErrDroppedPacket].
func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) {
var ehdr eth.EthernetHeader
var ihdr eth.IPv4Header
var ehdr eth.EthernetHeader
defer func() {
if err != nil {
ps.error("Stack.RecvEth", slog.String("err", err.Error()), slog.Any("IP", ihdr))
} else {
ps.lastRxSuccess = ps.lastRx
if ps.glob != nil {
ps.glob(ethernetFrame)
}
}
}()
payload := ethernetFrame
Expand All @@ -153,6 +165,13 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) {

// Ethernet parsing block
ehdr = eth.DecodeEthernetHeader(payload)
if ps.glob != nil {
ps.auxEth = ehdr // Need auxiliary struct since glob is an arbitrary function, the escape analysis can't determine if argument escapes.
err = ps.glob(&ps.auxEth, payload[eth.SizeEthernetHeader:])
if err != nil {
return err
}
}
etype := ehdr.AssertType()
if !eth.IsBroadcastHW(ehdr.Destination[:]) && !bytes.Equal(ehdr.Destination[:], ps.mac[:]) {
return nil // Ignore packet, is not for us.
Expand Down Expand Up @@ -202,7 +221,7 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) {
return nil // Not for us.
case uint16(offset) > end || int(offset) > len(payload) || int(end) > len(payload):
return errors.New("bad IP TotalLength/IHL")
case end > _MTU:
case end > ps.mtu:
return errPacketExceedsMTU
}
ipOptions := payload[eth.SizeEthernetHeader+eth.SizeIPv4Header : offset] // TODO add IPv4 options.
Expand Down Expand Up @@ -233,7 +252,10 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) {
port := findPort(ps.portsUDP, uhdr.DestinationPort)
if port == nil {
break // No socket listening on this port.
} else if port.NeedsHandling() {
}

pkt := port.freePacket()
if pkt == nil {
ps.error("UDP packet dropped")
ps.droppedPackets++
return ErrDroppedPacket // Our socket needs handling before admitting more packets.
Expand All @@ -244,12 +266,11 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) {
ps.pendingUDPv4++
port.LastRx = ps.lastRx // set as unhandled here.

port.packets[0].Rx = ps.lastRx
port.packets[0].Eth = ehdr
port.packets[0].IP = ihdr
port.packets[0].UDP = uhdr

copy(port.packets[0].payload[:], payload)
pkt.Rx = ps.lastRx
pkt.Eth = ehdr
pkt.IP = ihdr // TODO(soypat): Don't ignore IP options.
pkt.UDP = uhdr
copy(pkt.payload[:], payload)

case 6:
ps.info("TCP packet received", slog.Int("plen", len(payload)))
Expand Down Expand Up @@ -278,23 +299,27 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) {
port := findPort(ps.portsTCP, thdr.DestinationPort)
if port == nil {
break // No socket listening on this port.
} else if port.NeedsHandling() {
}

pkt := port.freePacket()
if pkt == nil {
ps.error("TCP packet dropped")
ps.droppedPackets++
return ErrDroppedPacket // Our socket needs handling before admitting more packets.
}

ps.info("TCP packet stored", slog.Int("plen", len(payload)))
// Flag packets as needing processing.
ps.pendingTCPv4++
port.LastRx = ps.lastRx // set as unhandled here.

port.packets[0].Rx = ps.lastRx
port.packets[0].Eth = ehdr
port.packets[0].IP = ihdr
port.packets[0].TCP = thdr
n := copy(port.packets[0].data[:], ipOptions)
n += copy(port.packets[0].data[n:], tcpOptions)
copy(port.packets[0].data[n:], payload)
pkt.Rx = ps.lastRx
pkt.Eth = ehdr
pkt.IP = ihdr
pkt.TCP = thdr
n := copy(pkt.data[:], ipOptions)
n += copy(pkt.data[n:], tcpOptions)
copy(pkt.data[n:], payload)
}
return nil
}
Expand All @@ -312,7 +337,7 @@ func (ps *PortStack) HandleEth(dst []byte) (n int, err error) {
}
}()
switch {
case len(dst) < _MTU:
case len(dst) < int(ps.mtu):
return 0, io.ErrShortBuffer

case ps.pendingRequestARP():
Expand Down Expand Up @@ -595,20 +620,20 @@ type porter interface {
Port() uint16
}

func findPort[T porter](list []T, port uint16) *T {
func findPort[T porter](list []T, portNum uint16) *T {
for i := range list {
if list[i].Port() == port {
if list[i].Port() == portNum {
return &list[i]
}
}
return nil
}

func findAvailPort[T porter](list []T, port uint16) (*T, error) {
func findAvailPort[T porter](list []T, portNum uint16) (*T, error) {
availableIdx := -1
for i := range list {
got := list[i].Port()
if got == port {
if got == portNum {
availableIdx = -2
break
} else if got == 0 { // Port==0 means port is unused.
Expand Down
17 changes: 9 additions & 8 deletions stack/socket_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ func DialTCP(stack *PortStack, localPort uint16, remoteMAC [6]byte, remote netip

// ListenTCP opens a passive TCP connection that listens on the given port.
// ListenTCP only handles one connection at a time, so API may change in future to accomodate multiple connections.
func ListenTCP(stack *PortStack, port uint16, iss seqs.Value, window seqs.Size) (*TCPSocket, error) {
func ListenTCP(stack *PortStack, portNum uint16, iss seqs.Value, window seqs.Size) (*TCPSocket, error) {
t := TCPSocket{
stack: stack,
localPort: port,
localPort: portNum,
}
err := stack.OpenTCP(port, t.handleMain)
err := stack.OpenTCP(portNum, t.handleMain)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -240,6 +240,7 @@ func (t *TCPSocket) handleSend(response []byte, pkt *TCPPacket) (n int, err erro
err = ErrFlagPending // Flag to PortStack that we have pending data to send.
} else if t.scb.State() == seqs.StateClosed {
err = io.EOF
t.close()
}
return sizeTCPNoOptions + n, err
}
Expand Down Expand Up @@ -289,8 +290,8 @@ func (t *TCPSocket) synsentSegment() seqs.Segment {
}
}

func (t *TCPSocket) abort(err error) error {
t.close()
t.abortErr = err
return err
}
// func (t *TCPSocket) abort(err error) error {
// t.close()
// t.abortErr = err
// return err
// }
1 change: 1 addition & 0 deletions stack/stack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ func createPortStacks(t *testing.T, n int) (stacks []*stack.PortStack) {
Stack := stack.NewPortStack(stack.PortStackConfig{
MAC: MAC,
MaxOpenPortsTCP: 1,
MTU: 2048,
})
Stack.SetAddr(ip)
stacks = append(stacks, Stack)
Expand Down

0 comments on commit 2cc1e3a

Please sign in to comment.