diff --git a/stack/port_tcp.go b/stack/port_tcp.go index cb6e902..ebb78a4 100644 --- a/stack/port_tcp.go +++ b/stack/port_tcp.go @@ -78,6 +78,8 @@ func (port *tcpPort) freePacket() *TCPPacket { func (port *tcpPort) Open(portNum uint16, handler tcphandler) { if portNum == 0 || handler == nil { panic("invalid port or nil handler" + strconv.Itoa(int(port.port))) + } else if port.port != 0 { + panic("port already open") } port.handler = handler port.port = portNum @@ -110,7 +112,7 @@ func (port *tcpPort) forceResponse() (added bool) { return false } -const tcpMTU = _MTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeTCPHeader +const tcpMTU = defaultMTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeTCPHeader type TCPPacket struct { Rx time.Time diff --git a/stack/port_udp.go b/stack/port_udp.go index b1bd3eb..7bd038e 100644 --- a/stack/port_udp.go +++ b/stack/port_udp.go @@ -50,6 +50,8 @@ func (port *udpPort) HandleEth(dst []byte) (int, error) { func (port *udpPort) Open(portNum uint16, h udphandler) { if portNum == 0 || h == nil { panic("invalid port or nil handler" + strconv.Itoa(int(port.port))) + } else if port.port != 0 { + panic("port already open") } port.handler = h port.port = portNum @@ -110,7 +112,7 @@ type UDPPacket struct { Eth eth.EthernetHeader IP eth.IPv4Header UDP eth.UDPHeader - payload [_MTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeUDPHeader]byte + payload [defaultMTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeUDPHeader]byte } func (pkt *UDPPacket) HasPacket() bool { return pkt.Rx != forcedTime && !pkt.Rx.IsZero() } diff --git a/stack/portstack.go b/stack/portstack.go index 41e18cc..75bd34e 100644 --- a/stack/portstack.go +++ b/stack/portstack.go @@ -10,22 +10,29 @@ import ( "net" "net/netip" "slices" + "strconv" "time" "github.com/soypat/seqs/eth" ) const ( - _MTU = 1500 - arpOpWait = 0xffff + defaultMTU = 2048 + arpOpWait = 0xffff ) +type ethernethandler = func(ehdr *eth.EthernetHeader, ethPayload []byte) error + type PortStackConfig struct { - MAC [6]byte - // IP netip.Addr MaxOpenPortsUDP int MaxOpenPortsTCP int - Logger *slog.Logger + // GlobalHandler processes all incoming ethernet frames before they reach the port handlers. + // If GlobalHandler returns an error the frame is discarded and PortStack.HandleEth returns the error. + GlobalHandler ethernethandler + Logger *slog.Logger + MAC [6]byte + // MTU is the maximum transmission unit of the ethernet interface. + MTU uint16 } // NewPortStack creates a ready to use TCP/UDP Stack instance. @@ -36,6 +43,10 @@ func NewPortStack(cfg PortStackConfig) *PortStack { s.portsUDP = make([]udpPort, cfg.MaxOpenPortsUDP) s.portsTCP = make([]tcpPort, cfg.MaxOpenPortsTCP) s.logger = cfg.Logger + if cfg.MTU < defaultMTU { + panic("please use a larger MTU. min=" + strconv.Itoa(defaultMTU)) + } + s.mtu = cfg.MTU return &s } @@ -77,7 +88,7 @@ type PortStack struct { lastRx time.Time lastRxSuccess time.Time lastTx time.Time - glob func([]byte) + glob ethernethandler logger *slog.Logger portsUDP []udpPort portsTCP []tcpPort @@ -91,8 +102,11 @@ type PortStack struct { // pending ARP reply that must be sent out. pendingARPresponse eth.ARPv4Header ARPresult eth.ARPv4Header - mac [6]byte - ip [4]byte + // Auxiliary struct to avoid allocations passed to global handler. + auxEth eth.EthernetHeader + mac [6]byte + ip [4]byte + mtu uint16 } // Common errors. @@ -123,6 +137,7 @@ func (ps *PortStack) SetAddr(addr netip.Addr) { ps.ip = addr.As4() } +func (ps *PortStack) MTU() uint16 { return ps.mtu } func (ps *PortStack) MAC() net.HardwareAddr { return slices.Clone(ps.mac[:]) } func (ps *PortStack) MACAs6() [6]byte { return ps.mac } @@ -132,16 +147,13 @@ func (ps *PortStack) MACAs6() [6]byte { return ps.mac } // If [Stack.HandleEth] is not called often enough prevent packet queue from // filling up on a socket RecvEth will start to return [ErrDroppedPacket]. func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) { - var ehdr eth.EthernetHeader var ihdr eth.IPv4Header + var ehdr eth.EthernetHeader defer func() { if err != nil { ps.error("Stack.RecvEth", slog.String("err", err.Error()), slog.Any("IP", ihdr)) } else { ps.lastRxSuccess = ps.lastRx - if ps.glob != nil { - ps.glob(ethernetFrame) - } } }() payload := ethernetFrame @@ -153,6 +165,13 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) { // Ethernet parsing block ehdr = eth.DecodeEthernetHeader(payload) + if ps.glob != nil { + ps.auxEth = ehdr // Need auxiliary struct since glob is an arbitrary function, the escape analysis can't determine if argument escapes. + err = ps.glob(&ps.auxEth, payload[eth.SizeEthernetHeader:]) + if err != nil { + return err + } + } etype := ehdr.AssertType() if !eth.IsBroadcastHW(ehdr.Destination[:]) && !bytes.Equal(ehdr.Destination[:], ps.mac[:]) { return nil // Ignore packet, is not for us. @@ -202,7 +221,7 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) { return nil // Not for us. case uint16(offset) > end || int(offset) > len(payload) || int(end) > len(payload): return errors.New("bad IP TotalLength/IHL") - case end > _MTU: + case end > ps.mtu: return errPacketExceedsMTU } ipOptions := payload[eth.SizeEthernetHeader+eth.SizeIPv4Header : offset] // TODO add IPv4 options. @@ -233,7 +252,10 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) { port := findPort(ps.portsUDP, uhdr.DestinationPort) if port == nil { break // No socket listening on this port. - } else if port.NeedsHandling() { + } + + pkt := port.freePacket() + if pkt == nil { ps.error("UDP packet dropped") ps.droppedPackets++ return ErrDroppedPacket // Our socket needs handling before admitting more packets. @@ -244,12 +266,11 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) { ps.pendingUDPv4++ port.LastRx = ps.lastRx // set as unhandled here. - port.packets[0].Rx = ps.lastRx - port.packets[0].Eth = ehdr - port.packets[0].IP = ihdr - port.packets[0].UDP = uhdr - - copy(port.packets[0].payload[:], payload) + pkt.Rx = ps.lastRx + pkt.Eth = ehdr + pkt.IP = ihdr // TODO(soypat): Don't ignore IP options. + pkt.UDP = uhdr + copy(pkt.payload[:], payload) case 6: ps.info("TCP packet received", slog.Int("plen", len(payload))) @@ -278,23 +299,27 @@ func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) { port := findPort(ps.portsTCP, thdr.DestinationPort) if port == nil { break // No socket listening on this port. - } else if port.NeedsHandling() { + } + + pkt := port.freePacket() + if pkt == nil { ps.error("TCP packet dropped") ps.droppedPackets++ return ErrDroppedPacket // Our socket needs handling before admitting more packets. } + ps.info("TCP packet stored", slog.Int("plen", len(payload))) // Flag packets as needing processing. ps.pendingTCPv4++ port.LastRx = ps.lastRx // set as unhandled here. - port.packets[0].Rx = ps.lastRx - port.packets[0].Eth = ehdr - port.packets[0].IP = ihdr - port.packets[0].TCP = thdr - n := copy(port.packets[0].data[:], ipOptions) - n += copy(port.packets[0].data[n:], tcpOptions) - copy(port.packets[0].data[n:], payload) + pkt.Rx = ps.lastRx + pkt.Eth = ehdr + pkt.IP = ihdr + pkt.TCP = thdr + n := copy(pkt.data[:], ipOptions) + n += copy(pkt.data[n:], tcpOptions) + copy(pkt.data[n:], payload) } return nil } @@ -312,7 +337,7 @@ func (ps *PortStack) HandleEth(dst []byte) (n int, err error) { } }() switch { - case len(dst) < _MTU: + case len(dst) < int(ps.mtu): return 0, io.ErrShortBuffer case ps.pendingRequestARP(): @@ -595,20 +620,20 @@ type porter interface { Port() uint16 } -func findPort[T porter](list []T, port uint16) *T { +func findPort[T porter](list []T, portNum uint16) *T { for i := range list { - if list[i].Port() == port { + if list[i].Port() == portNum { return &list[i] } } return nil } -func findAvailPort[T porter](list []T, port uint16) (*T, error) { +func findAvailPort[T porter](list []T, portNum uint16) (*T, error) { availableIdx := -1 for i := range list { got := list[i].Port() - if got == port { + if got == portNum { availableIdx = -2 break } else if got == 0 { // Port==0 means port is unused. diff --git a/stack/socket_tcp.go b/stack/socket_tcp.go index 1f24b12..63d6964 100644 --- a/stack/socket_tcp.go +++ b/stack/socket_tcp.go @@ -114,12 +114,12 @@ func DialTCP(stack *PortStack, localPort uint16, remoteMAC [6]byte, remote netip // ListenTCP opens a passive TCP connection that listens on the given port. // ListenTCP only handles one connection at a time, so API may change in future to accomodate multiple connections. -func ListenTCP(stack *PortStack, port uint16, iss seqs.Value, window seqs.Size) (*TCPSocket, error) { +func ListenTCP(stack *PortStack, portNum uint16, iss seqs.Value, window seqs.Size) (*TCPSocket, error) { t := TCPSocket{ stack: stack, - localPort: port, + localPort: portNum, } - err := stack.OpenTCP(port, t.handleMain) + err := stack.OpenTCP(portNum, t.handleMain) if err != nil { return nil, err } @@ -240,6 +240,7 @@ func (t *TCPSocket) handleSend(response []byte, pkt *TCPPacket) (n int, err erro err = ErrFlagPending // Flag to PortStack that we have pending data to send. } else if t.scb.State() == seqs.StateClosed { err = io.EOF + t.close() } return sizeTCPNoOptions + n, err } @@ -289,8 +290,8 @@ func (t *TCPSocket) synsentSegment() seqs.Segment { } } -func (t *TCPSocket) abort(err error) error { - t.close() - t.abortErr = err - return err -} +// func (t *TCPSocket) abort(err error) error { +// t.close() +// t.abortErr = err +// return err +// } diff --git a/stack/stack_test.go b/stack/stack_test.go index 979f04a..39f9b62 100644 --- a/stack/stack_test.go +++ b/stack/stack_test.go @@ -306,6 +306,7 @@ func createPortStacks(t *testing.T, n int) (stacks []*stack.PortStack) { Stack := stack.NewPortStack(stack.PortStackConfig{ MAC: MAC, MaxOpenPortsTCP: 1, + MTU: 2048, }) Stack.SetAddr(ip) stacks = append(stacks, Stack)