From d418285e12046e6359ff36e7c23452f9b5f41304 Mon Sep 17 00:00:00 2001 From: HarrisChu <1726587+HarrisChu@users.noreply.github.com> Date: Thu, 23 Nov 2023 17:37:24 +0800 Subject: [PATCH] support to get result response byte size --- connection.go | 12 ++++++++++-- result_set.go | 35 +++++++++++++++++++++++++++++++++-- result_set_test.go | 28 +++++++++++++++++++++++++--- session.go | 3 ++- session_pool.go | 3 ++- 5 files changed, 72 insertions(+), 9 deletions(-) diff --git a/connection.go b/connection.go index fdfe188e..797536cc 100644 --- a/connection.go +++ b/connection.go @@ -58,6 +58,8 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo transport thrift.Transport pf thrift.ProtocolFactory ) + pf = cn.getProtocolFactory() + if useHTTP2 { if sslConfig != nil { transport, err = thrift.NewHTTPPostClientWithOptions("https://"+newAdd, thrift.HTTPClientOptions{ @@ -86,7 +88,6 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo if err != nil { return fmt.Errorf("failed to create a net.Conn-backed Transport,: %s", err.Error()) } - pf = thrift.NewBinaryProtocolFactoryDefault() if httpHeader != nil { client, ok := transport.(*thrift.HTTPClient) if !ok { @@ -118,7 +119,6 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo // Set transport bufferedTranFactory := thrift.NewBufferedTransportFactory(bufferSize) transport = thrift.NewHeaderTransport(bufferedTranFactory.GetTransport(sock)) - pf = thrift.NewHeaderProtocolFactory() } cn.graph = graph.NewGraphServiceClientFactory(transport, pf) @@ -224,3 +224,11 @@ func (cn *connection) release() { func (cn *connection) close() { cn.graph.Close() } + +func (cn *connection) getProtocolFactory() thrift.ProtocolFactory { + if cn.useHTTP2 { + return thrift.NewBinaryProtocolFactoryDefault() + } else { + return thrift.NewHeaderProtocolFactory() + } +} diff --git a/result_set.go b/result_set.go index ab5cd676..ff5db3df 100644 --- a/result_set.go +++ b/result_set.go @@ -17,6 +17,7 @@ import ( "strings" "time" + "github.com/vesoft-inc/fbthrift/thrift/lib/go/thrift" "github.com/vesoft-inc/nebula-go/v3/nebula" "github.com/vesoft-inc/nebula-go/v3/nebula/graph" ) @@ -26,6 +27,7 @@ type ResultSet struct { columnNames []string colNameIndexMap map[string]int timezoneInfo timezoneInfo + factory thrift.ProtocolFactory } type Record struct { @@ -96,18 +98,22 @@ const ( func GenResultSet(resp *graph.ExecutionResponse) (*ResultSet, error) { var defaultTimezone timezoneInfo = timezoneInfo{0, []byte("UTC")} - return genResultSet(resp, defaultTimezone) + return genResultSet(resp, defaultTimezone, nil) } -func genResultSet(resp *graph.ExecutionResponse, timezoneInfo timezoneInfo) (*ResultSet, error) { +func genResultSet(resp *graph.ExecutionResponse, timezoneInfo timezoneInfo, factory thrift.ProtocolFactory) (*ResultSet, error) { var colNames []string var colNameIndexMap = make(map[string]int) + if factory == nil { + factory = thrift.NewHeaderProtocolFactory() + } if resp.Data == nil { // if resp.Data != nil then resp.Data.row and resp.Data.colNames wont be nil return &ResultSet{ resp: resp, columnNames: colNames, colNameIndexMap: colNameIndexMap, + factory: factory, }, nil } for i, name := range resp.Data.ColumnNames { @@ -120,6 +126,7 @@ func genResultSet(resp *graph.ExecutionResponse, timezoneInfo timezoneInfo) (*Re columnNames: colNames, colNameIndexMap: colNameIndexMap, timezoneInfo: timezoneInfo, + factory: factory, }, nil } @@ -1381,3 +1388,27 @@ func (res ResultSet) MakePlanByTck() [][]interface{} { } return rows } + +func (res *ResultSet) GetByteSize() int { + if res.resp == nil { + return 0 + } + if res.factory == nil { + return 0 + } + var pf thrift.ProtocolFactory + buf := thrift.NewMemoryBuffer() + switch res.factory.(type) { + case *thrift.BinaryProtocolFactory: + pf = res.factory + case *thrift.HeaderProtocolFactory: + pf = thrift.NewCompactProtocolFactory() + } + protocal := pf.GetProtocol(buf) + if err := res.resp.Write(protocal); err != nil { + return 0 + } + bs := make([]byte, buf.Len()) + copy(bs, buf.Bytes()) + return len(bs) +} diff --git a/result_set_test.go b/result_set_test.go index ed417352..ffdb7587 100644 --- a/result_set_test.go +++ b/result_set_test.go @@ -14,6 +14,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/vesoft-inc/fbthrift/thrift/lib/go/thrift" "github.com/vesoft-inc/nebula-go/v3/nebula" "github.com/vesoft-inc/nebula-go/v3/nebula/graph" ) @@ -550,7 +551,7 @@ func TestResultSet(t *testing.T) { nil, nil, nil} - resultSetWithNil, err := genResultSet(respWithNil, testTimezone) + resultSetWithNil, err := genResultSet(respWithNil, testTimezone, nil) if err != nil { t.Error(err) } @@ -595,7 +596,7 @@ func TestResultSet(t *testing.T) { &planDesc, []byte("test_comment")} - resultSet, err := genResultSet(resp, testTimezone) + resultSet, err := genResultSet(resp, testTimezone, nil) if err != nil { t.Error(err) } @@ -673,7 +674,7 @@ func TestAsStringTable(t *testing.T) { []byte("test"), graph.NewPlanDescription(), []byte("test_comment")} - resultSet, err := genResultSet(resp, testTimezone) + resultSet, err := genResultSet(resp, testTimezone, nil) if err != nil { t.Error(err) } @@ -878,3 +879,24 @@ func setIVal(ival int) *nebula.Value { value.IVal = newNum return value } + +func TestGetByteSize(t *testing.T) { + resp := &graph.ExecutionResponse{ + nebula.ErrorCode_SUCCEEDED, + 1000, + getDateset(), + []byte("test_space"), + []byte("test"), + graph.NewPlanDescription(), + []byte("test_comment")} + resultSet, err := genResultSet(resp, testTimezone, thrift.NewBinaryProtocolFactoryDefault()) + if err != nil { + t.Error(err) + } + assert.Equal(t, 2899, resultSet.GetByteSize()) + resultSet, err = genResultSet(resp, testTimezone, thrift.NewHeaderProtocolFactory()) + if err != nil { + t.Error(err) + } + assert.Equal(t, 1297, resultSet.GetByteSize()) +} diff --git a/session.go b/session.go index 6a7da2af..45d65b8b 100644 --- a/session.go +++ b/session.go @@ -69,7 +69,8 @@ func (session *Session) ExecuteWithParameter(stmt string, params map[string]inte if err != nil { return nil, err } - resSet, err := genResultSet(resp, session.timezoneInfo) + pf := session.connection.getProtocolFactory() + resSet, err := genResultSet(resp, session.timezoneInfo, pf) if err != nil { return nil, err } diff --git a/session_pool.go b/session_pool.go index 877dcab6..d15e6618 100644 --- a/session_pool.go +++ b/session_pool.go @@ -576,7 +576,8 @@ func (session *pureSession) executeWithParameter(stmt string, params map[string] if err != nil { return nil, err } - rs, err := genResultSet(resp, session.timezoneInfo) + pf := session.connection.getProtocolFactory() + rs, err := genResultSet(resp, session.timezoneInfo, pf) if err != nil { return nil, err }