Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inspect support. Fixes #457 #458

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 60 additions & 18 deletions ziti/edge/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,30 @@ package edge

import (
"encoding/binary"

"github.com/openziti/channel/v2"
"github.com/openziti/foundation/v2/uuidz"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)

const (
ContentTypeConnect = 60783
ContentTypeStateConnected = 60784
ContentTypeStateClosed = 60785
ContentTypeData = 60786
ContentTypeDial = 60787
ContentTypeDialSuccess = 60788
ContentTypeDialFailed = 60789
ContentTypeBind = 60790
ContentTypeUnbind = 60791
ContentTypeStateSessionEnded = 60792
ContentTypeProbe = 60793
ContentTypeUpdateBind = 60794
ContentTypeHealthEvent = 60795
ContentTypeTraceRoute = 60796
ContentTypeTraceRouteResponse = 60797
ContentTypeHostingCheck = 60798
ContentTypeHostingStatus = 60799
ContentTypeConnect = 60783
ContentTypeStateConnected = 60784
ContentTypeStateClosed = 60785
ContentTypeData = 60786
ContentTypeDial = 60787
ContentTypeDialSuccess = 60788
ContentTypeDialFailed = 60789
ContentTypeBind = 60790
ContentTypeUnbind = 60791
ContentTypeStateSessionEnded = 60792
ContentTypeProbe = 60793
ContentTypeUpdateBind = 60794
ContentTypeHealthEvent = 60795
ContentTypeTraceRoute = 60796
ContentTypeTraceRouteResponse = 60797
ContentTypeConnInspectRequest = 60798
ContentTypeConnInspectResponse = 60799

ConnIdHeader = 1000
SeqHeader = 1001
Expand All @@ -66,6 +65,8 @@ const (
TraceSourceRequestIdHeader = 1019
TraceError = 1020
ListenerId = 1021
ConnTypeHeader = 1022
SupportsInspectHeader = 1023

ErrorCodeInternal = 1
ErrorCodeInvalidApiSession = 2
Expand Down Expand Up @@ -164,6 +165,13 @@ func NewTraceRouteResponseMsg(connId uint32, hops uint32, timestamp uint64, hopT
return msg
}

func NewConnInspectResponse(connId uint32, connType ConnType, state string) *channel.Message {
msg := channel.NewMessage(ContentTypeConnInspectResponse, []byte(state))
msg.PutUint32Header(ConnIdHeader, connId)
msg.PutByteHeader(ConnTypeHeader, byte(connType))
return msg
}

func NewConnectMsg(connId uint32, token string, pubKey []byte, options *DialOptions) *channel.Message {
msg := newMsg(ContentTypeConnect, connId, 0, []byte(token))
if pubKey != nil {
Expand Down Expand Up @@ -199,6 +207,8 @@ func NewDialMsg(connId uint32, token string, callerId string) *channel.Message {

func NewBindMsg(connId uint32, token string, pubKey []byte, options *ListenOptions) *channel.Message {
msg := newMsg(ContentTypeBind, connId, 0, []byte(token))
msg.PutBoolHeader(SupportsInspectHeader, true)

if pubKey != nil {
msg.Headers[PublicKeyHeader] = pubKey
msg.PutByteHeader(CryptoMethodHeader, byte(CryptoMethodLibsodium))
Expand Down Expand Up @@ -323,3 +333,35 @@ func GetLoggerFields(msg *channel.Message) logrus.Fields {

return fields
}

type ConnType byte

const (
ConnTypeInvalid ConnType = 0
ConnTypeDial ConnType = 1
ConnTypeBind ConnType = 2
ConnTypeUnknown ConnType = 3
)

type InspectResult struct {
ConnId uint32
Type ConnType
Detail string
}

func UnmarshalInspectResult(msg *channel.Message) (*InspectResult, error) {
if msg.ContentType == ContentTypeConnInspectResponse {
connId, _ := msg.GetUint32Header(ConnIdHeader)
connType, found := msg.GetByteHeader(ConnTypeHeader)
if !found {
connType = byte(ConnTypeUnknown)
}
return &InspectResult{
ConnId: connId,
Type: ConnType(connType),
Detail: string(msg.Body),
}, nil
}

return nil, errors.Errorf("unexpected response. received %v instead of inspect result message", msg.ContentType)
}
10 changes: 9 additions & 1 deletion ziti/edge/msg_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
package edge

import (
"fmt"
"github.com/michaelquigley/pfxlog"
"github.com/openziti/channel/v2"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"math"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -80,7 +82,7 @@ func (mux *CowMapMsgMux) ContentType() int32 {
return ContentTypeData
}

func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, _ channel.Channel) {
func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, ch channel.Channel) {
connId, found := msg.GetUint32Header(ConnIdHeader)
if !found {
pfxlog.Logger().Errorf("received edge message with no connId header. content type: %v", msg.ContentType)
Expand All @@ -90,6 +92,12 @@ func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, _ channel.Channel)
sinks := mux.getSinks()
if sink, found := sinks[connId]; found {
sink.Accept(msg)
} else if msg.ContentType == ContentTypeConnInspectRequest {
resp := NewConnInspectResponse(connId, ConnTypeInvalid, fmt.Sprintf("invalid conn id [%v]", connId))
if err := resp.ReplyTo(msg).Send(ch); err != nil {
logrus.WithFields(GetLoggerFields(msg)).WithError(err).
Error("failed to send inspect response")
}
} else {
pfxlog.Logger().Debugf("unable to dispatch msg received for unknown edge conn id: %v", connId)
}
Expand Down
107 changes: 76 additions & 31 deletions ziti/edge/network/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
package network

import (
"encoding/json"
"fmt"
"github.com/openziti/edge-api/rest_model"
cmap "github.com/orcaman/concurrent-map/v2"
"io"
"net"
"sync"
"sync/atomic"
"time"

Expand All @@ -41,8 +42,8 @@ var unsupportedCrypto = errors.New("unsupported crypto")
type ConnType byte

const (
ConnTypeDial = 1
ConnTypeBind = 2
ConnTypeDial ConnType = 1
ConnTypeBind ConnType = 2
)

var _ edge.Conn = &edgeConn{}
Expand All @@ -52,11 +53,11 @@ type edgeConn struct {
readQ *noopSeq[*channel.Message]
leftover []byte
msgMux edge.MsgMux
hosting sync.Map
hosting cmap.ConcurrentMap[string, *edgeListener]
closed atomic.Bool
readFIN atomic.Bool
sentFIN atomic.Bool
serviceId string
serviceName string
sourceIdentity string
acceptCompleteHandler *newConnHandler
connType ConnType
Expand Down Expand Up @@ -100,8 +101,51 @@ func (conn *edgeConn) CloseWrite() error {
return nil
}

func (conn *edgeConn) Inspect() string {
result := map[string]interface{}{}
result["id"] = conn.Id()
result["serviceName"] = conn.serviceName
result["closed"] = conn.closed.Load()
result["encryptionRequired"] = conn.crypto

if conn.connType == ConnTypeDial {
result["encrypted"] = conn.rxKey != nil || conn.receiver != nil
result["readFIN"] = conn.readFIN.Load()
result["sentFIN"] = conn.sentFIN.Load()
}

if conn.connType == ConnTypeBind {
hosting := map[string]interface{}{}
for entry := range conn.hosting.IterBuffered() {
hosting[entry.Key] = map[string]interface{}{
"closed": entry.Val.closed.Load(),
"manualStart": entry.Val.manualStart,
"serviceId": *entry.Val.service.ID,
"serviceName": *entry.Val.service.Name,
}
}
result["hosting"] = hosting
}

jsonOutput, err := json.Marshal(result)
if err != nil {
pfxlog.Logger().WithError(err).Error("unable to marshal inspect result")
}
return string(jsonOutput)
}

func (conn *edgeConn) Accept(msg *channel.Message) {
conn.TraceMsg("Accept", msg)

if msg.ContentType == edge.ContentTypeConnInspectRequest {
resp := edge.NewConnInspectResponse(0, edge.ConnType(conn.connType), conn.Inspect())
if err := resp.ReplyTo(msg).Send(conn.Channel); err != nil {
logrus.WithFields(edge.GetLoggerFields(msg)).WithError(err).
Error("failed to send inspect response")
}
return
}

switch conn.connType {
case ConnTypeDial:
if msg.ContentType == edge.ContentTypeStateClosed {
Expand Down Expand Up @@ -144,9 +188,7 @@ func (conn *edgeConn) Accept(msg *channel.Message) {
if msg.ContentType == edge.ContentTypeDial {
logrus.WithFields(edge.GetLoggerFields(msg)).Debug("received dial request")
go conn.newChildConnection(msg)
}

if msg.ContentType == edge.ContentTypeStateClosed {
} else if msg.ContentType == edge.ContentTypeStateClosed {
conn.close(true)
}
default:
Expand All @@ -159,11 +201,11 @@ func (conn *edgeConn) IsClosed() bool {
}

func (conn *edgeConn) Network() string {
return conn.serviceId
return conn.serviceName
}

func (conn *edgeConn) String() string {
return fmt.Sprintf("zitiConn connId=%v svcId=%v sourceIdentity=%v", conn.Id(), conn.serviceId, conn.sourceIdentity)
return fmt.Sprintf("zitiConn connId=%v svcId=%v sourceIdentity=%v", conn.Id(), conn.serviceName, conn.sourceIdentity)
}

func (conn *edgeConn) LocalAddr() net.Addr {
Expand Down Expand Up @@ -300,7 +342,7 @@ func (conn *edgeConn) establishServerCrypto(keypair *kx.KeyPair, peerKey []byte,
}

func (conn *edgeConn) Listen(session *rest_model.SessionDetail, service *rest_model.ServiceDetail, options *edge.ListenOptions) (edge.Listener, error) {
logger := pfxlog.Logger().
logger := pfxlog.ContextLogger(conn.Channel.Label()).
WithField("connId", conn.Id()).
WithField("serviceName", *service.Name).
WithField("sessionId", *session.ID)
Expand All @@ -316,20 +358,13 @@ func (conn *edgeConn) Listen(session *rest_model.SessionDetail, service *rest_mo
manualStart: options.ManualStart,
}
logger.Debug("adding listener for session")
conn.hosting.Store(*session.Token, listener)
conn.hosting.Set(*session.Token, listener)

success := false
defer func() {
if !success {
logger.Debug("removing listener for session")
conn.hosting.Delete(*session.Token)

unbindRequest := edge.NewUnbindMsg(conn.Id(), listener.token)
listener.edgeChan.TraceMsg("close", unbindRequest)
if err := unbindRequest.WithTimeout(5 * time.Second).SendAndWaitForWire(conn.Channel); err != nil {
logger.WithError(err).Error("unable to unbind session for conn")
}

conn.unbind(logger, listener.token)
}
}()

Expand Down Expand Up @@ -363,6 +398,19 @@ func (conn *edgeConn) Listen(session *rest_model.SessionDetail, service *rest_mo
return listener, nil
}

func (conn *edgeConn) unbind(logger *logrus.Entry, token string) {
logger.Debug("starting unbind")

conn.hosting.Remove(token)

unbindRequest := edge.NewUnbindMsg(conn.Id(), token)
if err := unbindRequest.WithTimeout(5 * time.Second).SendAndWaitForWire(conn.Channel); err != nil {
logger.WithError(err).Error("unable to send unbind msg for conn")
} else {
logger.Debug("unbind message sent successfully")
}
}

func (conn *edgeConn) Read(p []byte) (int, error) {
log := pfxlog.Logger().WithField("connId", conn.Id())
if conn.closed.Load() {
Expand Down Expand Up @@ -472,21 +520,18 @@ func (conn *edgeConn) close(closedByRemote bool) {
conn.readQ.Close()
conn.msgMux.RemoveMsgSink(conn) // if we switch back to ChMsgMux will need to be done async again, otherwise we may deadlock

conn.hosting.Range(func(key, value interface{}) bool {
listener := value.(*edgeListener)
if err := listener.close(closedByRemote); err != nil {
log.WithError(err).WithField("serviceName", *listener.service.Name).Error("failed to close listener")
if conn.connType == ConnTypeBind {
for entry := range conn.hosting.IterBuffered() {
listener := entry.Val
if err := listener.close(closedByRemote); err != nil {
log.WithError(err).WithField("serviceName", *listener.service.Name).Error("failed to close listener")
}
}
return true
})
}
}

func (conn *edgeConn) getListener(token string) (*edgeListener, bool) {
if val, found := conn.hosting.Load(token); found {
listener, ok := val.(*edgeListener)
return listener, ok
}
return nil, false
return conn.hosting.Get(token)
}

func (conn *edgeConn) newChildConnection(message *channel.Message) {
Expand Down
16 changes: 8 additions & 8 deletions ziti/edge/network/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ func BenchmarkConnWrite(b *testing.B) {
mux := edge.NewCowMapMsgMux()
testChannel := &NoopTestChannel{}
conn := &edgeConn{
MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1),
readQ: NewNoopSequencer[*channel.Message](4),
msgMux: mux,
serviceId: "test",
MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1),
readQ: NewNoopSequencer[*channel.Message](4),
msgMux: mux,
serviceName: "test",
}

req := require.New(b)
Expand All @@ -55,10 +55,10 @@ func BenchmarkConnRead(b *testing.B) {

readQ := NewNoopSequencer[*channel.Message](4)
conn := &edgeConn{
MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1),
readQ: readQ,
msgMux: mux,
serviceId: "test",
MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1),
readQ: readQ,
msgMux: mux,
serviceName: "test",
}

var stop atomic.Bool
Expand Down
Loading
Loading