From 2d6ffd4102b63989587cf1f2a0327fc3c2b28750 Mon Sep 17 00:00:00 2001 From: Sanyam Singhal Date: Tue, 4 Mar 2025 11:22:55 +0530 Subject: [PATCH] Tech Debt: Merge MergeConstraints and ApplyShardedRecommendations transformation on exported schema (#2366) - now the parsing of sql file happens once and then one by one all transformations are applied then Deparse the final schema file * Added unit tests * Shifted tests from exportSchema_test.go to transformer_test.go --- yb-voyager/cmd/common.go | 4 + yb-voyager/cmd/constants.go | 1 - yb-voyager/cmd/exportSchema.go | 377 ++++++------------ yb-voyager/cmd/exportSchema_test.go | 104 ----- .../adaptive_parallelism_test.go | 10 + yb-voyager/src/constants/constants.go | 3 +- .../src/query/queryparser/query_parser.go | 4 +- .../src/query/sqltransformer/helpers.go | 67 ++++ .../src/query/sqltransformer/transformer.go | 41 +- .../query/sqltransformer/transformer_test.go | 79 +++- 10 files changed, 314 insertions(+), 376 deletions(-) delete mode 100644 yb-voyager/cmd/exportSchema_test.go create mode 100644 yb-voyager/src/query/sqltransformer/helpers.go diff --git a/yb-voyager/cmd/common.go b/yb-voyager/cmd/common.go index 9f75d5eaba..f4fa498b63 100644 --- a/yb-voyager/cmd/common.go +++ b/yb-voyager/cmd/common.go @@ -1210,6 +1210,10 @@ type TargetSizingRecommendations struct { //====== AssesmentReport struct methods ======// func ParseJSONToAssessmentReport(reportPath string) (*AssessmentReport, error) { + if !utils.FileOrFolderExists(reportPath) { + return nil, fmt.Errorf("report file %q does not exist", reportPath) + } + var report AssessmentReport err := jsonfile.NewJsonFile[AssessmentReport](reportPath).Load(&report) if err != nil { diff --git a/yb-voyager/cmd/constants.go b/yb-voyager/cmd/constants.go index 8b36c85244..f57dcb82b3 100644 --- a/yb-voyager/cmd/constants.go +++ b/yb-voyager/cmd/constants.go @@ -61,7 +61,6 @@ const ( ROW_UPDATE_STATUS_NOT_STARTED = 0 ROW_UPDATE_STATUS_IN_PROGRESS = 1 ROW_UPDATE_STATUS_COMPLETED = 3 - COLOCATION_CLAUSE = "colocation" //phase names used in call-home payload ANALYZE_PHASE = "analyze-schema" EXPORT_SCHEMA_PHASE = "export-schema" diff --git a/yb-voyager/cmd/exportSchema.go b/yb-voyager/cmd/exportSchema.go index 87d507ce9b..138fabb262 100644 --- a/yb-voyager/cmd/exportSchema.go +++ b/yb-voyager/cmd/exportSchema.go @@ -55,6 +55,8 @@ var exportSchemaCmd = &cobra.Command{ if err != nil { utils.ErrExit("Error validating export schema flags: %s", err.Error()) } + + validateAssessmentReportPathFlag() markFlagsRequired(cmd) }, @@ -170,16 +172,7 @@ func exportSchema() error { return fmt.Errorf("failed to update indexes info metadata db: %w", err) } - err = applyMigrationAssessmentRecommendations() - if err != nil { - return fmt.Errorf("failed to apply migration assessment recommendation to the schema files: %w", err) - } - - // continue after logging the error; since this transformation is only for performance improvement - err = applyMergeConstraintsTransformations() - if err != nil { - log.Warnf("failed to apply merge constraints transformation to the schema files: %v", err) - } + applySchemaTransformations() utils.PrintAndLog("\nExported schema files created under directory: %s\n\n", filepath.Join(exportDir, "schema")) @@ -246,6 +239,19 @@ func init() { "path to the generated assessment report file(JSON format) to be used for applying recommendation to exported schema") } +func validateAssessmentReportPathFlag() { + if assessmentReportPath == "" { + return + } + + if !utils.FileOrFolderExists(assessmentReportPath) { + utils.ErrExit("assessment report file doesn't exists at path provided in --assessment-report-path flag: %q", assessmentReportPath) + } + if !strings.HasSuffix(assessmentReportPath, ".json") { + utils.ErrExit("assessment report file should be in JSON format, path provided in --assessment-report-path flag: %q", assessmentReportPath) + } +} + func schemaIsExported() bool { if !metaDBIsCreated(exportDir) { return false @@ -295,305 +301,160 @@ func updateIndexesInfoInMetaDB() error { return nil } -func applyMigrationAssessmentRecommendations() error { - if skipRecommendations { - log.Infof("not apply recommendations due to flag --skip-recommendations=true") - return nil - } else if source.DBType == MYSQL { - return nil - } +/* +applySchemaTransformations applies the following transformations to the exported schema one by one +and saves the transformed schema in the same file. - // TODO: copy the reports to "export-dir/assessment/reports" for further usage - assessmentReportPath := lo.Ternary(assessmentReportPath != "", assessmentReportPath, - filepath.Join(exportDir, "assessment", "reports", fmt.Sprintf("%s.json", ASSESSMENT_FILE_NAME))) - log.Infof("using assessmentReportPath: %s", assessmentReportPath) - if !utils.FileOrFolderExists(assessmentReportPath) { - utils.PrintAndLog("migration assessment report file doesn't exists at %q, skipping apply recommendations step...", assessmentReportPath) - return nil - } +In case of any failure in applying any transformation, it logs the error, keep the original file and continues with the next transformation. +*/ +func applySchemaTransformations() { + // 1. Transform table.sql + { + tableFilePath := utils.GetObjectFilePath(schemaDir, TABLE) + transformations := []func([]*pg_query.RawStmt) ([]*pg_query.RawStmt, error){ + applyShardedTableTransformation, // transform #1 + applyMergeConstraintsTransformation, // transform #2 + } - log.Infof("parsing assessment report json file for applying recommendations") - report, err := ParseJSONToAssessmentReport(assessmentReportPath) - if err != nil { - return fmt.Errorf("failed to parse json report file %q: %w", assessmentReportPath, err) + err := transformSchemaFile(tableFilePath, transformations, "table") + if err != nil { + log.Warnf("Error transforming %q: %v", tableFilePath, err) + } } - shardedTables, err := report.GetShardedTablesRecommendation() - if err != nil { - return fmt.Errorf("failed to fetch sharded tables recommendation: %w", err) - } else { - err := applyShardedTablesRecommendation(shardedTables, TABLE) - if err != nil { - return fmt.Errorf("failed to apply colocated vs sharded table recommendation: %w", err) + // 2. Transform mview.sql + { + mviewFilePath := utils.GetObjectFilePath(schemaDir, MVIEW) + transformations := []func([]*pg_query.RawStmt) ([]*pg_query.RawStmt, error){ + applyShardedTableTransformation, // only transformation for mview } - err = applyShardedTablesRecommendation(shardedTables, MVIEW) + + err := transformSchemaFile(mviewFilePath, transformations, "mview") if err != nil { - return fmt.Errorf("failed to apply colocated vs sharded table recommendation: %w", err) + log.Warnf("Error transforming %q: %v", mviewFilePath, err) } } - - assessmentRecommendationsApplied = true - SetAssessmentRecommendationsApplied() - - utils.PrintAndLog("Applied assessment recommendations.") - return nil } -// TODO: merge this function with applying sharded/colocated recommendation -func applyMergeConstraintsTransformations() error { - if utils.GetEnvAsBool("YB_VOYAGER_SKIP_MERGE_CONSTRAINTS_TRANSFORMATIONS", false) { - log.Infof("skipping applying merge constraints transformation due to env var YB_VOYAGER_SKIP_MERGE_CONSTRAINTS_TRANSFORMATIONS=true") - return nil - } - - utils.PrintAndLog("Applying merge constraints transformation to the exported schema") - transformer := sqltransformer.NewTransformer() - - fileName := utils.GetObjectFilePath(schemaDir, TABLE) - if !utils.FileOrFolderExists(fileName) { // there are no tables in exported schema - log.Infof("table.sql file doesn't exists, skipping applying merge constraints transformation") +// transformSchemaFile applies a sequence of transformations to the given schema file +// and writes the transformed result back. If the file doesn't exist, logs a message and returns nil. +func transformSchemaFile(filePath string, transformations []func(raw []*pg_query.RawStmt) ([]*pg_query.RawStmt, error), objectType string) error { + if !utils.FileOrFolderExists(filePath) { + log.Infof("%q file doesn't exist, skipping transformations for %s object type", filePath, objectType) return nil } - rawStmts, err := queryparser.ParseSqlFile(fileName) - if err != nil { - return fmt.Errorf("failed to parse table.sql file: %w", err) - } - - transformedRawStmts, err := transformer.MergeConstraints(rawStmts.Stmts) + rawStmts, err := queryparser.ParseSqlFile(filePath) if err != nil { - return fmt.Errorf("failed to merge constraints: %w", err) + return fmt.Errorf("failed to parse sql statements from %s object type in schema file %q: %w", objectType, filePath, err) } - sqlStmts, err := queryparser.DeparseRawStmts(transformedRawStmts) + beforeSqlStmts, err := queryparser.DeparseRawStmts(rawStmts) if err != nil { - return fmt.Errorf("failed to deparse transformed raw stmts: %w", err) + return fmt.Errorf("failed to deparse raw stmts for %s object type in schema file %q: %w", objectType, filePath, err) } - fileContent := strings.Join(sqlStmts, "\n\n") - - // rename the old file to table_before_merge_constraints.sql - // replace filepath base with new name - renamedFileName := filepath.Join(filepath.Dir(fileName), "table_before_merge_constraints.sql") - err = os.Rename(fileName, renamedFileName) - if err != nil { - return fmt.Errorf("failed to rename table.sql file to table_before_merge_constraints.sql: %w", err) + transformedStmts := rawStmts + // Apply transformations in order + for _, transformFn := range transformations { + newStmts, err := transformFn(transformedStmts) + if err != nil { + // Log and continue using the unmodified statements slice for subsequent transformations in case of error + log.Warnf("failed to apply transformation function %T in schema file %q: %v", transformFn, filePath, err) + continue + } + transformedStmts = newStmts } - err = os.WriteFile(fileName, []byte(fileContent), 0644) + // Deparse + sqlStmts, err := queryparser.DeparseRawStmts(transformedStmts) if err != nil { - return fmt.Errorf("failed to write transformed table.sql file: %w", err) + return fmt.Errorf("failed to deparse transformed raw stmts for %s object type in schema file %q: %w", objectType, filePath, err) } - return nil -} - -func applyShardedTablesRecommendation(shardedTables []string, objType string) error { - if shardedTables == nil { - log.Info("list of sharded tables is null hence all the tables are recommended as colocated") - return nil - } - - filePath := utils.GetObjectFilePath(schemaDir, objType) - if !utils.FileOrFolderExists(filePath) { - // Report if the file does not exist for tables. No need to report it for mviews - if objType == TABLE { - utils.PrintAndLog("Required schema file %s does not exists, "+ - "returning without applying colocated/sharded tables recommendation", filePath) - } + // Below Check for if transformations changed anything is WRONG + // here we are dealing with pointers - *pg_query.RawStmt so underlying elements of slices point to same memory + // if slices.Equal(originalStmts, transformedStmts) { + // log.Infof("no change in the schema for object type %s after applying all transformations", objectType) + // return nil + // } + if slices.Equal(beforeSqlStmts, sqlStmts) { + log.Infof("no change in the schema for object type %s after applying all transformations", objectType) return nil } - log.Infof("applying colocated vs sharded tables recommendation") - var newSQLFileContent strings.Builder - sqlInfoArr := parseSqlFileForObjectType(filePath, objType) - - for _, sqlInfo := range sqlInfoArr { - /* - We can rely on pg_query to detect if it is CreateTable and also table name - but due to time constraint this module can't be tested thoroughly so relying on the existing as much as possible - - We can pass the whole .sql file as a string also to pg_query.Parse() all the statements at once. - But avoiding that also specially for cases where the SQL syntax can be invalid - */ - modifiedSqlStmt, match, err := applyShardingRecommendationIfMatching(&sqlInfo, shardedTables, objType) - if err != nil { - log.Errorf("failed to apply sharding recommendation for table=%q: %v", sqlInfo.objName, err) - if match { - utils.PrintAndLog("Unable to apply sharding recommendation for table=%q, continuing without applying...\n", sqlInfo.objName) - utils.PrintAndLog("Please manually add the clause \"WITH (colocation = false)\" to the CREATE TABLE DDL of the '%s' table.\n", sqlInfo.objName) - } - } else { - if match { - log.Infof("original ddl - %s", sqlInfo.stmt) - log.Infof("modified ddl - %s", modifiedSqlStmt) - } - } - - _, err = newSQLFileContent.WriteString(modifiedSqlStmt + "\n\n") - if err != nil { - return fmt.Errorf("write SQL string to string builder: %w", err) - } - } - - // rename existing table.sql file to table.sql.orig - backupPath := filePath + ".orig" - log.Infof("renaming existing file '%s' --> '%s.orig'", filePath, backupPath) - err := os.Rename(filePath, filePath+".orig") + // Backup original + backupFile := filePath + ".orig" + err = os.Rename(filePath, backupFile) if err != nil { - return fmt.Errorf("error renaming file %s: %w", filePath, err) + return fmt.Errorf("failed to rename %s file to %s: %w", filePath, backupFile, err) } - // create new table.sql file for modified schema - log.Infof("creating file %q to store the modified recommended schema", filePath) - file, err := os.Create(filePath) + // Write updated file + fileContent := strings.Join(sqlStmts, "\n\n") + err = os.WriteFile(filePath, []byte(fileContent), 0644) if err != nil { - return fmt.Errorf("error creating file '%q' storing the modified recommended schema: %w", filePath, err) - } - if _, err = file.WriteString(newSQLFileContent.String()); err != nil { - return fmt.Errorf("error writing to file '%q' storing the modified recommended schema: %w", filePath, err) - } - if err = file.Close(); err != nil { - return fmt.Errorf("error closing file '%q' storing the modified recommended schema: %w", filePath, err) - } - var objTypeName = "" - switch objType { - case MVIEW: - objTypeName = "MATERIALIZED VIEW" - case TABLE: - objTypeName = "TABLE" - default: - panic(fmt.Sprintf("Object type not supported %s", objType)) + return fmt.Errorf("failed to write transformed schema file %q: %w", filePath, err) } - utils.PrintAndLog("Modified CREATE %s statements in %q according to the colocation and sharding recommendations of the assessment report.", - objTypeName, - utils.GetRelativePathFromCwd(filePath)) - utils.PrintAndLog("The original DDLs have been preserved in %q for reference.", utils.GetRelativePathFromCwd(backupPath)) return nil } -/* -applyShardingRecommendationIfMatching uses pg_query module to parse the given SQL stmt -In case of any errors or unexpected behaviour it return the original DDL -so in worse case, only recommendation of that table won't be followed. - -# It can handle cases like multiple options in WITH clause - -returns: -modifiedSqlStmt: original stmt if not sharded else modified stmt with colocation clause -match: true if its a sharded table and should be modified -error: nil/non-nil - -Drawback: pg_query module doesn't have functionality to format the query after parsing -so the CREATE TABLE for sharding recommended tables will be one-liner -*/ -func applyShardingRecommendationIfMatching(sqlInfo *sqlInfo, shardedTables []string, objType string) (string, bool, error) { - - stmt := sqlInfo.stmt - formattedStmt := sqlInfo.formattedStmt - parseTree, err := pg_query.Parse(stmt) +func applyShardedTableTransformation(stmts []*pg_query.RawStmt) ([]*pg_query.RawStmt, error) { + log.Info("applying sharded tables transformation to the exported schema") + assessmentReportPath = lo.Ternary(assessmentReportPath != "", assessmentReportPath, + filepath.Join(exportDir, "assessment", "reports", fmt.Sprintf("%s.json", ASSESSMENT_FILE_NAME))) + assessmentReport, err := ParseJSONToAssessmentReport(assessmentReportPath) if err != nil { - return formattedStmt, false, fmt.Errorf("error parsing the stmt-%s: %v", stmt, err) - } - - if len(parseTree.Stmts) == 0 { - log.Warnf("parse tree is empty for stmt=%s for table '%s'", stmt, sqlInfo.objName) - return formattedStmt, false, nil + return stmts, fmt.Errorf("failed to parse json report file %q: %w", assessmentReportPath, err) } - relation := &pg_query.RangeVar{} - switch objType { - case MVIEW: - createMViewNode, ok := parseTree.Stmts[0].Stmt.Node.(*pg_query.Node_CreateTableAsStmt) - if !ok || createMViewNode.CreateTableAsStmt.Objtype != pg_query.ObjectType_OBJECT_MATVIEW { - // return the original sql if it's not a Create Materialized view statement - log.Infof("stmt=%s is not create materialized view as per the parse tree,"+ - " expected tablename=%s", stmt, sqlInfo.objName) - return formattedStmt, false, nil - } - relation = createMViewNode.CreateTableAsStmt.Into.Rel - case TABLE: - createStmtNode, ok := parseTree.Stmts[0].Stmt.Node.(*pg_query.Node_CreateStmt) - if !ok { // return the original sql if it's not a CreateStmt - log.Infof("stmt=%s is not createTable as per the parse tree, expected tablename=%s", stmt, sqlInfo.objName) - return formattedStmt, false, nil - } - relation = createStmtNode.CreateStmt.Relation - default: - panic(fmt.Sprintf("Object type not supported %s", objType)) - } - - // true -> oracle, false -> PG - parsedObjectName := utils.BuildObjectName(relation.Schemaname, relation.Relname) - - match := false - switch source.DBType { - case POSTGRESQL: - match = slices.Contains(shardedTables, parsedObjectName) - case ORACLE: - // TODO: handle case-sensitivity properly - for _, shardedTable := range shardedTables { - // in case of oracle, shardedTable is unqualified. - if strings.ToLower(shardedTable) == parsedObjectName { - match = true - break + shardedObjects, err := assessmentReport.GetShardedTablesRecommendation() + if err != nil { + return stmts, fmt.Errorf("failed to fetch sharded tables recommendation: %w", err) + } + + isObjectSharded := func(objectName string) bool { + switch source.DBType { + case POSTGRESQL: + return slices.Contains(shardedObjects, objectName) + case ORACLE: + // TODO: handle case-sensitivity properly + for _, shardedObject := range shardedObjects { + // in case of oracle, shardedTable is unqualified. + if strings.ToLower(shardedObject) == objectName { + return true + } } + default: + panic(fmt.Sprintf("unsupported source db type %s for applying sharded table transformation", source.DBType)) } - default: - panic(fmt.Sprintf("unsupported source db type %s for applying sharding recommendations", source.DBType)) - } - if !match { - log.Infof("%q not present in the sharded table list", parsedObjectName) - return formattedStmt, false, nil - } else { - log.Infof("%q present in the sharded table list", parsedObjectName) - } - - colocationOption := &pg_query.DefElem{ - Defname: COLOCATION_CLAUSE, - Arg: pg_query.MakeStrNode("false"), + return false } - nodeForColocationOption := &pg_query.Node_DefElem{ - DefElem: colocationOption, + transformer := sqltransformer.NewTransformer() + transformedRawStmts, err := transformer.ConvertToShardedTables(stmts, isObjectSharded) + if err != nil { + return stmts, fmt.Errorf("failed to convert to sharded tables: %w", err) } - log.Infof("adding colocation option in the parse tree for table %s", sqlInfo.objName) - switch objType { - case MVIEW: - createMViewNode, _ := parseTree.Stmts[0].Stmt.Node.(*pg_query.Node_CreateTableAsStmt) + return transformedRawStmts, nil +} - if createMViewNode.CreateTableAsStmt.Into.Options == nil { - createMViewNode.CreateTableAsStmt.Into.Options = - []*pg_query.Node{{Node: nodeForColocationOption}} - } else { - createMViewNode.CreateTableAsStmt.Into.Options = append( - createMViewNode.CreateTableAsStmt.Into.Options, - &pg_query.Node{Node: nodeForColocationOption}) - } - case TABLE: - createStmtNode, _ := parseTree.Stmts[0].Stmt.Node.(*pg_query.Node_CreateStmt) - if createStmtNode.CreateStmt.Options == nil { - createStmtNode.CreateStmt.Options = - []*pg_query.Node{{Node: nodeForColocationOption}} - } else { - createStmtNode.CreateStmt.Options = append( - createStmtNode.CreateStmt.Options, - &pg_query.Node{Node: nodeForColocationOption}) - } - default: - panic(fmt.Sprintf("Object type not supported %s", objType)) +func applyMergeConstraintsTransformation(rawStmts []*pg_query.RawStmt) ([]*pg_query.RawStmt, error) { + if utils.GetEnvAsBool("YB_VOYAGER_SKIP_MERGE_CONSTRAINTS_TRANSFORMATIONS", false) { + log.Infof("skipping applying merge constraints transformation due to env var YB_VOYAGER_SKIP_MERGE_CONSTRAINTS_TRANSFORMATIONS=true") + return rawStmts, nil } - log.Infof("deparsing the updated parse tre into a stmt for table '%s'", parsedObjectName) - modifiedQuery, err := pg_query.Deparse(parseTree) + log.Info("applying merge constraints transformation to the exported schema") + transformer := sqltransformer.NewTransformer() + transformedRawStmts, err := transformer.MergeConstraints(rawStmts) if err != nil { - return formattedStmt, true, fmt.Errorf("error deparsing the parseTree into the query: %w", err) + return rawStmts, fmt.Errorf("failed to merge constraints: %w", err) } - // adding semi-colon at the end - return fmt.Sprintf("%s;", modifiedQuery), true, nil + return transformedRawStmts, nil } func createExportSchemaStartedEvent() cp.ExportSchemaStartedEvent { diff --git a/yb-voyager/cmd/exportSchema_test.go b/yb-voyager/cmd/exportSchema_test.go deleted file mode 100644 index a399b0e70a..0000000000 --- a/yb-voyager/cmd/exportSchema_test.go +++ /dev/null @@ -1,104 +0,0 @@ -//go:build unit - -/* -Copyright (c) YugabyteDB, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package cmd - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestShardingRecommendations(t *testing.T) { - sqlInfo_mview1 := sqlInfo{ - objName: "m1", - stmt: "CREATE MATERIALIZED VIEW m1 AS SELECT * FROM t1 WHERE a = 3", - formattedStmt: "CREATE MATERIALIZED VIEW m1 AS SELECT * FROM t1 WHERE a = 3", - fileName: "", - } - sqlInfo_mview2 := sqlInfo{ - objName: "m1", - stmt: "CREATE MATERIALIZED VIEW m1 AS SELECT * FROM t1 WHERE a = 3 with no data;", - formattedStmt: "CREATE MATERIALIZED VIEW m1 AS SELECT * FROM t1 WHERE a = 3 with no data;", - fileName: "", - } - sqlInfo_mview3 := sqlInfo{ - objName: "m1", - stmt: "CREATE MATERIALIZED VIEW m1 WITH (fillfactor=70) AS SELECT * FROM t1 WHERE a = 3 with no data", - formattedStmt: "CREATE MATERIALIZED VIEW m1 WITH (fillfactor=70) AS SELECT * FROM t1 WHERE a = 3 with no data", - fileName: "", - } - source.DBType = POSTGRESQL - modifiedSqlStmt, match, _ := applyShardingRecommendationIfMatching(&sqlInfo_mview1, []string{"m1"}, MVIEW) - assert.Equal(t, strings.ToLower(modifiedSqlStmt), - strings.ToLower("create materialized view m1 with (colocation=false) as select * from t1 where a = 3;")) - assert.Equal(t, match, true) - - modifiedSqlStmt, match, _ = applyShardingRecommendationIfMatching(&sqlInfo_mview2, []string{"m1"}, MVIEW) - assert.Equal(t, strings.ToLower(modifiedSqlStmt), - strings.ToLower("create materialized view m1 with (colocation=false) as select * from t1 where a = 3 with no data;")) - assert.Equal(t, match, true) - - modifiedSqlStmt, match, _ = applyShardingRecommendationIfMatching(&sqlInfo_mview2, []string{"m1_notfound"}, MVIEW) - assert.Equal(t, modifiedSqlStmt, sqlInfo_mview2.stmt) - assert.Equal(t, match, false) - - modifiedSqlStmt, match, _ = applyShardingRecommendationIfMatching(&sqlInfo_mview3, []string{"m1"}, MVIEW) - assert.Equal(t, strings.ToLower(modifiedSqlStmt), - strings.ToLower("create materialized view m1 with (fillfactor=70, colocation=false) "+ - "as select * from t1 where a = 3 with no data;")) - assert.Equal(t, match, true) - - sqlInfo_table1 := sqlInfo{ - objName: "m1", - stmt: "create table a (a int, b int)", - formattedStmt: "create table a (a int, b int)", - fileName: "", - } - sqlInfo_table2 := sqlInfo{ - objName: "m1", - stmt: "create table a (a int, b int) WITH (fillfactor=70);", - formattedStmt: "create table a (a int, b int) WITH (fillfactor=70);", - fileName: "", - } - sqlInfo_table3 := sqlInfo{ - objName: "m1", - stmt: "alter table a add col text;", - formattedStmt: "alter table a add col text;", - fileName: "", - } - modifiedTableStmt, matchTable, _ := applyShardingRecommendationIfMatching(&sqlInfo_table1, []string{"a"}, TABLE) - assert.Equal(t, strings.ToLower(modifiedTableStmt), - strings.ToLower("create table a (a int, b int) WITH (colocation=false);")) - assert.Equal(t, matchTable, true) - - modifiedTableStmt, matchTable, _ = applyShardingRecommendationIfMatching(&sqlInfo_table2, []string{"a"}, TABLE) - assert.Equal(t, strings.ToLower(modifiedTableStmt), - strings.ToLower("create table a (a int, b int) WITH (fillfactor=70, colocation=false);")) - assert.Equal(t, matchTable, true) - - modifiedSqlStmt, matchTable, _ = applyShardingRecommendationIfMatching(&sqlInfo_table2, []string{"m1_notfound"}, TABLE) - assert.Equal(t, modifiedSqlStmt, sqlInfo_table2.stmt) - assert.Equal(t, matchTable, false) - - modifiedTableStmt, matchTable, _ = applyShardingRecommendationIfMatching(&sqlInfo_table3, []string{"a"}, TABLE) - assert.Equal(t, strings.ToLower(modifiedTableStmt), - strings.ToLower(sqlInfo_table3.stmt)) - assert.Equal(t, matchTable, false) -} diff --git a/yb-voyager/src/adaptiveparallelism/adaptive_parallelism_test.go b/yb-voyager/src/adaptiveparallelism/adaptive_parallelism_test.go index 139b443f4d..c46c1ebbbe 100644 --- a/yb-voyager/src/adaptiveparallelism/adaptive_parallelism_test.go +++ b/yb-voyager/src/adaptiveparallelism/adaptive_parallelism_test.go @@ -18,9 +18,11 @@ limitations under the License. package adaptiveparallelism import ( + "os" "strconv" "testing" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/yugabyte/yb-voyager/yb-voyager/src/tgtdb" ) @@ -112,6 +114,14 @@ func (d *dummyTargetYugabyteDB) UpdateNumConnectionsInPool(delta int) error { return nil } +func TestMain(m *testing.M) { + // to avoid info level logs flooding the unit test output + log.SetLevel(log.WarnLevel) + + exitCode := m.Run() + os.Exit(exitCode) +} + func TestMaxCpuUsage(t *testing.T) { yb := &dummyTargetYugabyteDB{ size: 3, diff --git a/yb-voyager/src/constants/constants.go b/yb-voyager/src/constants/constants.go index 935232058d..b4158cf112 100644 --- a/yb-voyager/src/constants/constants.go +++ b/yb-voyager/src/constants/constants.go @@ -49,5 +49,6 @@ const ( ) const ( - OBFUSCATE_STRING = "XXXXX" + OBFUSCATE_STRING = "XXXXX" + COLOCATION_CLAUSE = "colocation" ) diff --git a/yb-voyager/src/query/queryparser/query_parser.go b/yb-voyager/src/query/queryparser/query_parser.go index 76073968b0..3c325870dd 100644 --- a/yb-voyager/src/query/queryparser/query_parser.go +++ b/yb-voyager/src/query/queryparser/query_parser.go @@ -54,7 +54,7 @@ func ParsePLPGSQLToJson(query string) (string, error) { return jsonString, err } -func ParseSqlFile(filePath string) (*pg_query.ParseResult, error) { +func ParseSqlFile(filePath string) ([]*pg_query.RawStmt, error) { log.Debugf("parsing the file [%s]", filePath) bytes, err := os.ReadFile(filePath) if err != nil { @@ -67,7 +67,7 @@ func ParseSqlFile(filePath string) (*pg_query.ParseResult, error) { } log.Debugf("sql file contents: %s\n", string(bytes)) log.Debugf("parse tree: %v\n", tree) - return tree, nil + return tree.Stmts, nil } func ProcessDDL(parseTree *pg_query.ParseResult) (DDLObject, error) { diff --git a/yb-voyager/src/query/sqltransformer/helpers.go b/yb-voyager/src/query/sqltransformer/helpers.go new file mode 100644 index 0000000000..7c6c5ff1fc --- /dev/null +++ b/yb-voyager/src/query/sqltransformer/helpers.go @@ -0,0 +1,67 @@ +/* +Copyright (c) YugabyteDB, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package sqltransformer + +import ( + pg_query "github.com/pganalyze/pg_query_go/v6" + "github.com/yugabyte/yb-voyager/yb-voyager/src/constants" +) + +// addColocatedOption adds "WITH (colocated = true)" to the "Options" array +// for either a CreateStmt or CreateMatViewStmt. +func addColocationOptionToCreateTable(createStmt *pg_query.CreateStmt) { + if createStmt == nil { + return + } + // If Options slice is nil, initialize it + if createStmt.Options == nil { + createStmt.Options = []*pg_query.Node{} + } + + // Build DefElem: defname = "colocated", arg = "false" + defElemNode := &pg_query.Node{ + Node: &pg_query.Node_DefElem{ + DefElem: &pg_query.DefElem{ + Defname: constants.COLOCATION_CLAUSE, + Arg: pg_query.MakeStrNode("false"), + }, + }, + } + + createStmt.Options = append(createStmt.Options, defElemNode) +} + +func addColocationOptionToCreateMaterializedView(createMatViewStmt *pg_query.CreateTableAsStmt) { + if createMatViewStmt == nil { + return + } + // If Options slice is nil, initialize it + if createMatViewStmt.Into.Options == nil { + createMatViewStmt.Into.Options = []*pg_query.Node{} + } + + // Build DefElem: defname = "colocated", arg = "false" + defElemNode := &pg_query.Node{ + Node: &pg_query.Node_DefElem{ + DefElem: &pg_query.DefElem{ + Defname: constants.COLOCATION_CLAUSE, + Arg: pg_query.MakeStrNode("false"), + }, + }, + } + + createMatViewStmt.Into.Options = append(createMatViewStmt.Into.Options, defElemNode) +} diff --git a/yb-voyager/src/query/sqltransformer/transformer.go b/yb-voyager/src/query/sqltransformer/transformer.go index 8359f7f72b..b716e97af3 100644 --- a/yb-voyager/src/query/sqltransformer/transformer.go +++ b/yb-voyager/src/query/sqltransformer/transformer.go @@ -20,6 +20,7 @@ import ( "slices" pg_query "github.com/pganalyze/pg_query_go/v6" + log "github.com/sirupsen/logrus" "github.com/yugabyte/yb-voyager/yb-voyager/src/query/queryparser" ) @@ -63,7 +64,9 @@ Note: Need to keep the relative ordering of statements(tables) intact. Because there can be cases like Foreign Key constraints that depend on the order of tables. */ func (t *Transformer) MergeConstraints(stmts []*pg_query.RawStmt) ([]*pg_query.RawStmt, error) { - // TODO: Ensure removing all the ALTER stmts which are merged into CREATE. No duplicates. + if len(stmts) == 0 { + return stmts, nil + } createStmtMap := make(map[string]*pg_query.RawStmt) for _, stmt := range stmts { @@ -153,3 +156,39 @@ func (t *Transformer) MergeConstraints(stmts []*pg_query.RawStmt) ([]*pg_query.R return result, nil } + +// write a tranformation function which converts the given tables into Sharded table by adding clause WITH (colocated = true) +func (t *Transformer) ConvertToShardedTables(stmts []*pg_query.RawStmt, isObjectSharded func(objectName string) bool) ([]*pg_query.RawStmt, error) { + if len(stmts) == 0 { + return stmts, nil + } + + var result []*pg_query.RawStmt + for _, stmt := range stmts { + stmtType := queryparser.GetStatementType(stmt.Stmt.ProtoReflect()) + + switch stmtType { + case queryparser.PG_QUERY_CREATE_STMT: // CREATE TABLE case + objectName := queryparser.GetObjectNameFromRangeVar(stmt.Stmt.GetCreateStmt().Relation) + if isObjectSharded(objectName) { + log.Infof("adding colocation option to CREATE TABLE for object %v", objectName) + addColocationOptionToCreateTable(stmt.Stmt.GetCreateStmt()) + } + + result = append(result, stmt) + case queryparser.PG_QUERY_CREATE_TABLE_AS_STMT: // CREATE MATERIALIZED VIEW case + objectName := queryparser.GetObjectNameFromRangeVar(stmt.Stmt.GetCreateTableAsStmt().Into.Rel) + if isObjectSharded(objectName) { + log.Infof("adding colocation option to CREATE MATERIALIZED VIEW for object %v", objectName) + addColocationOptionToCreateMaterializedView(stmt.Stmt.GetCreateTableAsStmt()) + } + + result = append(result, stmt) + default: + result = append(result, stmt) + } + + } + + return result, nil +} diff --git a/yb-voyager/src/query/sqltransformer/transformer_test.go b/yb-voyager/src/query/sqltransformer/transformer_test.go index e24c343a3a..c0dc70d6f3 100644 --- a/yb-voyager/src/query/sqltransformer/transformer_test.go +++ b/yb-voyager/src/query/sqltransformer/transformer_test.go @@ -66,7 +66,7 @@ func TestMergeConstraints_Basic(t *testing.T) { testutils.FatalIfError(t, err) transformer := NewTransformer() - transformedStmts, err := transformer.MergeConstraints(stmts.Stmts) + transformedStmts, err := transformer.MergeConstraints(stmts) testutils.FatalIfError(t, err) finalSqlStmts, err := queryparser.DeparseRawStmts(transformedStmts) @@ -135,7 +135,7 @@ func TestMergeConstraints_AllSupportedConstraintTypes(t *testing.T) { testutils.FatalIfError(t, err) transformer := NewTransformer() - transformedStmts, err := transformer.MergeConstraints(stmts.Stmts) + transformedStmts, err := transformer.MergeConstraints(stmts) testutils.FatalIfError(t, err) finalSqlStmts, err := queryparser.DeparseRawStmts(transformedStmts) @@ -165,7 +165,7 @@ func TestMergeConstraints_DifferentCasing(t *testing.T) { testutils.FatalIfError(t, err) transformer := NewTransformer() - transformedStmts, err := transformer.MergeConstraints(stmts.Stmts) + transformedStmts, err := transformer.MergeConstraints(stmts) testutils.FatalIfError(t, err) finalSqlStmts, err := queryparser.DeparseRawStmts(transformedStmts) @@ -199,7 +199,7 @@ func TestMergeConstraints_MultipleConstraintsInSingleStmt(t *testing.T) { testutils.FatalIfError(t, err) transformer := NewTransformer() - transformedStmts, err := transformer.MergeConstraints(stmts.Stmts) + transformedStmts, err := transformer.MergeConstraints(stmts) testutils.FatalIfError(t, err) finalSqlStmts, err := queryparser.DeparseRawStmts(transformedStmts) @@ -241,7 +241,7 @@ func TestMergeConstraints_CircularDependencyWithSeparateFK(t *testing.T) { testutils.FatalIfError(t, err) transformer := NewTransformer() - transformedStmts, err := transformer.MergeConstraints(stmts.Stmts) + transformedStmts, err := transformer.MergeConstraints(stmts) testutils.FatalIfError(t, err) finalSqlStmts, err := queryparser.DeparseRawStmts(transformedStmts) @@ -272,7 +272,7 @@ func TestMergeConstraints_QuotedColumnNames(t *testing.T) { testutils.FatalIfError(t, err) transformer := NewTransformer() - transformedStmts, err := transformer.MergeConstraints(stmts.Stmts) + transformedStmts, err := transformer.MergeConstraints(stmts) testutils.FatalIfError(t, err) finalSqlStmts, err := queryparser.DeparseRawStmts(transformedStmts) @@ -299,7 +299,7 @@ func TestMergeConstraints_AlterWithoutCreateTableError(t *testing.T) { testutils.FatalIfError(t, err) transformer := NewTransformer() - _, transformErr := transformer.MergeConstraints(stmts.Stmts) + _, transformErr := transformer.MergeConstraints(stmts) if transformErr == nil { t.Fatalf("expected an error because CREATE TABLE is missing, but got no error") } @@ -311,6 +311,67 @@ func TestMergeConstraints_AlterWithoutCreateTableError(t *testing.T) { } } +func TestConvertToShardedTables(t *testing.T) { + sqlFileContent := ` + CREATE TABLE test_table1 ( + id INT PRIMARY KEY, + name VARCHAR(255) + ); + + CREATE TABLE test_table2 ( + id INT PRIMARY KEY, + name VARCHAR(255), + email VARCHAR(255) + ); + + CREATE TABLE test_table3 ( + id INT PRIMARY KEY, + name VARCHAR(255), + email VARCHAR(255) + ) WITH (fillfactor=100); + + CREATE TABLE test_table4 ( + id INT PRIMARY KEY, + name VARCHAR(255), + email VARCHAR(255) + ) WITH (fillfactor=101); + + CREATE MATERIALIZED VIEW test_mview1 AS SELECT * FROM test_table1; + CREATE MATERIALIZED VIEW test_mview2 AS SELECT * FROM test_table2; + CREATE MATERIALIZED VIEW test_mview3 WITH (fillfactor=70) AS SELECT * FROM test_table3 WHERE a = 3 with no data; + + ALTER TABLE test_table1 ADD COLUMN col text; -- this should be ignored + ` + + expectedSqls := []string{ + `CREATE TABLE test_table1 (id int PRIMARY KEY, name varchar(255)) WITH (colocation=false);`, + `CREATE TABLE test_table2 (id int PRIMARY KEY, name varchar(255), email varchar(255));`, + `CREATE TABLE test_table3 (id int PRIMARY KEY, name varchar(255), email varchar(255)) WITH (fillfactor=100, colocation=false);`, + `CREATE TABLE test_table4 (id int PRIMARY KEY, name varchar(255), email varchar(255)) WITH (fillfactor=101);`, + `CREATE MATERIALIZED VIEW test_mview1 WITH (colocation=false) AS SELECT * FROM test_table1;`, + `CREATE MATERIALIZED VIEW test_mview2 AS SELECT * FROM test_table2;`, + `CREATE MATERIALIZED VIEW test_mview3 WITH (fillfactor=70, colocation=false) AS SELECT * FROM test_table3 WHERE a = 3 WITH NO DATA;`, + `ALTER TABLE test_table1 ADD COLUMN col text;`, + } + + tempFilePath, err := testutils.CreateTempFile("/tmp", sqlFileContent, "sql") + testutils.FatalIfError(t, err) + + stmts, err := queryparser.ParseSqlFile(tempFilePath) + testutils.FatalIfError(t, err) + + transformer := NewTransformer() + transformedStmts, err := transformer.ConvertToShardedTables(stmts, func(objectName string) bool { + return objectName == "test_table1" || objectName == "test_table3" || objectName == "test_mview1" || objectName == "test_mview3" + }) + testutils.FatalIfError(t, err) + + finalSqlStmts, err := queryparser.DeparseRawStmts(transformedStmts) + testutils.FatalIfError(t, err) + + testutils.AssertEqualStringSlices(t, expectedSqls, finalSqlStmts) +} + /* jFYI: For EXCLUDE constraint, the USING btree is omitted by parser during deparsing. @@ -347,7 +408,7 @@ func TestMergeConstraints_ExcludeConstraintType(t *testing.T) { testutils.FatalIfError(t, err) transformer := NewTransformer() - transformedStmts, err := transformer.MergeConstraints(stmts.Stmts) + transformedStmts, err := transformer.MergeConstraints(stmts) testutils.FatalIfError(t, err) finalSqlStmts, err := queryparser.DeparseRawStmts(transformedStmts) @@ -389,7 +450,7 @@ func Test_RemovalOfDefaultValuesByParser(t *testing.T) { testutils.FatalIfError(t, err) transformer := NewTransformer() - transformedStmts, err := transformer.MergeConstraints(stmts.Stmts) + transformedStmts, err := transformer.MergeConstraints(stmts) testutils.FatalIfError(t, err) finalSqlStmts, err := queryparser.DeparseRawStmts(transformedStmts)