Skip to content

Commit

Permalink
feat: support ws client (#2992)
Browse files Browse the repository at this point in the history
Signed-off-by: Song Gao <disxiaofei@163.com>
  • Loading branch information
Yisaer authored Jul 8, 2024
1 parent 4884ab8 commit 461b48b
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 35 deletions.
58 changes: 47 additions & 11 deletions internal/io/http/httpserver/websocketConn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,37 @@ package httpserver
import (
"github.com/lf-edge/ekuiper/contract/v2/api"
"github.com/lf-edge/ekuiper/v2/pkg/cast"
"github.com/lf-edge/ekuiper/v2/pkg/cert"
"github.com/lf-edge/ekuiper/v2/pkg/modules"
)

type WebsocketConnection struct {
RecvTopic string
SendTopic string
cfg *connectionCfg
props map[string]any
cfg *wscConfig
isServer bool
client *WebsocketClient
}

type wscConfig struct {
Datasource string `json:"datasource"`
Addr string `json:"addr"`
}

func (w *WebsocketConnection) Ping(ctx api.StreamContext) error {
return nil
}

func (w *WebsocketConnection) DetachSub(ctx api.StreamContext, props map[string]any) {
UnRegisterWebSocketEndpoint(w.cfg.Datasource)
}

func (w *WebsocketConnection) Close(ctx api.StreamContext) error {
if w.isServer {
UnRegisterWebSocketEndpoint(w.cfg.Datasource)
} else {
w.client.Close(ctx)
}
return nil
}

Expand All @@ -43,17 +56,40 @@ func CreateWebsocketConnection(ctx api.StreamContext, props map[string]any) (mod
}

func createWebsocketServerConnection(ctx api.StreamContext, props map[string]any) (*WebsocketConnection, error) {
cfg := &connectionCfg{}
cfg := &wscConfig{}
if err := cast.MapToStruct(props, cfg); err != nil {
return nil, err
}
rTopic, sTopic, err := RegisterWebSocketEndpoint(ctx, cfg.Datasource)
if err != nil {
return nil, err
wc := &WebsocketConnection{
props: props,
cfg: cfg,
isServer: getWsType(cfg),
}
if wc.isServer {
rTopic, sTopic, err := RegisterWebSocketEndpoint(ctx, cfg.Datasource)
if err != nil {
return nil, err
}
wc.RecvTopic = rTopic
wc.SendTopic = sTopic
} else {
tlsConfig, err := cert.GenTLSConfig(props, "websocket")
if err != nil {
return nil, err
}
c := NewWebsocketClient(cfg.Addr, cfg.Datasource, tlsConfig)
if err := c.Connect(); err != nil {
return nil, err
}
wc.client = c
wc.RecvTopic, wc.SendTopic = c.Run(ctx)
}
return wc, nil
}

func getWsType(cfg *wscConfig) bool {
if len(cfg.Addr) < 1 {
return true
}
return &WebsocketConnection{
RecvTopic: rTopic,
SendTopic: sTopic,
cfg: cfg,
}, nil
return false
}
96 changes: 96 additions & 0 deletions internal/io/http/httpserver/websocketConn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
package httpserver

import (
"context"
"log"
"net/http"
"net/http/httptest"
"sync"
"testing"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"

mockContext "github.com/lf-edge/ekuiper/v2/pkg/mock/context"
Expand All @@ -35,5 +41,95 @@ func TestWebsocketConn(t *testing.T) {
require.NoError(t, err)
require.NoError(t, conn.Ping(ctx))
conn.DetachSub(ctx, props)
require.NoError(t, conn.Close(ctx))
}

func TestWebsocketClientConn(t *testing.T) {
tc := newTC()
s := createWServer(tc)
defer func() {
s.Close()
}()
ctx := mockContext.NewMockContext("1", "2")
props := map[string]any{
"datasource": "/ws",
"addr": s.URL[len("http://"):],
}
conn, err := createWebsocketServerConnection(ctx, props)
require.NoError(t, err)
require.NoError(t, conn.Ping(ctx))
conn.DetachSub(ctx, props)
require.NoError(t, conn.Close(ctx))
}

func newTC() *testcase {
ctx, cancel := context.WithCancel(context.Background())
return &testcase{
ctx: ctx,
cancel: cancel,
recvCh: make(chan []byte, 10),
sendCh: make(chan []byte, 10),
}
}

type testcase struct {
ctx context.Context
cancel context.CancelFunc
recvCh chan []byte
sendCh chan []byte
}

func createWServer(tc *testcase) *httptest.Server {
router := http.NewServeMux()
router.HandleFunc("/ws", tc.handler)
server := httptest.NewServer(router)
return server
}

var upgrader = websocket.Upgrader{
ReadBufferSize: 256,
WriteBufferSize: 256,
WriteBufferPool: &sync.Pool{},
}

func (tc *testcase) recvProcess(c *websocket.Conn) {
defer func() {
tc.cancel()
c.Close()
}()
for {
_, message, err := c.ReadMessage()
if err != nil {
return
}
tc.recvCh <- message
}
}

func (tc *testcase) sendProcess(c *websocket.Conn) {
defer func() {
tc.cancel()
c.Close()
}()
for {
select {
case <-tc.ctx.Done():
return
case x := <-tc.sendCh:
err := c.WriteMessage(websocket.TextMessage, x)
if err != nil {
return
}
}
}
}

func (tc *testcase) handler(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Print("upgrade:", err)
return
}
go tc.recvProcess(c)
go tc.sendProcess(c)
}
90 changes: 90 additions & 0 deletions internal/io/http/httpserver/websocket_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2024 EMQ Technologies Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package httpserver

import (
"context"
"crypto/tls"
"fmt"
"net/url"
"sync"
"time"

"github.com/gorilla/websocket"

"github.com/lf-edge/ekuiper/contract/v2/api"
"github.com/lf-edge/ekuiper/v2/internal/io/memory/pubsub"
)

type WebsocketClient struct {
RecvTopic string
SendTopic string

addr string
path string
tlsConfig *tls.Config
conn *websocket.Conn
wg *sync.WaitGroup
cancel context.CancelFunc
}

func NewWebsocketClient(addr, path string, tlsConfig *tls.Config) *WebsocketClient {
return &WebsocketClient{
addr: addr,
path: path,
tlsConfig: tlsConfig,
wg: &sync.WaitGroup{},
}
}

func (c *WebsocketClient) Connect() error {
d := &websocket.Dialer{
HandshakeTimeout: 3 * time.Second,
TLSClientConfig: c.tlsConfig,
}
if len(c.addr) < 1 {
return fmt.Errorf("addr should be defined")
}
u := url.URL{Scheme: "ws", Host: c.addr, Path: c.path}
conn, _, err := d.Dial(u.String(), nil)
if err != nil {
return err
}
c.conn = conn
return nil
}

func (c *WebsocketClient) Run(ctx api.StreamContext) (string, string) {
c.RecvTopic = recvTopic(c.path, false)
c.SendTopic = sendTopic(c.path, false)
pubsub.CreatePub(c.RecvTopic)
c.handleProcess(ctx)
return c.RecvTopic, c.SendTopic
}

func (c *WebsocketClient) handleProcess(parCtx api.StreamContext) {
ctx, cancel := parCtx.WithCancel()
c.cancel = cancel
c.wg.Add(2)
go recvProcess(ctx, c.RecvTopic, c.conn, cancel, c.wg)
go sendProcess(ctx, c.SendTopic, "", c.conn, cancel, c.wg)
}

func (c *WebsocketClient) Close(ctx api.StreamContext) error {
pubsub.RemovePub(c.RecvTopic)
c.cancel()
c.wg.Wait()
return nil
}
53 changes: 53 additions & 0 deletions internal/io/http/httpserver/websocket_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2024 EMQ Technologies Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package httpserver

import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/lf-edge/ekuiper/v2/internal/io/memory/pubsub"
mockContext "github.com/lf-edge/ekuiper/v2/pkg/mock/context"
)

func TestWebsocketClient(t *testing.T) {
tc := newTC()
s := createWServer(tc)
defer func() {
s.Close()
}()
ctx := mockContext.NewMockContext("1", "2")
wc := NewWebsocketClient(s.URL[len("http://"):], "/ws", nil)
require.NoError(t, wc.Connect())
rt, st := wc.Run(ctx)
pubsub.CreatePub(st)
defer func() {
pubsub.RemovePub(st)
}()
// wait process start
time.Sleep(100 * time.Millisecond)
data := []byte("123")
pubsub.ProduceAny(ctx, st, data)
require.Equal(t, data, <-tc.recvCh)
ch := pubsub.CreateSub(rt, nil, "", 1024)
defer func() {
pubsub.CloseSourceConsumerChannel(rt, "")
}()
tc.sendCh <- data
require.Equal(t, data, <-ch)
require.NoError(t, wc.Close(ctx))
}
Loading

0 comments on commit 461b48b

Please sign in to comment.