diff --git a/ziti/edge/messages.go b/ziti/edge/messages.go index a29b7f72..5b263c59 100644 --- a/ziti/edge/messages.go +++ b/ziti/edge/messages.go @@ -18,7 +18,6 @@ package edge import ( "encoding/binary" - "github.com/openziti/channel/v2" "github.com/openziti/foundation/v2/uuidz" "github.com/pkg/errors" @@ -26,23 +25,23 @@ import ( ) const ( - ContentTypeConnect = 60783 - ContentTypeStateConnected = 60784 - ContentTypeStateClosed = 60785 - ContentTypeData = 60786 - ContentTypeDial = 60787 - ContentTypeDialSuccess = 60788 - ContentTypeDialFailed = 60789 - ContentTypeBind = 60790 - ContentTypeUnbind = 60791 - ContentTypeStateSessionEnded = 60792 - ContentTypeProbe = 60793 - ContentTypeUpdateBind = 60794 - ContentTypeHealthEvent = 60795 - ContentTypeTraceRoute = 60796 - ContentTypeTraceRouteResponse = 60797 - ContentTypeHostingCheck = 60798 - ContentTypeHostingStatus = 60799 + ContentTypeConnect = 60783 + ContentTypeStateConnected = 60784 + ContentTypeStateClosed = 60785 + ContentTypeData = 60786 + ContentTypeDial = 60787 + ContentTypeDialSuccess = 60788 + ContentTypeDialFailed = 60789 + ContentTypeBind = 60790 + ContentTypeUnbind = 60791 + ContentTypeStateSessionEnded = 60792 + ContentTypeProbe = 60793 + ContentTypeUpdateBind = 60794 + ContentTypeHealthEvent = 60795 + ContentTypeTraceRoute = 60796 + ContentTypeTraceRouteResponse = 60797 + ContentTypeConnInspectRequest = 60798 + ContentTypeConnInspectResponse = 60799 ConnIdHeader = 1000 SeqHeader = 1001 @@ -66,6 +65,8 @@ const ( TraceSourceRequestIdHeader = 1019 TraceError = 1020 ListenerId = 1021 + ConnTypeHeader = 1022 + SupportsInspectHeader = 1023 ErrorCodeInternal = 1 ErrorCodeInvalidApiSession = 2 @@ -164,6 +165,13 @@ func NewTraceRouteResponseMsg(connId uint32, hops uint32, timestamp uint64, hopT return msg } +func NewConnInspectResponse(connId uint32, connType ConnType, state string) *channel.Message { + msg := channel.NewMessage(ContentTypeConnInspectResponse, []byte(state)) + msg.PutUint32Header(ConnIdHeader, connId) + msg.PutByteHeader(ConnTypeHeader, byte(connType)) + return msg +} + func NewConnectMsg(connId uint32, token string, pubKey []byte, options *DialOptions) *channel.Message { msg := newMsg(ContentTypeConnect, connId, 0, []byte(token)) if pubKey != nil { @@ -199,6 +207,8 @@ func NewDialMsg(connId uint32, token string, callerId string) *channel.Message { func NewBindMsg(connId uint32, token string, pubKey []byte, options *ListenOptions) *channel.Message { msg := newMsg(ContentTypeBind, connId, 0, []byte(token)) + msg.PutBoolHeader(SupportsInspectHeader, true) + if pubKey != nil { msg.Headers[PublicKeyHeader] = pubKey msg.PutByteHeader(CryptoMethodHeader, byte(CryptoMethodLibsodium)) @@ -323,3 +333,35 @@ func GetLoggerFields(msg *channel.Message) logrus.Fields { return fields } + +type ConnType byte + +const ( + ConnTypeInvalid ConnType = 0 + ConnTypeDial ConnType = 1 + ConnTypeBind ConnType = 2 + ConnTypeUnknown ConnType = 3 +) + +type InspectResult struct { + ConnId uint32 + Type ConnType + Detail string +} + +func UnmarshalInspectResult(msg *channel.Message) (*InspectResult, error) { + if msg.ContentType == ContentTypeConnInspectResponse { + connId, _ := msg.GetUint32Header(ConnIdHeader) + connType, found := msg.GetByteHeader(ConnTypeHeader) + if !found { + connType = byte(ConnTypeUnknown) + } + return &InspectResult{ + ConnId: connId, + Type: ConnType(connType), + Detail: string(msg.Body), + }, nil + } + + return nil, errors.Errorf("unexpected response. received %v instead of inspect result message", msg.ContentType) +} diff --git a/ziti/edge/msg_mux.go b/ziti/edge/msg_mux.go index 8e2a64e9..3ed424c8 100644 --- a/ziti/edge/msg_mux.go +++ b/ziti/edge/msg_mux.go @@ -17,9 +17,11 @@ package edge import ( + "fmt" "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" "github.com/pkg/errors" + "github.com/sirupsen/logrus" "math" "sync" "sync/atomic" @@ -80,7 +82,7 @@ func (mux *CowMapMsgMux) ContentType() int32 { return ContentTypeData } -func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, _ channel.Channel) { +func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, ch channel.Channel) { connId, found := msg.GetUint32Header(ConnIdHeader) if !found { pfxlog.Logger().Errorf("received edge message with no connId header. content type: %v", msg.ContentType) @@ -90,6 +92,12 @@ func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, _ channel.Channel) sinks := mux.getSinks() if sink, found := sinks[connId]; found { sink.Accept(msg) + } else if msg.ContentType == ContentTypeConnInspectRequest { + resp := NewConnInspectResponse(connId, ConnTypeInvalid, fmt.Sprintf("invalid conn id [%v]", connId)) + if err := resp.ReplyTo(msg).Send(ch); err != nil { + logrus.WithFields(GetLoggerFields(msg)).WithError(err). + Error("failed to send inspect response") + } } else { pfxlog.Logger().Debugf("unable to dispatch msg received for unknown edge conn id: %v", connId) } diff --git a/ziti/edge/network/conn.go b/ziti/edge/network/conn.go index 431103dd..ab42eab8 100644 --- a/ziti/edge/network/conn.go +++ b/ziti/edge/network/conn.go @@ -17,11 +17,12 @@ package network import ( + "encoding/json" "fmt" "github.com/openziti/edge-api/rest_model" + cmap "github.com/orcaman/concurrent-map/v2" "io" "net" - "sync" "sync/atomic" "time" @@ -41,8 +42,8 @@ var unsupportedCrypto = errors.New("unsupported crypto") type ConnType byte const ( - ConnTypeDial = 1 - ConnTypeBind = 2 + ConnTypeDial ConnType = 1 + ConnTypeBind ConnType = 2 ) var _ edge.Conn = &edgeConn{} @@ -52,11 +53,11 @@ type edgeConn struct { readQ *noopSeq[*channel.Message] leftover []byte msgMux edge.MsgMux - hosting sync.Map + hosting cmap.ConcurrentMap[string, *edgeListener] closed atomic.Bool readFIN atomic.Bool sentFIN atomic.Bool - serviceId string + serviceName string sourceIdentity string acceptCompleteHandler *newConnHandler connType ConnType @@ -100,8 +101,51 @@ func (conn *edgeConn) CloseWrite() error { return nil } +func (conn *edgeConn) Inspect() string { + result := map[string]interface{}{} + result["id"] = conn.Id() + result["serviceName"] = conn.serviceName + result["closed"] = conn.closed.Load() + result["encryptionRequired"] = conn.crypto + + if conn.connType == ConnTypeDial { + result["encrypted"] = conn.rxKey != nil || conn.receiver != nil + result["readFIN"] = conn.readFIN.Load() + result["sentFIN"] = conn.sentFIN.Load() + } + + if conn.connType == ConnTypeBind { + hosting := map[string]interface{}{} + for entry := range conn.hosting.IterBuffered() { + hosting[entry.Key] = map[string]interface{}{ + "closed": entry.Val.closed.Load(), + "manualStart": entry.Val.manualStart, + "serviceId": *entry.Val.service.ID, + "serviceName": *entry.Val.service.Name, + } + } + result["hosting"] = hosting + } + + jsonOutput, err := json.Marshal(result) + if err != nil { + pfxlog.Logger().WithError(err).Error("unable to marshal inspect result") + } + return string(jsonOutput) +} + func (conn *edgeConn) Accept(msg *channel.Message) { conn.TraceMsg("Accept", msg) + + if msg.ContentType == edge.ContentTypeConnInspectRequest { + resp := edge.NewConnInspectResponse(0, edge.ConnType(conn.connType), conn.Inspect()) + if err := resp.ReplyTo(msg).Send(conn.Channel); err != nil { + logrus.WithFields(edge.GetLoggerFields(msg)).WithError(err). + Error("failed to send inspect response") + } + return + } + switch conn.connType { case ConnTypeDial: if msg.ContentType == edge.ContentTypeStateClosed { @@ -144,9 +188,7 @@ func (conn *edgeConn) Accept(msg *channel.Message) { if msg.ContentType == edge.ContentTypeDial { logrus.WithFields(edge.GetLoggerFields(msg)).Debug("received dial request") go conn.newChildConnection(msg) - } - - if msg.ContentType == edge.ContentTypeStateClosed { + } else if msg.ContentType == edge.ContentTypeStateClosed { conn.close(true) } default: @@ -159,11 +201,11 @@ func (conn *edgeConn) IsClosed() bool { } func (conn *edgeConn) Network() string { - return conn.serviceId + return conn.serviceName } func (conn *edgeConn) String() string { - return fmt.Sprintf("zitiConn connId=%v svcId=%v sourceIdentity=%v", conn.Id(), conn.serviceId, conn.sourceIdentity) + return fmt.Sprintf("zitiConn connId=%v svcId=%v sourceIdentity=%v", conn.Id(), conn.serviceName, conn.sourceIdentity) } func (conn *edgeConn) LocalAddr() net.Addr { @@ -300,7 +342,7 @@ func (conn *edgeConn) establishServerCrypto(keypair *kx.KeyPair, peerKey []byte, } func (conn *edgeConn) Listen(session *rest_model.SessionDetail, service *rest_model.ServiceDetail, options *edge.ListenOptions) (edge.Listener, error) { - logger := pfxlog.Logger(). + logger := pfxlog.ContextLogger(conn.Channel.Label()). WithField("connId", conn.Id()). WithField("serviceName", *service.Name). WithField("sessionId", *session.ID) @@ -316,20 +358,13 @@ func (conn *edgeConn) Listen(session *rest_model.SessionDetail, service *rest_mo manualStart: options.ManualStart, } logger.Debug("adding listener for session") - conn.hosting.Store(*session.Token, listener) + conn.hosting.Set(*session.Token, listener) success := false defer func() { if !success { logger.Debug("removing listener for session") - conn.hosting.Delete(*session.Token) - - unbindRequest := edge.NewUnbindMsg(conn.Id(), listener.token) - listener.edgeChan.TraceMsg("close", unbindRequest) - if err := unbindRequest.WithTimeout(5 * time.Second).SendAndWaitForWire(conn.Channel); err != nil { - logger.WithError(err).Error("unable to unbind session for conn") - } - + conn.unbind(logger, listener.token) } }() @@ -363,6 +398,19 @@ func (conn *edgeConn) Listen(session *rest_model.SessionDetail, service *rest_mo return listener, nil } +func (conn *edgeConn) unbind(logger *logrus.Entry, token string) { + logger.Debug("starting unbind") + + conn.hosting.Remove(token) + + unbindRequest := edge.NewUnbindMsg(conn.Id(), token) + if err := unbindRequest.WithTimeout(5 * time.Second).SendAndWaitForWire(conn.Channel); err != nil { + logger.WithError(err).Error("unable to send unbind msg for conn") + } else { + logger.Debug("unbind message sent successfully") + } +} + func (conn *edgeConn) Read(p []byte) (int, error) { log := pfxlog.Logger().WithField("connId", conn.Id()) if conn.closed.Load() { @@ -472,21 +520,18 @@ func (conn *edgeConn) close(closedByRemote bool) { conn.readQ.Close() conn.msgMux.RemoveMsgSink(conn) // if we switch back to ChMsgMux will need to be done async again, otherwise we may deadlock - conn.hosting.Range(func(key, value interface{}) bool { - listener := value.(*edgeListener) - if err := listener.close(closedByRemote); err != nil { - log.WithError(err).WithField("serviceName", *listener.service.Name).Error("failed to close listener") + if conn.connType == ConnTypeBind { + for entry := range conn.hosting.IterBuffered() { + listener := entry.Val + if err := listener.close(closedByRemote); err != nil { + log.WithError(err).WithField("serviceName", *listener.service.Name).Error("failed to close listener") + } } - return true - }) + } } func (conn *edgeConn) getListener(token string) (*edgeListener, bool) { - if val, found := conn.hosting.Load(token); found { - listener, ok := val.(*edgeListener) - return listener, ok - } - return nil, false + return conn.hosting.Get(token) } func (conn *edgeConn) newChildConnection(message *channel.Message) { diff --git a/ziti/edge/network/conn_test.go b/ziti/edge/network/conn_test.go index 68fbaaac..4b5449af 100644 --- a/ziti/edge/network/conn_test.go +++ b/ziti/edge/network/conn_test.go @@ -30,10 +30,10 @@ func BenchmarkConnWrite(b *testing.B) { mux := edge.NewCowMapMsgMux() testChannel := &NoopTestChannel{} conn := &edgeConn{ - MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1), - readQ: NewNoopSequencer[*channel.Message](4), - msgMux: mux, - serviceId: "test", + MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1), + readQ: NewNoopSequencer[*channel.Message](4), + msgMux: mux, + serviceName: "test", } req := require.New(b) @@ -55,10 +55,10 @@ func BenchmarkConnRead(b *testing.B) { readQ := NewNoopSequencer[*channel.Message](4) conn := &edgeConn{ - MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1), - readQ: readQ, - msgMux: mux, - serviceId: "test", + MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1), + readQ: readQ, + msgMux: mux, + serviceName: "test", } var stop atomic.Bool diff --git a/ziti/edge/network/factory.go b/ziti/edge/network/factory.go index 5b7f8905..28c79244 100644 --- a/ziti/edge/network/factory.go +++ b/ziti/edge/network/factory.go @@ -22,6 +22,7 @@ import ( "github.com/openziti/edge-api/rest_model" "github.com/openziti/sdk-golang/ziti/edge" "github.com/openziti/secretstream/kx" + cmap "github.com/orcaman/concurrent-map/v2" ) type RouterConnOwner interface { @@ -67,6 +68,7 @@ func (conn *routerConn) BindChannel(binding channel.Binding) error { binding.AddReceiveHandlerF(edge.ContentTypeDial, conn.msgMux.HandleReceive) binding.AddReceiveHandlerF(edge.ContentTypeStateClosed, conn.msgMux.HandleReceive) binding.AddReceiveHandlerF(edge.ContentTypeTraceRoute, conn.msgMux.HandleReceive) + binding.AddReceiveHandlerF(edge.ContentTypeConnInspectRequest, conn.msgMux.HandleReceive) // Since data is the common message type, it gets to be dispatched directly binding.AddTypedReceiveHandler(conn.msgMux) @@ -80,11 +82,11 @@ func (conn *routerConn) NewDialConn(service *rest_model.ServiceDetail) *edgeConn id := conn.msgMux.GetNextId() edgeCh := &edgeConn{ - MsgChannel: *edge.NewEdgeMsgChannel(conn.ch, id), - readQ: NewNoopSequencer[*channel.Message](4), - msgMux: conn.msgMux, - serviceId: *service.Name, - connType: ConnTypeDial, + MsgChannel: *edge.NewEdgeMsgChannel(conn.ch, id), + readQ: NewNoopSequencer[*channel.Message](4), + msgMux: conn.msgMux, + serviceName: *service.Name, + connType: ConnTypeDial, } var err error @@ -107,19 +109,25 @@ func (conn *routerConn) NewListenConn(service *rest_model.ServiceDetail, keyPair id := conn.msgMux.GetNextId() edgeCh := &edgeConn{ - MsgChannel: *edge.NewEdgeMsgChannel(conn.ch, id), - readQ: NewNoopSequencer[*channel.Message](4), - msgMux: conn.msgMux, - serviceId: *service.Name, - connType: ConnTypeBind, - keyPair: keyPair, - crypto: keyPair != nil, + MsgChannel: *edge.NewEdgeMsgChannel(conn.ch, id), + readQ: NewNoopSequencer[*channel.Message](4), + msgMux: conn.msgMux, + serviceName: *service.Name, + connType: ConnTypeBind, + keyPair: keyPair, + crypto: keyPair != nil, + hosting: cmap.New[*edgeListener](), } // duplicate errors only happen on the server side, since client controls ids if err := conn.msgMux.AddMsgSink(edgeCh); err != nil { pfxlog.Logger().Warnf("error adding message sink %s[%d]: %v", *service.Name, id, err) } + pfxlog.Logger().WithField("connId", id). + WithField("routerName", conn.routerName). + WithField("serviceId", *service.ID). + WithField("serviceName", *service.Name). + Debug("created new listener connection") return edgeCh } @@ -136,13 +144,23 @@ func (conn *routerConn) Connect(service *rest_model.ServiceDetail, session *rest func (conn *routerConn) Listen(service *rest_model.ServiceDetail, session *rest_model.SessionDetail, options *edge.ListenOptions) (edge.Listener, error) { ec := conn.NewListenConn(service, options.KeyPair) + + log := pfxlog.Logger(). + WithField("connId", ec.Id()). + WithField("router", conn.routerName). + WithField("serviceId", *service.ID). + WithField("serviceName", *service.Name) + listener, err := ec.Listen(session, service, options) if err != nil { + log.WithError(err).Error("failed to establish listener") + if err2 := ec.Close(); err2 != nil { - pfxlog.Logger().WithError(err2). - WithField("serviceName", *service.Name). + log.WithError(err2). Error("failed to cleanup listener for service after failed bind") } + } else { + log.Debug("established listener") } return listener, err } diff --git a/ziti/edge/network/listener.go b/ziti/edge/network/listener.go index 0a650b1a..75f2baaf 100644 --- a/ziti/edge/network/listener.go +++ b/ziti/edge/network/listener.go @@ -108,7 +108,7 @@ func (listener *edgeListener) UpdateCostAndPrecedence(cost uint16, precedence ed func (listener *edgeListener) updateCostAndPrecedence(cost *uint16, precedence *edge.Precedence) error { logger := pfxlog.Logger(). WithField("connId", listener.edgeChan.Id()). - WithField("service", listener.edgeChan.serviceId). + WithField("serviceName", listener.edgeChan.serviceName). WithField("session", listener.token) logger.Debug("sending update bind request to edge router") @@ -120,7 +120,7 @@ func (listener *edgeListener) updateCostAndPrecedence(cost *uint16, precedence * func (listener *edgeListener) SendHealthEvent(pass bool) error { logger := pfxlog.Logger(). WithField("connId", listener.edgeChan.Id()). - WithField("service", listener.edgeChan.serviceId). + WithField("serviceName", listener.edgeChan.serviceName). WithField("session", listener.token). WithField("health.status", pass) @@ -147,7 +147,7 @@ func (listener *edgeListener) close(closedByRemote bool) error { WithField("sessionId", listener.token) logger.Debug("removing listener for session") - edgeChan.hosting.Delete(listener.token) + edgeChan.hosting.Remove(listener.token) defer func() { edgeChan.close(closedByRemote)