From a46d25abf24ba4992a85d96c2b8989f89fe370f1 Mon Sep 17 00:00:00 2001 From: Frankie Date: Tue, 16 Jul 2024 02:09:10 +0800 Subject: [PATCH 1/3] 1. Add AWS RDS connection and RDS certificate. 2. Update the unit test configuration regarding db connections. 3. Update the env and config files. --- .gitignore | 7 +- Makefile | 12 +-- README.md | 27 +++++- app.env.example | 2 + cmd/go-todolist-grpc/main.go | 1 + internal/config/config.go | 1 + internal/config/env.go.example | 16 ++-- internal/model/mod_category_test.go | 5 +- internal/model/mod_task_test.go | 5 +- internal/model/mod_user_test.go | 5 +- internal/pkg/db/db.go | 81 ++++++++++++++++-- internal/pkg/db/db_test.go | 126 +++++++++++++++++++++++++--- internal/service/s_category_test.go | 6 +- internal/service/s_task_test.go | 6 +- internal/service/s_user_test.go | 6 +- 15 files changed, 259 insertions(+), 47 deletions(-) diff --git a/.gitignore b/.gitignore index 598d9af..645e749 100644 --- a/.gitignore +++ b/.gitignore @@ -17,10 +17,11 @@ # Dependency directories (remove the comment below to include it) # vendor/ -internal/config/env.go -app.env -target/ *.DS_Store +target/ +app.env +internal/config/env.go +internal/config/certs/*.pem # Go workspace file go.work diff --git a/Makefile b/Makefile index e289cfc..7a99b22 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,8 @@ include app.env DOCKER = docker compose exec server +DOCKER_HOST=db +TEST_DB=test_db YMD = _$$(date +'%Y%m%d') number?=3 table := @@ -21,19 +23,19 @@ migrate-create: # make migrate-up number=x migrate-up: - $(DOCKER) migrate -path ./internal/migrations -database "$(DB)://$(DB_USER):$(DB_PASS)@$(DB_HOST):$(DB_PORT)/$(DB_NAME)?sslmode=disable" up $(number) + $(DOCKER) migrate -path ./internal/migrations -database "$(DB)://$(DB_USER):$(DB_PASS)@$(DB_HOST):$(DB_PORT)/$(DB_NAME)?sslmode=verify-full&sslrootcert=./internal/config/certs/rds-ca-2019-root.pem" up $(number) # make migrate-down number=x migrate-down: - $(DOCKER) migrate -path ./internal/migrations -database "$(DB)://$(DB_USER):$(DB_PASS)@$(DB_HOST):$(DB_PORT)/$(DB_NAME)?sslmode=disable" down $(number) + $(DOCKER) migrate -path ./internal/migrations -database "$(DB)://$(DB_USER):$(DB_PASS)@$(DB_HOST):$(DB_PORT)/$(DB_NAME)?sslmode=verify-full&sslrootcert=./internal/config/certs/rds-ca-2019-root.pem" down $(number) # make migrate-test-up number=x migrate-test-up: - $(DOCKER) migrate -path ./internal/migrations -database "$(DB)://$(DB_USER):$(DB_PASS)@$(DB_HOST):$(DB_PORT)/test_db?sslmode=disable" up $(number) + $(DOCKER) migrate -path ./internal/migrations -database "$(DB)://root:root@$(DOCKER_HOST):$(DB_PORT)/$(TEST_DB)?sslmode=disable" up $(number) # make migrate-test-down number=x migrate-test-down: - $(DOCKER) migrate -path ./internal/migrations -database "$(DB)://$(DB_USER):$(DB_PASS)@$(DB_HOST):$(DB_PORT)/test_db?sslmode=disable" down $(number) + $(DOCKER) migrate -path ./internal/migrations -database "$(DB)://root:root@$(DOCKER_HOST):$(DB_PORT)/$(TEST_DB)?sslmode=disable" down $(number) clean-logs: @echo "Cleaning log directories..." @@ -75,7 +77,7 @@ go-test-single: go-test-ci: @set -e; \ start_time=$$(date +%s); \ - migrate -path ./internal/migrations -database "$(DB)://$(DB_USER):$(DB_PASS)@$(DB_HOST):$(DB_PORT)/test_db?sslmode=disable" up; \ + migrate -path ./internal/migrations -database "$(DB)://root:root@127.0.0.1:$(DB_PORT)/$(TEST_DB)?sslmode=disable" up; \ go test -v -short ./...; \ end_time=$$(date +%s); \ total_duration=$$((end_time - start_time)); \ diff --git a/README.md b/README.md index 2d3b987..2b20338 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ This is a To-Do List project implemented using Go language and gRPC. git clone https://github.com/Frankie0702111/go-todolist-grpc.git ``` -## 2.Set up Docker information, such as database, Redis, Server +## 2.Set up Docker information, such as database, gRPC server, gateway server... ```bash cd go-todolist-grpc cp app.env.example app.env @@ -63,13 +63,22 @@ docker compose up -d docker compose stop ``` -## 4.Set up basic information, such as database, Redis, AWS, JWT +## 4.Set up the testing information, such as database, AWS, JWT... ```bash cp ./internal/config/env.go.example ./internal/config/env.go vim ./internal/config/env.go ``` -## 5.Generate db migrations +## 5.Download the AWS-RDS certificates +- **Documents** + - [Using SSL/TLS to encrypt a connection to a DB instance or cluster](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.SSL.html) + +- **1.Choosing the correct AWS region** + - [Download certificate bundles for Amazon RDS](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.SSL.html#UsingWithRDS.SSL.CertificatesAllRegions) +- **2.Changing the .pem file name (name: rds-ca-2019-root.pem)** +- **3.Move the rds-ca-2019-root.pem file to internal/config/certs** + +## 6.Generate db migrations ```bash # Up all migration make migrate-up @@ -82,6 +91,18 @@ make migrate-up number=1 make migrate-down number=1 ``` +## 7.Create the DB Extensions(Options) +[Reference: Determining the SSL connection status](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/PostgreSQL.Concepts.General.SSL.html#PostgreSQL.Concepts.General.SSL.Status) +```sql +CREATE EXTENSION sslinfo; + +SELECT ssl_is_used(); +ssl_is_used +--------- +t +(1 row) +``` + # Folder structure ``` ├── Dockerfile diff --git a/app.env.example b/app.env.example index 5bd9e96..ad67882 100644 --- a/app.env.example +++ b/app.env.example @@ -10,6 +10,8 @@ DB_PORT=5432 DB_USER=root DB_PASS=root DB_NAME=go-todolist-grpc-db +# SSL_MODE=verify-full +SSL_MODE=disable DB_CONN_MAX_LT_SEC=1800 DB_MAX_CONN=50 DB_MAX_IDLE=5 diff --git a/cmd/go-todolist-grpc/main.go b/cmd/go-todolist-grpc/main.go index f28ac7f..bd816c3 100644 --- a/cmd/go-todolist-grpc/main.go +++ b/cmd/go-todolist-grpc/main.go @@ -50,6 +50,7 @@ func main() { Username: cnf.DBUser, Password: cnf.DBPassword, DBName: cnf.DBName, + SSLMode: cnf.SSLMode, ConnectionMaxLifeTimeSec: cnf.DBConnectionMaxLifeTimeSec, MaxConn: cnf.DBMaxConnection, MaxIdle: cnf.DBMaxIdle, diff --git a/internal/config/config.go b/internal/config/config.go index 0249e6b..75e0f24 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,6 +19,7 @@ type Config struct { DBUser string `mapstructure:"DB_USER"` DBPassword string `mapstructure:"DB_PASS"` DBName string `mapstructure:"DB_NAME"` + SSLMode string `mapstructure:"SSL_MODE"` DBConnectionMaxLifeTimeSec *int `mapstructure:"DB_CONN_MAX_LT_SEC"` DBMaxConnection *int `mapstructure:"DB_MAX_CONN"` DBMaxIdle *int `mapstructure:"DB_MAX_IDLE"` diff --git a/internal/config/env.go.example b/internal/config/env.go.example index 0783af8..3b17792 100644 --- a/internal/config/env.go.example +++ b/internal/config/env.go.example @@ -1,27 +1,31 @@ package config +// The config parameters for testing. + // Server const ( HttpPort = "8642" GrpcPort = "7531" ) -// Gorm +// GORM const ( Source = "postgres" SourceUser = "root" SourcePassword = "root" SourceHost = "127.0.0.1" SourcePort = "5432" - SourceDataBase = "go-todolist-grpc-db" + SourceDataBase = "test_db" + SourceSSLMode = "disable" SourceDBConnMaxLTSec = 1800 SourceMaxConn = 50 SourceMaxIdle = 5 -) -// For testing -const ( - TestSourceDataBase = "test_db" + AWSSourceUser = "" + AWSSourcePassword = "" + AWSSourceHost = "" + AWSSourceDataBase = "go_todolist_grpc" + AWSSourceSSLMode = "verify-full" ) // JWT diff --git a/internal/model/mod_category_test.go b/internal/model/mod_category_test.go index 0feb1bc..32d69b8 100644 --- a/internal/model/mod_category_test.go +++ b/internal/model/mod_category_test.go @@ -19,12 +19,13 @@ var sqlDBCategory *sql.DB var sqlTxCategory *sql.Tx func setUpModCategory() { - psqlInfo := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s application_name=otter sslmode=disable timezone=UTC", + psqlInfo := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s timezone=UTC", config.SourceHost, config.SourcePort, config.SourceUser, config.SourcePassword, - config.TestSourceDataBase, + config.SourceDataBase, + config.SourceSSLMode, ) db, err := sql.Open("postgres", psqlInfo) diff --git a/internal/model/mod_task_test.go b/internal/model/mod_task_test.go index c5abad4..3d252d8 100644 --- a/internal/model/mod_task_test.go +++ b/internal/model/mod_task_test.go @@ -18,12 +18,13 @@ var sqlDBTask *sql.DB var sqlTxTask *sql.Tx func setUpModTask() { - psqlInfo := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s application_name=otter sslmode=disable timezone=UTC", + psqlInfo := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s timezone=UTC", config.SourceHost, config.SourcePort, config.SourceUser, config.SourcePassword, - config.TestSourceDataBase, + config.SourceDataBase, + config.SourceSSLMode, ) db, err := sql.Open("postgres", psqlInfo) diff --git a/internal/model/mod_user_test.go b/internal/model/mod_user_test.go index 8d8e4d4..3da63c7 100644 --- a/internal/model/mod_user_test.go +++ b/internal/model/mod_user_test.go @@ -17,12 +17,13 @@ var sqlDBUser *sql.DB var sqlTxUser *sql.Tx func setUpModUser() { - psqlInfo := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s application_name=otter sslmode=disable timezone=UTC", + psqlInfo := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s timezone=UTC", config.SourceHost, config.SourcePort, config.SourceUser, config.SourcePassword, - config.TestSourceDataBase, + config.SourceDataBase, + config.SourceSSLMode, ) db, err := sql.Open("postgres", psqlInfo) diff --git a/internal/pkg/db/db.go b/internal/pkg/db/db.go index 0004971..2ff80aa 100644 --- a/internal/pkg/db/db.go +++ b/internal/pkg/db/db.go @@ -1,9 +1,14 @@ package db import ( + "crypto/x509" "database/sql" + "encoding/pem" "fmt" "go-todolist-grpc/internal/pkg/log" + "os" + "path/filepath" + "runtime" "time" _ "github.com/lib/pq" @@ -20,19 +25,40 @@ type Option struct { Username string Password string DBName string + SSLMode string ConnectionMaxLifeTimeSec *int MaxConn *int MaxIdle *int } func Init(opt *Option) error { - psqlInfo := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable timezone=UTC", - opt.Host, - opt.Port, - opt.Username, - opt.Password, - opt.DBName, - ) + var psqlInfo string + if opt.SSLMode == "verify-full" { + sslRootCert := getRootCertPath() + err := verifyCertificate(sslRootCert) + if err != nil { + return fmt.Errorf("certificate verification failed: %w", err) + } + + psqlInfo = fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s sslrootcert=%s timezone=UTC", + opt.Host, + opt.Port, + opt.Username, + opt.Password, + opt.DBName, + opt.SSLMode, + sslRootCert, + ) + } else { + psqlInfo = fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s timezone=UTC", + opt.Host, + opt.Port, + opt.Username, + opt.Password, + opt.DBName, + opt.SSLMode, + ) + } log.Info.Printf("initial gRPC DB: %s:%s", opt.Host, opt.Port) @@ -89,3 +115,44 @@ func GormDriver(db gorm.ConnPool) *gorm.DB { func ResetConn() { gRPCDB = nil } + +func getRootCertPath() string { + _, b, _, _ := runtime.Caller(0) + basepath := filepath.Dir(b) + return filepath.Join(basepath, "..", "..", "config", "certs", "rds-ca-2019-root.pem") +} + +func GetRootCertPath() string { + return getRootCertPath() +} + +func verifyCertificate(certPath string) error { + certData, err := os.ReadFile(certPath) + if err != nil { + return fmt.Errorf("failed to read certificate file: %w", err) + } + + block, _ := pem.Decode(certData) + if block == nil { + return fmt.Errorf("failed to decode PEM block") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + + // Verify the certificate issuer + expectedIssuer := "CN=Amazon RDS Root 2019 CA,OU=Amazon RDS,O=Amazon Web Services\\, Inc.,L=Seattle,ST=Washington,C=US" + if cert.Issuer.String() != expectedIssuer { + return fmt.Errorf("unexpected certificate issuer: got %s, want %s", cert.Issuer.String(), expectedIssuer) + } + + // Verify the certificate's validity period + now := time.Now() + if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { + return fmt.Errorf("certificate is not valid at the current time") + } + + return nil +} diff --git a/internal/pkg/db/db_test.go b/internal/pkg/db/db_test.go index a8861da..a34c221 100644 --- a/internal/pkg/db/db_test.go +++ b/internal/pkg/db/db_test.go @@ -1,19 +1,67 @@ package db_test import ( + "crypto/x509" + "encoding/pem" "go-todolist-grpc/internal/config" mydb "go-todolist-grpc/internal/pkg/db" "go-todolist-grpc/internal/pkg/log" "os" "strconv" "testing" + "time" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" ) func TestInitSuccess(t *testing.T) { - t.Run("Success", func(t *testing.T) { + t.Run("Success_AWS", func(t *testing.T) { + connectionMaxLifeTimeSec := config.SourceDBConnMaxLTSec + maxConn := config.SourceMaxConn + maxIdle := config.SourceMaxIdle + log.Init(config.LogLevel, config.LogFolderPath, strconv.Itoa(os.Getpid()), config.EnableConsoleOutput, config.EnableFileOutput) + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPing() + + // Simulate the SSL query + rows := sqlmock.NewRows([]string{"ssl_is_used"}).AddRow(true) + mock.ExpectQuery("SELECT ssl_is_used()").WillReturnRows(rows) + + opt := &mydb.Option{ + Host: config.AWSSourceHost, + Port: config.SourcePort, + Username: config.AWSSourceUser, + Password: config.AWSSourcePassword, + DBName: config.AWSSourceDataBase, + SSLMode: config.AWSSourceSSLMode, + ConnectionMaxLifeTimeSec: &connectionMaxLifeTimeSec, + MaxConn: &maxConn, + MaxIdle: &maxIdle, + } + + err = mydb.Init(opt) + assert.NoError(t, err) + + // Checking the SSL status + var sslUsed bool + err = db.QueryRow("SELECT ssl_is_used()").Scan(&sslUsed) + assert.NoError(t, err) + assert.True(t, sslUsed, "SSL should be used") + + // We make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + }) + + t.Run("Success_General", func(t *testing.T) { connectionMaxLifeTimeSec := config.SourceDBConnMaxLTSec maxConn := config.SourceMaxConn maxIdle := config.SourceMaxIdle @@ -34,7 +82,8 @@ func TestInitSuccess(t *testing.T) { Port: config.SourcePort, Username: config.SourceUser, Password: config.SourcePassword, - DBName: config.TestSourceDataBase, + DBName: config.SourceDataBase, + SSLMode: config.SourceSSLMode, ConnectionMaxLifeTimeSec: &connectionMaxLifeTimeSec, MaxConn: &maxConn, MaxIdle: &maxIdle, @@ -43,7 +92,7 @@ func TestInitSuccess(t *testing.T) { err = mydb.Init(opt) assert.NoError(t, err) - // we make sure that all expectations were met + // We make sure that all expectations were met if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } @@ -65,14 +114,15 @@ func TestInitSuccess(t *testing.T) { Port: config.SourcePort, Username: config.SourceUser, Password: config.SourcePassword, - DBName: config.TestSourceDataBase, + DBName: config.SourceDataBase, + SSLMode: config.SourceSSLMode, ConnectionMaxLifeTimeSec: nil, MaxConn: nil, MaxIdle: nil, } err = mydb.Init(opt) - assert.Error(t, err, "dial tcp: lookup invalidhost: no such host") + assert.EqualError(t, err, "dial tcp: lookup invalidhost: no such host") }) t.Run("Failure_InvalidPort", func(t *testing.T) { @@ -91,14 +141,15 @@ func TestInitSuccess(t *testing.T) { Port: "invalidport", Username: config.SourceUser, Password: config.SourcePassword, - DBName: config.TestSourceDataBase, + DBName: config.SourceDataBase, + SSLMode: config.SourceSSLMode, ConnectionMaxLifeTimeSec: nil, MaxConn: nil, MaxIdle: nil, } err = mydb.Init(opt) - assert.Error(t, err, "dial tcp: lookup tcp/invalidport: unknown port") + assert.EqualError(t, err, "dial tcp: lookup tcp/invalidport: unknown port") }) t.Run("Failure_InvalidUser", func(t *testing.T) { @@ -115,14 +166,15 @@ func TestInitSuccess(t *testing.T) { Port: config.SourcePort, Username: "invalid", Password: "invalid", - DBName: config.TestSourceDataBase, + DBName: config.SourceDataBase, + SSLMode: config.SourceSSLMode, ConnectionMaxLifeTimeSec: nil, MaxConn: nil, MaxIdle: nil, } err = mydb.Init(opt) - assert.Error(t, err, "pq: password authentication failed for user \"invalid\"") + assert.EqualError(t, err, "pq: password authentication failed for user \"invalid\"") }) t.Run("Failure_InvalidDBName", func(t *testing.T) { @@ -140,13 +192,39 @@ func TestInitSuccess(t *testing.T) { Username: config.SourceUser, Password: config.SourcePassword, DBName: "invaliddbname", + SSLMode: config.SourceSSLMode, ConnectionMaxLifeTimeSec: nil, MaxConn: nil, MaxIdle: nil, } err = mydb.Init(opt) - assert.Error(t, err, "pq: database \"invaliddbname\" does not exist") + assert.EqualError(t, err, "pq: database \"invaliddbname\" does not exist") + }) + + t.Run("Failure_InvalidSSLMode", func(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectPing() + + opt := &mydb.Option{ + Host: config.SourceHost, + Port: config.SourcePort, + Username: config.SourceUser, + Password: config.SourcePassword, + DBName: config.SourceDataBase, + SSLMode: "invalidsslmode", + ConnectionMaxLifeTimeSec: nil, + MaxConn: nil, + MaxIdle: nil, + } + + err = mydb.Init(opt) + assert.EqualError(t, err, "pq: unsupported sslmode \"invalidsslmode\"; only \"require\" (default), \"verify-full\", \"verify-ca\", and \"disable\" supported") }) } @@ -167,7 +245,8 @@ func TestGetConn(t *testing.T) { Port: config.SourcePort, Username: config.SourceUser, Password: config.SourcePassword, - DBName: config.TestSourceDataBase, + DBName: config.SourceDataBase, + SSLMode: config.SourceSSLMode, ConnectionMaxLifeTimeSec: nil, MaxConn: nil, MaxIdle: nil, @@ -189,3 +268,28 @@ func TestGetConn(t *testing.T) { assert.Nil(t, conn) }) } + +func TestVerifyCertificate(t *testing.T) { + t.Run("Success", func(t *testing.T) { + certPath := mydb.GetRootCertPath() + + // Should be able to read the certificate file + certData, err := os.ReadFile(certPath) + assert.NoError(t, err) + + // Should be able to decode the PEM block + block, _ := pem.Decode(certData) + assert.NotNil(t, block) + + // Should be able to parse the certificate + cert, err := x509.ParseCertificate(block.Bytes) + assert.NoError(t, err) + + // Certificate issuer should match + assert.Equal(t, "CN=Amazon RDS Root 2019 CA,OU=Amazon RDS,O=Amazon Web Services\\, Inc.,L=Seattle,ST=Washington,C=US", cert.Issuer.String()) + + // Certificate should be valid + now := time.Now() + assert.True(t, now.After(cert.NotBefore) && now.Before(cert.NotAfter)) + }) +} diff --git a/internal/service/s_category_test.go b/internal/service/s_category_test.go index a8cb9fa..6dff388 100644 --- a/internal/service/s_category_test.go +++ b/internal/service/s_category_test.go @@ -28,7 +28,8 @@ func setUpCategory() error { mockConfigContent.WriteString("DB_PORT=" + config.SourcePort + "\n") mockConfigContent.WriteString("DB_USER=" + config.SourceUser + "\n") mockConfigContent.WriteString("DB_PASS=" + config.SourcePassword + "\n") - mockConfigContent.WriteString("DB_NAME=" + config.TestSourceDataBase + "\n") + mockConfigContent.WriteString("DB_NAME=" + config.SourceDataBase + "\n") + mockConfigContent.WriteString("SSL_MODE=" + config.SourceSSLMode + "\n") mockConfigContent.WriteString("DB_CONN_MAX_LT_SEC=" + strconv.Itoa(config.SourceDBConnMaxLTSec) + "\n") mockConfigContent.WriteString("DB_MAX_CONN=" + strconv.Itoa(config.SourceMaxConn) + "\n") mockConfigContent.WriteString("DB_MAX_IDLE=" + strconv.Itoa(config.SourceMaxIdle) + "\n") @@ -64,7 +65,8 @@ func setUpCategory() error { Port: config.SourcePort, Username: config.SourceUser, Password: config.SourcePassword, - DBName: config.TestSourceDataBase, + DBName: config.SourceDataBase, + SSLMode: config.SourceSSLMode, } err = db.Init(opt) diff --git a/internal/service/s_task_test.go b/internal/service/s_task_test.go index dbf27d3..1d3f2b2 100644 --- a/internal/service/s_task_test.go +++ b/internal/service/s_task_test.go @@ -30,7 +30,8 @@ func setUpTask() error { mockConfigContent.WriteString("DB_PORT=" + config.SourcePort + "\n") mockConfigContent.WriteString("DB_USER=" + config.SourceUser + "\n") mockConfigContent.WriteString("DB_PASS=" + config.SourcePassword + "\n") - mockConfigContent.WriteString("DB_NAME=" + config.TestSourceDataBase + "\n") + mockConfigContent.WriteString("DB_NAME=" + config.SourceDataBase + "\n") + mockConfigContent.WriteString("SSL_MODE=" + config.SourceSSLMode + "\n") mockConfigContent.WriteString("DB_CONN_MAX_LT_SEC=" + strconv.Itoa(config.SourceDBConnMaxLTSec) + "\n") mockConfigContent.WriteString("DB_MAX_CONN=" + strconv.Itoa(config.SourceMaxConn) + "\n") mockConfigContent.WriteString("DB_MAX_IDLE=" + strconv.Itoa(config.SourceMaxIdle) + "\n") @@ -66,7 +67,8 @@ func setUpTask() error { Port: config.SourcePort, Username: config.SourceUser, Password: config.SourcePassword, - DBName: config.TestSourceDataBase, + DBName: config.SourceDataBase, + SSLMode: config.SourceSSLMode, } err = db.Init(opt) diff --git a/internal/service/s_user_test.go b/internal/service/s_user_test.go index 8c035e3..c52ad83 100644 --- a/internal/service/s_user_test.go +++ b/internal/service/s_user_test.go @@ -28,7 +28,8 @@ func setUpUser() error { mockConfigContent.WriteString("DB_PORT=" + config.SourcePort + "\n") mockConfigContent.WriteString("DB_USER=" + config.SourceUser + "\n") mockConfigContent.WriteString("DB_PASS=" + config.SourcePassword + "\n") - mockConfigContent.WriteString("DB_NAME=" + config.TestSourceDataBase + "\n") + mockConfigContent.WriteString("DB_NAME=" + config.SourceDataBase + "\n") + mockConfigContent.WriteString("SSL_MODE=" + config.SourceSSLMode + "\n") mockConfigContent.WriteString("DB_CONN_MAX_LT_SEC=" + strconv.Itoa(config.SourceDBConnMaxLTSec) + "\n") mockConfigContent.WriteString("DB_MAX_CONN=" + strconv.Itoa(config.SourceMaxConn) + "\n") mockConfigContent.WriteString("DB_MAX_IDLE=" + strconv.Itoa(config.SourceMaxIdle) + "\n") @@ -64,7 +65,8 @@ func setUpUser() error { Port: config.SourcePort, Username: config.SourceUser, Password: config.SourcePassword, - DBName: config.TestSourceDataBase, + DBName: config.SourceDataBase, + SSLMode: config.SourceSSLMode, } err = db.Init(opt) From 6eeba890878cee41598e0e0a1c282e549bca8072 Mon Sep 17 00:00:00 2001 From: Frankie Date: Tue, 16 Jul 2024 02:48:30 +0800 Subject: [PATCH 2/3] Update the db unit test for CI. --- .github/workflows/test.yml | 7 +++++++ internal/pkg/db/db.go | 17 +++++++++++++++++ internal/pkg/db/db_test.go | 10 +++++++++- 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 503ee8d..2467a48 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,5 +47,12 @@ jobs: cp ./app.env.example ./app.env cp ./internal/config/env.go.example ./internal/config/env.go + - name: Set up SSL certificate + run: | + mkdir -p internal/config/certs + echo "${{ secrets.RDS_CA_CERT }}" > internal/config/certs/rds-ca-2019-root.pem + - name: Run tests + env: + RDS_CA_CERT: ${{ secrets.RDS_CA_CERT }} run: make go-test-ci diff --git a/internal/pkg/db/db.go b/internal/pkg/db/db.go index 2ff80aa..8f8775f 100644 --- a/internal/pkg/db/db.go +++ b/internal/pkg/db/db.go @@ -117,6 +117,23 @@ func ResetConn() { } func getRootCertPath() string { + // For CI to run the unit tests. + if certContent := os.Getenv("RDS_CA_CERT"); certContent != "" { + tempDir, err := os.MkdirTemp("", "rds-cert") + if err != nil { + log.Error.Printf("Failed to create temp directory: %v", err) + return "" + } + + tempFile := filepath.Join(tempDir, "rds-ca-2019-root.pem") + if err := os.WriteFile(tempFile, []byte(certContent), 0600); err != nil { + log.Error.Printf("Failed to write cert content: %v", err) + return "" + } + + return tempFile + } + _, b, _, _ := runtime.Caller(0) basepath := filepath.Dir(b) return filepath.Join(basepath, "..", "..", "config", "certs", "rds-ca-2019-root.pem") diff --git a/internal/pkg/db/db_test.go b/internal/pkg/db/db_test.go index a34c221..db9d1bf 100644 --- a/internal/pkg/db/db_test.go +++ b/internal/pkg/db/db_test.go @@ -17,6 +17,11 @@ import ( func TestInitSuccess(t *testing.T) { t.Run("Success_AWS", func(t *testing.T) { + certPath := mydb.GetRootCertPath() + if _, err := os.Stat(certPath); os.IsNotExist(err) { + t.Skip("Skipping AWS test: RDS certificate file not found") + } + connectionMaxLifeTimeSec := config.SourceDBConnMaxLTSec maxConn := config.SourceMaxConn maxIdle := config.SourceMaxIdle @@ -122,7 +127,7 @@ func TestInitSuccess(t *testing.T) { } err = mydb.Init(opt) - assert.EqualError(t, err, "dial tcp: lookup invalidhost: no such host") + assert.Contains(t, err.Error(), "dial tcp: lookup invalidhost") }) t.Run("Failure_InvalidPort", func(t *testing.T) { @@ -272,6 +277,9 @@ func TestGetConn(t *testing.T) { func TestVerifyCertificate(t *testing.T) { t.Run("Success", func(t *testing.T) { certPath := mydb.GetRootCertPath() + if _, err := os.Stat(certPath); os.IsNotExist(err) { + t.Skip("Skipping test: RDS certificate file not found") + } // Should be able to read the certificate file certData, err := os.ReadFile(certPath) From 7dea3716731ed37b6a5722ff3ee516ecce5d06ad Mon Sep 17 00:00:00 2001 From: Frankie Date: Tue, 16 Jul 2024 03:14:52 +0800 Subject: [PATCH 3/3] Update the db unit test logic. --- internal/pkg/db/db_test.go | 99 +++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 49 deletions(-) diff --git a/internal/pkg/db/db_test.go b/internal/pkg/db/db_test.go index db9d1bf..05c768f 100644 --- a/internal/pkg/db/db_test.go +++ b/internal/pkg/db/db_test.go @@ -16,55 +16,56 @@ import ( ) func TestInitSuccess(t *testing.T) { - t.Run("Success_AWS", func(t *testing.T) { - certPath := mydb.GetRootCertPath() - if _, err := os.Stat(certPath); os.IsNotExist(err) { - t.Skip("Skipping AWS test: RDS certificate file not found") - } - - connectionMaxLifeTimeSec := config.SourceDBConnMaxLTSec - maxConn := config.SourceMaxConn - maxIdle := config.SourceMaxIdle - log.Init(config.LogLevel, config.LogFolderPath, strconv.Itoa(os.Getpid()), config.EnableConsoleOutput, config.EnableFileOutput) - - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - defer db.Close() - - mock.ExpectPing() - - // Simulate the SSL query - rows := sqlmock.NewRows([]string{"ssl_is_used"}).AddRow(true) - mock.ExpectQuery("SELECT ssl_is_used()").WillReturnRows(rows) - - opt := &mydb.Option{ - Host: config.AWSSourceHost, - Port: config.SourcePort, - Username: config.AWSSourceUser, - Password: config.AWSSourcePassword, - DBName: config.AWSSourceDataBase, - SSLMode: config.AWSSourceSSLMode, - ConnectionMaxLifeTimeSec: &connectionMaxLifeTimeSec, - MaxConn: &maxConn, - MaxIdle: &maxIdle, - } - - err = mydb.Init(opt) - assert.NoError(t, err) - - // Checking the SSL status - var sslUsed bool - err = db.QueryRow("SELECT ssl_is_used()").Scan(&sslUsed) - assert.NoError(t, err) - assert.True(t, sslUsed, "SSL should be used") - - // We make sure that all expectations were met - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } - }) + // Performing this unit test may result in the exposure of AWS user information, which is only open to the local. + // t.Run("Success_AWS", func(t *testing.T) { + // certPath := mydb.GetRootCertPath() + // if _, err := os.Stat(certPath); os.IsNotExist(err) { + // t.Skip("Skipping AWS test: RDS certificate file not found") + // } + + // connectionMaxLifeTimeSec := config.SourceDBConnMaxLTSec + // maxConn := config.SourceMaxConn + // maxIdle := config.SourceMaxIdle + // log.Init(config.LogLevel, config.LogFolderPath, strconv.Itoa(os.Getpid()), config.EnableConsoleOutput, config.EnableFileOutput) + + // db, mock, err := sqlmock.New() + // if err != nil { + // t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + // } + // defer db.Close() + + // mock.ExpectPing() + + // // Simulate the SSL query + // rows := sqlmock.NewRows([]string{"ssl_is_used"}).AddRow(true) + // mock.ExpectQuery("SELECT ssl_is_used()").WillReturnRows(rows) + + // opt := &mydb.Option{ + // Host: config.AWSSourceHost, + // Port: config.SourcePort, + // Username: config.AWSSourceUser, + // Password: config.AWSSourcePassword, + // DBName: config.AWSSourceDataBase, + // SSLMode: config.AWSSourceSSLMode, + // ConnectionMaxLifeTimeSec: &connectionMaxLifeTimeSec, + // MaxConn: &maxConn, + // MaxIdle: &maxIdle, + // } + + // err = mydb.Init(opt) + // assert.NoError(t, err) + + // // Checking the SSL status + // var sslUsed bool + // err = db.QueryRow("SELECT ssl_is_used()").Scan(&sslUsed) + // assert.NoError(t, err) + // assert.True(t, sslUsed, "SSL should be used") + + // // We make sure that all expectations were met + // if err := mock.ExpectationsWereMet(); err != nil { + // t.Errorf("there were unfulfilled expectations: %s", err) + // } + // }) t.Run("Success_General", func(t *testing.T) { connectionMaxLifeTimeSec := config.SourceDBConnMaxLTSec