diff --git a/pkg/interpreter/graph/graph_optimizer.go b/pkg/interpreter/graph/graph_optimizer.go new file mode 100644 index 00000000..7a2dea74 --- /dev/null +++ b/pkg/interpreter/graph/graph_optimizer.go @@ -0,0 +1,282 @@ +// Copyright 2025 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package graph + +import ( + "fmt" + "strconv" + + proto "github.com/secretflow/scql/pkg/proto-gen/scql" + "github.com/secretflow/scql/pkg/util/stringutil" +) + +var ( + _ optimizeGraphRule = &optConstantCast{} +) + +type optimizeGraphRule interface { + optimize(*Graph) error +} + +type GraphOptimizer struct { + rules []optimizeGraphRule +} + +func NewGraphOptimizer() *GraphOptimizer { + rules := []optimizeGraphRule{&optConstantCast{}} + return &GraphOptimizer{rules: rules} +} + +func (g *GraphOptimizer) Optimize(graph *Graph) error { + for _, rule := range g.rules { + if err := rule.optimize(graph); err != nil { + return err + } + } + return nil +} + +type optConstantCast struct { +} + +func (rule optConstantCast) optimize(graph *Graph) error { + for _, pipeline := range graph.Pipelines { + for node := range pipeline.Nodes { + if node.OpType != "Constant" { + continue + } + + // find broadcast node + var broadCastNode *ExecutionNode + for edge := range node.Edges { + if edge.To.OpType == "BroadcastTo" { + broadCastNode = edge.To + } + } + if broadCastNode == nil { + continue + } + + // find cast node + var castNode *ExecutionNode + for edge := range broadCastNode.Edges { + if edge.To.OpType == "Cast" { + castNode = edge.To + } + } + if castNode == nil { + continue + } + + // check whether cast is valid + originType := node.Outputs["Out"][0].DType + castType := castNode.Outputs["Out"][0].DType + if !isValidCast(originType, castType) { + return fmt.Errorf("GraphOptimizer: invalid cast from %v to %v", originType, castType) + } + + // cast value + scalarAttr := node.Attributes["scalar"] + err := castValue(scalarAttr, originType, castType) + if err != nil { + return fmt.Errorf("GraphOptimizer: failed to cast value: %v", err) + } + + // change tensor type + if castType == proto.PrimitiveDataType_DATETIME || castType == proto.PrimitiveDataType_TIMESTAMP { + node.Outputs["Out"][0].DType = proto.PrimitiveDataType_INT64 + broadCastNode.Outputs["Out"][0].DType = proto.PrimitiveDataType_INT64 + } else { + node.Outputs["Out"][0].DType = castType + broadCastNode.Outputs["Out"][0].DType = castType + } + + // rearrange edges + for edge := range broadCastNode.Edges { + delete(broadCastNode.Edges, edge) + } + castNodeOutTs := castNode.Outputs["Out"][0] + for edge := range castNode.Edges { + edge.From = broadCastNode + edge.Value = broadCastNode.Outputs["Out"][0] + broadCastNode.Edges[edge] = true + + for _, input := range edge.To.Inputs { + for i := range input { + if input[i].ID == castNodeOutTs.ID { + input[i] = edge.Value + } + } + } + } + + // remove castNode + delete(pipeline.Nodes, castNode) + } + } + return nil +} + +func isValidCast(originType, castType proto.PrimitiveDataType) bool { + validCasts := map[proto.PrimitiveDataType]map[proto.PrimitiveDataType]bool{ + proto.PrimitiveDataType_STRING: { + proto.PrimitiveDataType_INT64: true, + proto.PrimitiveDataType_FLOAT64: true, + proto.PrimitiveDataType_DATETIME: true, + proto.PrimitiveDataType_TIMESTAMP: true, + }, + proto.PrimitiveDataType_INT32: { + proto.PrimitiveDataType_FLOAT32: true, + proto.PrimitiveDataType_FLOAT64: true, + proto.PrimitiveDataType_STRING: true, + }, + proto.PrimitiveDataType_INT64: { + proto.PrimitiveDataType_FLOAT32: true, + proto.PrimitiveDataType_FLOAT64: true, + proto.PrimitiveDataType_STRING: true, + }, + proto.PrimitiveDataType_FLOAT32: { + proto.PrimitiveDataType_INT64: true, + proto.PrimitiveDataType_STRING: true, + }, + proto.PrimitiveDataType_FLOAT64: { + proto.PrimitiveDataType_INT64: true, + proto.PrimitiveDataType_STRING: true, + }, + proto.PrimitiveDataType_BOOL: { + proto.PrimitiveDataType_INT32: true, + proto.PrimitiveDataType_STRING: true, + }, + proto.PrimitiveDataType_DATETIME: { + proto.PrimitiveDataType_STRING: true, + proto.PrimitiveDataType_INT64: true, + }, + proto.PrimitiveDataType_TIMESTAMP: { + proto.PrimitiveDataType_STRING: true, + proto.PrimitiveDataType_INT64: true, + }, + } + + if validCastMap, ok := validCasts[originType]; ok { + return validCastMap[castType] + } + + return originType == castType +} + +func castValue(scalarAttr *Attribute, originType, castType proto.PrimitiveDataType) error { + if scalarAttr == nil { + return fmt.Errorf("constant node doesn't have scalar attribute") + } + + originalValue := scalarAttr.GetAttrValue() + if originalValue == nil { + return fmt.Errorf("constant node doesn't have value") + } + + switch originType { + case proto.PrimitiveDataType_STRING: + strVal, ok := originalValue.(string) + if !ok { + return fmt.Errorf("expected string value") + } + if castType == proto.PrimitiveDataType_INT64 { + castValue, err := strconv.ParseInt(strVal, 10, 64) + if err != nil { + return err + } + scalarAttr.SetInt64(castValue) + return nil + } else if castType == proto.PrimitiveDataType_FLOAT64 { + castValue, err := strconv.ParseFloat(strVal, 64) + if err != nil { + return err + } + scalarAttr.SetDouble(castValue) + return nil + } else if castType == proto.PrimitiveDataType_DATETIME || castType == proto.PrimitiveDataType_TIMESTAMP { // return int64 value + if stringutil.IsDateString(strVal) { + tsMilli, err := stringutil.StringToUnixMilli(strVal) + if err != nil { + return fmt.Errorf("failed to parse date/time constant %q: %v", strVal, err) + } + scalarAttr.SetInt64(tsMilli) + return nil + } + return fmt.Errorf("date/time constant format should be 'YYYY-MM-DD hh:mm:ss'") + } + case proto.PrimitiveDataType_INT32, proto.PrimitiveDataType_INT64: + intVal, ok := originalValue.(int64) + if !ok { + return fmt.Errorf("expected int64 value") + } + if castType == proto.PrimitiveDataType_FLOAT64 { + scalarAttr.SetDouble(float64(intVal)) + return nil + } else if castType == proto.PrimitiveDataType_STRING { + scalarAttr.SetString(strconv.FormatInt(intVal, 10)) + return nil + } + case proto.PrimitiveDataType_FLOAT32, proto.PrimitiveDataType_FLOAT64: + floatVal, ok := originalValue.(float64) + if !ok { + return fmt.Errorf("expected float64 value") + } + if castType == proto.PrimitiveDataType_INT64 { + scalarAttr.SetInt64(int64(floatVal)) + return nil + } else if castType == proto.PrimitiveDataType_STRING { + scalarAttr.SetString(fmt.Sprintf("%f", floatVal)) + return nil + } + case proto.PrimitiveDataType_BOOL: + boolVal, ok := originalValue.(bool) + if !ok { + return fmt.Errorf("expected bool value") + } + if castType == proto.PrimitiveDataType_INT32 { + if boolVal { + scalarAttr.SetInt(1) + return nil + } else { + scalarAttr.SetInt(0) + return nil + } + } else if castType == proto.PrimitiveDataType_STRING { + scalarAttr.SetString(strconv.FormatBool(boolVal)) + return nil + } + case proto.PrimitiveDataType_DATETIME, proto.PrimitiveDataType_TIMESTAMP: + strVal, ok := originalValue.(string) + if !ok { + return fmt.Errorf("expected datetime string value") + } + if castType == proto.PrimitiveDataType_STRING { + scalarAttr.SetString(strVal) + return nil + } else if castType == proto.PrimitiveDataType_INT64 { + if stringutil.IsDateString(strVal) { + tsMilli, err := stringutil.StringToUnixMilli(strVal) + if err != nil { + return fmt.Errorf("failed to parse date/time constant %q: %v", strVal, err) + } + scalarAttr.SetInt64(tsMilli) + return nil + } + return fmt.Errorf("date/time constant format should be 'YYYY-MM-DD hh:mm:ss'") + } + } + return fmt.Errorf("invalid cast from %v to %v", originType, castType) +} diff --git a/pkg/interpreter/graph/graph_optimizer_test.go b/pkg/interpreter/graph/graph_optimizer_test.go new file mode 100644 index 00000000..f5c47162 --- /dev/null +++ b/pkg/interpreter/graph/graph_optimizer_test.go @@ -0,0 +1,111 @@ +// Copyright 2025 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package graph + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/secretflow/scql/pkg/interpreter/ccl" + proto "github.com/secretflow/scql/pkg/proto-gen/scql" + "github.com/secretflow/scql/pkg/types" +) + +func testCastConversion(t *testing.T, originValue interface{}, originType proto.PrimitiveDataType, castType proto.PrimitiveDataType, expectedValue interface{}) { + r := require.New(t) + + participants := []*Participant{ + {PartyCode: "party1", Endpoints: []string{"party1.net"}, Token: "party1_credential"}, + {PartyCode: "party2", Endpoints: []string{"party2.net"}, Token: "party2_credential"}, + } + partyInfo := NewPartyInfo(participants) + e1 := NewGraphBuilder(partyInfo, false) + + t1 := e1.AddTensor("alice.date") + t1.SetStatus(proto.TensorStatus_TENSORSTATUS_PRIVATE) + t1.DType = castType + e1.AddRunSQLNode("RunSQLOp1", []*Tensor{t1}, "select f1 from alice.t1", []string{"alice.t1"}, "party1") + t1.CC = ccl.CreateAllPlainCCL([]string{"party1", "party2"}) + + // create constant node according to originValue + var constantTensor *Tensor + switch originType { + case proto.PrimitiveDataType_STRING: + strDatum := types.NewStringDatum(originValue.(string)) + t2, err := e1.AddConstantNode("make_constant", &strDatum, []string{"party1"}) + r.NoError(err) + r.NotNil(t2) + constantTensor = t2 + case proto.PrimitiveDataType_INT64: + intDatum := types.NewIntDatum(originValue.(int64)) + t2, err := e1.AddConstantNode("make_constant", &intDatum, []string{"party1"}) + r.NoError(err) + r.NotNil(t2) + constantTensor = t2 + case proto.PrimitiveDataType_FLOAT64: + floatDatum := types.NewFloat64Datum(originValue.(float64)) + t2, err := e1.AddConstantNode("make_constant", &floatDatum, []string{"party1"}) + r.NoError(err) + r.NotNil(t2) + constantTensor = t2 + // TODO: add more test cases + default: + t.Fatalf("Unsupported origin type: %v", originType) + } + + t3s, err := e1.AddBroadcastToNode("broadcast", []*Tensor{constantTensor}, t1) + r.NoError(err) + r.NotNil(t3s) + t3 := t3s[0] + r.NotNil(t3) + + t4, err := e1.AddCastNode("cast", t1.DType, t3, []string{"party1"}) + r.NoError(err) + r.NotNil(t4) + + graph := e1.Build() + pipelineNodes, err := graph.TopologicalSort() + r.NoError(err) + r.NotNil(pipelineNodes) + + r.Equal(4, len(pipelineNodes[0])) + + graphOptimizer := NewGraphOptimizer() + err = graphOptimizer.Optimize(graph) + r.NoError(err) + pipelineNodes, err = graph.TopologicalSort() + r.NoError(err) + r.NotNil(pipelineNodes) + + // cast node should be removed + r.Equal(3, len(pipelineNodes[0])) + for _, node := range pipelineNodes[0] { + if node.Name == "make_constant" { + r.Equal(expectedValue, node.Attributes["scalar"].GetAttrValue()) + } + } +} + +func TestOptConstantCast(t *testing.T) { + // Test for string to datetime conversion + testCastConversion(t, "2025-05-08", proto.PrimitiveDataType_STRING, proto.PrimitiveDataType_DATETIME, int64(1746662400000)) + + // Test for int to float conversion + testCastConversion(t, int64(12), proto.PrimitiveDataType_INT64, proto.PrimitiveDataType_FLOAT64, float64(12)) + + // Test for float to int conversion + testCastConversion(t, float64(12.2), proto.PrimitiveDataType_FLOAT64, proto.PrimitiveDataType_INT64, int64(12)) +} diff --git a/pkg/interpreter/interpreter.go b/pkg/interpreter/interpreter.go index d24e1e88..967cc678 100644 --- a/pkg/interpreter/interpreter.go +++ b/pkg/interpreter/interpreter.go @@ -109,6 +109,11 @@ func (*Interpreter) compileCore(enginesInfo *graph.EnginesInfo, req *pb.CompileQ return nil, err } + graphOptimizer := graph.NewGraphOptimizer() + if err := graphOptimizer.Optimize(ep); err != nil { + return nil, err + } + graphChecker := graph.NewGraphChecker() if err := graphChecker.Check(ep); err != nil { return nil, err diff --git a/pkg/interpreter/interpreter_test.go b/pkg/interpreter/interpreter_test.go index e722c0b4..9ecddbce 100644 --- a/pkg/interpreter/interpreter_test.go +++ b/pkg/interpreter/interpreter_test.go @@ -123,6 +123,20 @@ var commonSecurityConf = &proto.SecurityConfig{ TableName: "ta", ColumnName: "income", }, + { + PartyCode: "alice", + Visibility: proto.SecurityConfig_ColumnControl_PLAINTEXT, + DatabaseName: "", + TableName: "ta", + ColumnName: "date", + }, + { + PartyCode: "bob", + Visibility: proto.SecurityConfig_ColumnControl_PLAINTEXT, + DatabaseName: "", + TableName: "ta", + ColumnName: "date", + }, }, } @@ -147,6 +161,10 @@ var commonCatalog = &proto.Catalog{ Name: "age", Type: "int", }, + { + Name: "date", + Type: "datetime", + }, }, IsView: false, RefTable: "alice.user_credit", @@ -182,6 +200,22 @@ var commonCatalog = &proto.Catalog{ } var testCases = []compileTestCase{ + { + req: &proto.CompileQueryRequest{ + Query: "SELECT ta.ID, ta.date FROM ta right join tb on ta.ID = tb.ID where ta.date > '2025-04-23 12:25:42'", + DbName: "", + Issuer: &proto.PartyId{ + Code: "alice", + }, + IssuerAsParticipant: true, + SecurityConf: commonSecurityConf, + Catalog: commonCatalog, + CompileOpts: &proto.CompileOptions{SecurityCompromise: &proto.SecurityCompromiseConfig{GroupByThreshold: 4}}, + CreatedAt: timestamppb.New(time.Now()), + }, + ok: true, + workPartyNum: 2, + }, { req: &proto.CompileQueryRequest{ Query: "SELECT ta.ID, ta.income FROM ta WHERE ta.income > ALL(SELECT ta.income AS joined_income FROM ta INNER JOIN tb ON ta.ID = tb.ID)", diff --git a/pkg/interpreter/translator/translator.go b/pkg/interpreter/translator/translator.go index 2368e02f..4c3c76af 100644 --- a/pkg/interpreter/translator/translator.go +++ b/pkg/interpreter/translator/translator.go @@ -1192,10 +1192,47 @@ func inferBinaryOpOutputType(opType string, left, right *graph.Tensor) (proto.Pr return proto.PrimitiveDataType_PrimitiveDataType_UNDEFINED, fmt.Errorf("cannot infer output type for opType=%s", opType) } +var constTensorNeedCastOp = map[string]bool{ + operator.OpNameLess: true, + operator.OpNameLessEqual: true, + operator.OpNameGreater: true, + operator.OpNameGreaterEqual: true, + operator.OpNameEqual: true, + operator.OpNameAdd: true, + operator.OpNameMinus: true, +} + func (t *translator) addBinaryNode(opName string, opType string, left *graph.Tensor, right *graph.Tensor) (*graph.Tensor, error) { if ok := slices.Contains(operator.BinaryOps, opType); !ok { return nil, fmt.Errorf("failed to check op type AddBinaryNode: invalid opType %v", opType) } + + // only support string to time currently + if _, ok := constTensorNeedCastOp[opType]; ok { + var err error + const ConstantDataName = "constant_data" + + castIfNeeded := func(constTensor, otherTensor *graph.Tensor, targetType proto.PrimitiveDataType) (*graph.Tensor, error) { + if (constTensor.Name == ConstantDataName && constTensor.DType == proto.PrimitiveDataType_STRING) && (otherTensor.DType == proto.PrimitiveDataType_DATETIME || otherTensor.DType == proto.PrimitiveDataType_TIMESTAMP) { + inTensorPartyCodes := t.extractPartyCodeFromTensor(constTensor) + return t.ep.AddCastNode("cast", targetType, constTensor, inTensorPartyCodes) + } + return constTensor, nil + } + + if left.Name == ConstantDataName { + left, err = castIfNeeded(left, right, right.DType) + if err != nil { + return nil, err + } + } else if right.Name == ConstantDataName { + right, err = castIfNeeded(right, left, left.DType) + if err != nil { + return nil, err + } + } + } + if err := graph.CheckBinaryOpInputType(opType, left, right); err != nil { return nil, fmt.Errorf("addBinaryNode: %w", err) } diff --git a/pkg/interpreter/translator/translator_ccl_input_for_test.go b/pkg/interpreter/translator/translator_ccl_input_for_test.go index 0986fa3f..3f947237 100644 --- a/pkg/interpreter/translator/translator_ccl_input_for_test.go +++ b/pkg/interpreter/translator/translator_ccl_input_for_test.go @@ -74,6 +74,32 @@ var translateNumericTestCases = []sPair{ } var translateWithCCLTestCases = []sPair{ + {`SELECT ta.join_int_0, ta.plain_datetime_0 from alice.tbl_0 as ta right join bob.tbl_0 as tb on ta.join_int_0 = tb.join_int_0 where ta.plain_datetime_0 > '2025-04-23 12:25:42'`, `digraph G { +0 [label="runsql:{in:[],out:[Out:{t_0,t_1,},],attr:[sql:select ta.join_int_0,ta.plain_datetime_0 from alice.tbl_0 as ta;,table_refs:[alice.tbl_0],],party:[alice,]}"] +1 [label="runsql:{in:[],out:[Out:{t_2,},],attr:[sql:select tb.join_int_0 from bob.tbl_0 as tb;,table_refs:[bob.tbl_0],],party:[bob,]}"] +2 [label="join:{in:[Left:{t_0,},Right:{t_2,},],out:[LeftJoinIndex:{t_3,},RightJoinIndex:{t_4,},],attr:[input_party_codes:[alice bob],join_type:2,psi_algorithm:0,],party:[alice,bob,]}"] +3 [label="filter_by_index:{in:[Data:{t_0,t_1,},RowsIndexFilter:{t_3,},],out:[Out:{t_5,t_6,},],attr:[],party:[alice,]}"] +4 [label="make_constant:{in:[],out:[Out:{t_7,},],attr:[scalar:1745411142000,],party:[alice,bob,carol,]}"] +5 [label="broadcast:{in:[In:{t_7,},ShapeRefTensor:{t_6,},],out:[Out:{t_8,},],attr:[],party:[alice,]}"] +7 [label="Greater:{in:[Left:{t_6,},Right:{t_8,},],out:[Out:{t_10,},],attr:[],party:[alice,]}"] +8 [label="apply_filter:{in:[Filter:{t_10,},In:{t_5,t_6,},],out:[Out:{t_11,t_12,},],attr:[],party:[alice,]}"] +9 [label="publish:{in:[In:{t_11,t_12,},],out:[Out:{t_13,t_14,},],attr:[],party:[alice,]}"] +0 -> 2 [label = "t_0:{join_int_0:PRIVATE:INT64}"] +0 -> 3 [label = "t_0:{join_int_0:PRIVATE:INT64}"] +0 -> 3 [label = "t_1:{plain_datetime_0:PRIVATE:DATETIME}"] +1 -> 2 [label = "t_2:{join_int_0:PRIVATE:INT64}"] +2 -> 3 [label = "t_3:{join_int_0:PRIVATE:INT64}"] +3 -> 5 [label = "t_6:{plain_datetime_0:PRIVATE:DATETIME}"] +3 -> 7 [label = "t_6:{plain_datetime_0:PRIVATE:DATETIME}"] +3 -> 8 [label = "t_5:{join_int_0:PRIVATE:INT64}"] +3 -> 8 [label = "t_6:{plain_datetime_0:PRIVATE:DATETIME}"] +4 -> 5 [label = "t_7:{constant_data:PUBLIC:INT64}"] +5 -> 7 [label = "t_8:{constant_data:PRIVATE:INT64}"] +7 -> 8 [label = "t_10:{Greater_out:PRIVATE:BOOL}"] +8 -> 9 [label = "t_11:{join_int_0:PRIVATE:INT64}"] +8 -> 9 [label = "t_12:{plain_datetime_0:PRIVATE:DATETIME}"] +}`, ``, testConf{groupThreshold: 0, batched: false}, + }, {`SELECT PERCENTILE_DISC(u.int_0, 0.3) AS _30percent FROM (SELECT ta.plain_int_0 as int_0 FROM alice.tbl_0 AS ta UNION ALL SELECT tb.plain_int_0 as int_0 FROM bob.tbl_0 AS tb) as u`, `digraph G { 0 [label="runsql:{in:[],out:[Out:{t_0,},],attr:[sql:select ta.plain_int_0 from alice.tbl_0 as ta;,table_refs:[alice.tbl_0],],party:[alice,]}"] 1 [label="runsql:{in:[],out:[Out:{t_1,},],attr:[sql:select tb.plain_int_0 from bob.tbl_0 as tb;,table_refs:[bob.tbl_0],],party:[bob,]}"] diff --git a/pkg/interpreter/translator/translator_ccl_test.go b/pkg/interpreter/translator/translator_ccl_test.go index 068d03a8..532e0fe6 100644 --- a/pkg/interpreter/translator/translator_ccl_test.go +++ b/pkg/interpreter/translator/translator_ccl_test.go @@ -21,6 +21,7 @@ import ( . "github.com/pingcap/check" + "github.com/secretflow/scql/pkg/interpreter/graph" "github.com/secretflow/scql/pkg/planner/core" "github.com/secretflow/scql/pkg/proto-gen/scql" "github.com/secretflow/scql/pkg/util/mock" @@ -60,6 +61,11 @@ func (s *testTranslatorSuite) TestTranslateWithCCL(c *C) { c.Assert(err, IsNil) ep, err := t.Translate(lp) c.Assert(err, IsNil, Commentf("for %s", sql)) + + graphOptimizer := graph.NewGraphOptimizer() + err = graphOptimizer.Optimize(ep) + c.Assert(err, IsNil) + graphStr := ep.DumpGraphviz() // if you want to copy the graph created by DumpGraphviz, uncomment this line c.Log(graphStr) diff --git a/pkg/util/stringutil/string_util.go b/pkg/util/stringutil/string_util.go index 95488c6a..88c3d898 100644 --- a/pkg/util/stringutil/string_util.go +++ b/pkg/util/stringutil/string_util.go @@ -334,3 +334,29 @@ func RandString(n int) string { } return string(b) } + +func IsDateString(s string) bool { + matched, _ := regexp.MatchString(`^\d{4}-\d{2}-\d{2}(\s\d{2}:\d{2}:\d{2})?$`, s) + return matched +} + +func StringToUnixMilli(s string) (int64, error) { + var t time.Time + var err error + + if len(s) == len("2006-01-02 15:04:05") { + t, err = time.Parse("2006-01-02 15:04:05", s) + if err == nil { + return t.UnixMilli(), nil + } + } + + if len(s) == len("2006-01-02") { + t, err = time.Parse("2006-01-02", s) + if err == nil { + return t.UnixMilli(), nil + } + } + + return 0, fmt.Errorf("StringToUnixMilli: unsupported date/time format: %s", s) +} diff --git a/pkg/util/stringutil/string_util_test.go b/pkg/util/stringutil/string_util_test.go index c0627324..d20b9b7e 100644 --- a/pkg/util/stringutil/string_util_test.go +++ b/pkg/util/stringutil/string_util_test.go @@ -228,3 +228,48 @@ func TestRemoveSensitiveInfo(t *testing.T) { r.Equal(expected, actual) } } + +func TestIsDateString(t *testing.T) { + r := require.New(t) + type pair struct { + in string + expect bool + } + testCases := []pair{ + {"2024-05-01", true}, + {"2024-5-1", false}, + {"2024-05-01 11:12:13", true}, + {"2024-05-01T11:12:13Z", false}, + {"05/01/2024", false}, + {"", false}, + {"2024-05-01 1:12:13", false}, + } + for _, ca := range testCases { + r.Equal(ca.expect, IsDateString(ca.in), ca.in) + } +} + +func TestStringToUnixMilli(t *testing.T) { + r := require.New(t) + type pair struct { + in string + expectUnix int64 + expectErr bool + } + testCases := []pair{ + {"2024-05-01", 1714521600000, false}, + {"2024-05-01 11:12:13", 1714561933000, false}, + {"2024-05-01T11:12:13Z", 0, true}, + {"", 0, true}, + {"2024-05-01 1:12:13", 0, true}, + } + for _, ca := range testCases { + unixMilli, err := StringToUnixMilli(ca.in) + if ca.expectErr { + r.Error(err, ca.in) + } else { + r.NoError(err, ca.in) + r.Equal(ca.expectUnix, unixMilli, ca.in) + } + } +}