From 7fbde5e627bfe61657d636058863f88838dfc0ba Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Sun, 8 Dec 2024 20:14:44 -0500 Subject: [PATCH] always close db connection on error in Open This commit fixes a bug where the db connection would not always be closed when an error occurred in Open. Additionally, it makes sure that we always unregister any callbacks associated with the connection, which previously did not always happen. This change consolidates the error handling logic (previously, it had to be done for each return statement) which should make it more difficult to introduce this type of bug in the future. --- sqlite3.go | 500 ++++++++++++++++++++++++------------------------ sqlite3_test.go | 16 ++ 2 files changed, 269 insertions(+), 247 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index 3025a500..d886ee9d 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1487,296 +1487,302 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil } - // Busy timeout - if err := exec(fmt.Sprintf("PRAGMA busy_timeout = %d;", busyTimeout)); err != nil { - C.sqlite3_close_v2(db) - return nil, err - } - - // USER AUTHENTICATION - // - // User Authentication is always performed even when - // sqlite_userauth is not compiled in, because without user authentication - // the authentication is a no-op. - // - // Workflow - // - Authenticate - // ON::SUCCESS => Continue - // ON::SQLITE_AUTH => Return error and exit Open(...) - // - // - Activate User Authentication - // Check if the user wants to activate User Authentication. - // If so then first create a temporary AuthConn to the database - // This is possible because we are already successfully authenticated. - // - // - Check if `sqlite_user`` table exists - // YES => Add the provided user from DSN as Admin User and - // activate user authentication. - // NO => Continue - // - // Create connection to SQLite conn := &SQLiteConn{db: db, loc: loc, txlock: txlock} - // Password Cipher has to be registered before authentication - if len(authCrypt) > 0 { - switch strings.ToUpper(authCrypt) { - case "SHA1": - if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSHA1, true); err != nil { - return nil, fmt.Errorf("CryptEncoderSHA1: %s", err) - } - case "SSHA1": - if len(authSalt) == 0 { - return nil, fmt.Errorf("_auth_crypt=ssha1, requires _auth_salt") - } - if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSSHA1(authSalt), true); err != nil { - return nil, fmt.Errorf("CryptEncoderSSHA1: %s", err) - } - case "SHA256": - if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSHA256, true); err != nil { - return nil, fmt.Errorf("CryptEncoderSHA256: %s", err) - } - case "SSHA256": - if len(authSalt) == 0 { - return nil, fmt.Errorf("_auth_crypt=ssha256, requires _auth_salt") - } - if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSSHA256(authSalt), true); err != nil { - return nil, fmt.Errorf("CryptEncoderSSHA256: %s", err) - } - case "SHA384": - if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSHA384, true); err != nil { - return nil, fmt.Errorf("CryptEncoderSHA384: %s", err) - } - case "SSHA384": - if len(authSalt) == 0 { - return nil, fmt.Errorf("_auth_crypt=ssha384, requires _auth_salt") - } - if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSSHA384(authSalt), true); err != nil { - return nil, fmt.Errorf("CryptEncoderSSHA384: %s", err) - } - case "SHA512": - if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSHA512, true); err != nil { - return nil, fmt.Errorf("CryptEncoderSHA512: %s", err) - } - case "SSHA512": - if len(authSalt) == 0 { - return nil, fmt.Errorf("_auth_crypt=ssha512, requires _auth_salt") - } - if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSSHA512(authSalt), true); err != nil { - return nil, fmt.Errorf("CryptEncoderSSHA512: %s", err) + // All further database configuration *must* occur within this function, as + // it simplifies closing the database connection and unregistering associated + // functions or callbacks. + // + // Previously, this was handled at each return statement, which was error-prone + // and led to bugs because the database connection was not always closed. + err := func() error { + // Remove the externally managed DB conn in the event of a panic. + defer func() { + if e := recover(); e != nil { + _ = conn.Close() + panic(e) } + }() + + // Busy timeout + if err := exec(fmt.Sprintf("PRAGMA busy_timeout = %d;", busyTimeout)); err != nil { + return err } - } - // Preform Authentication - if err := conn.Authenticate(authUser, authPass); err != nil { - return nil, err - } + // USER AUTHENTICATION + // + // User Authentication is always performed even when + // sqlite_userauth is not compiled in, because without user authentication + // the authentication is a no-op. + // + // Workflow + // - Authenticate + // ON::SUCCESS => Continue + // ON::SQLITE_AUTH => Return error and exit Open(...) + // + // - Activate User Authentication + // Check if the user wants to activate User Authentication. + // If so then first create a temporary AuthConn to the database + // This is possible because we are already successfully authenticated. + // + // - Check if `sqlite_user`` table exists + // YES => Add the provided user from DSN as Admin User and + // activate user authentication. + // NO => Continue + // - // Register: authenticate - // Authenticate will perform an authentication of the provided username - // and password against the database. - // - // If a database contains the SQLITE_USER table, then the - // call to Authenticate must be invoked with an - // appropriate username and password prior to enable read and write - //access to the database. - // - // Return SQLITE_OK on success or SQLITE_ERROR if the username/password - // combination is incorrect or unknown. - // - // If the SQLITE_USER table is not present in the database file, then - // this interface is a harmless no-op returnning SQLITE_OK. - if err := conn.RegisterFunc("authenticate", conn.authenticate, true); err != nil { - return nil, err - } - // - // Register: auth_user_add - // auth_user_add can be used (by an admin user only) - // to create a new user. When called on a no-authentication-required - // database, this routine converts the database into an authentication- - // required database, automatically makes the added user an - // administrator, and logs in the current connection as that user. - // The AuthUserAdd only works for the "main" database, not - // for any ATTACH-ed databases. Any call to AuthUserAdd by a - // non-admin user results in an error. - if err := conn.RegisterFunc("auth_user_add", conn.authUserAdd, true); err != nil { - return nil, err - } - // - // Register: auth_user_change - // auth_user_change can be used to change a users - // login credentials or admin privilege. Any user can change their own - // login credentials. Only an admin user can change another users login - // credentials or admin privilege setting. No user may change their own - // admin privilege setting. - if err := conn.RegisterFunc("auth_user_change", conn.authUserChange, true); err != nil { - return nil, err - } - // - // Register: auth_user_delete - // auth_user_delete can be used (by an admin user only) - // to delete a user. The currently logged-in user cannot be deleted, - // which guarantees that there is always an admin user and hence that - // the database cannot be converted into a no-authentication-required - // database. - if err := conn.RegisterFunc("auth_user_delete", conn.authUserDelete, true); err != nil { - return nil, err - } + // Password Cipher has to be registered before authentication + if len(authCrypt) > 0 { + switch strings.ToUpper(authCrypt) { + case "SHA1": + if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSHA1, true); err != nil { + return fmt.Errorf("CryptEncoderSHA1: %s", err) + } + case "SSHA1": + if len(authSalt) == 0 { + return fmt.Errorf("_auth_crypt=ssha1, requires _auth_salt") + } + if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSSHA1(authSalt), true); err != nil { + return fmt.Errorf("CryptEncoderSSHA1: %s", err) + } + case "SHA256": + if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSHA256, true); err != nil { + return fmt.Errorf("CryptEncoderSHA256: %s", err) + } + case "SSHA256": + if len(authSalt) == 0 { + return fmt.Errorf("_auth_crypt=ssha256, requires _auth_salt") + } + if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSSHA256(authSalt), true); err != nil { + return fmt.Errorf("CryptEncoderSSHA256: %s", err) + } + case "SHA384": + if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSHA384, true); err != nil { + return fmt.Errorf("CryptEncoderSHA384: %s", err) + } + case "SSHA384": + if len(authSalt) == 0 { + return fmt.Errorf("_auth_crypt=ssha384, requires _auth_salt") + } + if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSSHA384(authSalt), true); err != nil { + return fmt.Errorf("CryptEncoderSSHA384: %s", err) + } + case "SHA512": + if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSHA512, true); err != nil { + return fmt.Errorf("CryptEncoderSHA512: %s", err) + } + case "SSHA512": + if len(authSalt) == 0 { + return fmt.Errorf("_auth_crypt=ssha512, requires _auth_salt") + } + if err := conn.RegisterFunc("sqlite_crypt", CryptEncoderSSHA512(authSalt), true); err != nil { + return fmt.Errorf("CryptEncoderSSHA512: %s", err) + } + } + } - // Register: auth_enabled - // auth_enabled can be used to check if user authentication is enabled - if err := conn.RegisterFunc("auth_enabled", conn.authEnabled, true); err != nil { - return nil, err - } + // Preform Authentication + if err := conn.Authenticate(authUser, authPass); err != nil { + return err + } - // Auto Vacuum - // Moved auto_vacuum command, the user preference for auto_vacuum needs to be implemented directly after - // the authentication and before the sqlite_user table gets created if the user - // decides to activate User Authentication because - // auto_vacuum needs to be set before any tables are created - // and activating user authentication creates the internal table `sqlite_user`. - if autoVacuum > -1 { - if err := exec(fmt.Sprintf("PRAGMA auto_vacuum = %d;", autoVacuum)); err != nil { - C.sqlite3_close_v2(db) - return nil, err + // Register: authenticate + // Authenticate will perform an authentication of the provided username + // and password against the database. + // + // If a database contains the SQLITE_USER table, then the + // call to Authenticate must be invoked with an + // appropriate username and password prior to enable read and write + //access to the database. + // + // Return SQLITE_OK on success or SQLITE_ERROR if the username/password + // combination is incorrect or unknown. + // + // If the SQLITE_USER table is not present in the database file, then + // this interface is a harmless no-op returnning SQLITE_OK. + if err := conn.RegisterFunc("authenticate", conn.authenticate, true); err != nil { + return err + } + // + // Register: auth_user_add + // auth_user_add can be used (by an admin user only) + // to create a new user. When called on a no-authentication-required + // database, this routine converts the database into an authentication- + // required database, automatically makes the added user an + // administrator, and logs in the current connection as that user. + // The AuthUserAdd only works for the "main" database, not + // for any ATTACH-ed databases. Any call to AuthUserAdd by a + // non-admin user results in an error. + if err := conn.RegisterFunc("auth_user_add", conn.authUserAdd, true); err != nil { + return err + } + // + // Register: auth_user_change + // auth_user_change can be used to change a users + // login credentials or admin privilege. Any user can change their own + // login credentials. Only an admin user can change another users login + // credentials or admin privilege setting. No user may change their own + // admin privilege setting. + if err := conn.RegisterFunc("auth_user_change", conn.authUserChange, true); err != nil { + return err + } + // + // Register: auth_user_delete + // auth_user_delete can be used (by an admin user only) + // to delete a user. The currently logged-in user cannot be deleted, + // which guarantees that there is always an admin user and hence that + // the database cannot be converted into a no-authentication-required + // database. + if err := conn.RegisterFunc("auth_user_delete", conn.authUserDelete, true); err != nil { + return err } - } - // Check if user wants to activate User Authentication - if authCreate { - // Before going any further, we need to check that the user - // has provided an username and password within the DSN. - // We are not allowed to continue. - if len(authUser) == 0 { - return nil, fmt.Errorf("Missing '_auth_user' while user authentication was requested with '_auth'") + // Register: auth_enabled + // auth_enabled can be used to check if user authentication is enabled + if err := conn.RegisterFunc("auth_enabled", conn.authEnabled, true); err != nil { + return err } - if len(authPass) == 0 { - return nil, fmt.Errorf("Missing '_auth_pass' while user authentication was requested with '_auth'") + + // Auto Vacuum + // Moved auto_vacuum command, the user preference for auto_vacuum needs to be implemented directly after + // the authentication and before the sqlite_user table gets created if the user + // decides to activate User Authentication because + // auto_vacuum needs to be set before any tables are created + // and activating user authentication creates the internal table `sqlite_user`. + if autoVacuum > -1 { + if err := exec(fmt.Sprintf("PRAGMA auto_vacuum = %d;", autoVacuum)); err != nil { + return err + } } - // Check if User Authentication is Enabled - authExists := conn.AuthEnabled() - if !authExists { - if err := conn.AuthUserAdd(authUser, authPass, true); err != nil { - return nil, err + // Check if user wants to activate User Authentication + if authCreate { + // Before going any further, we need to check that the user + // has provided an username and password within the DSN. + // We are not allowed to continue. + if len(authUser) == 0 { + return fmt.Errorf("Missing '_auth_user' while user authentication was requested with '_auth'") + } + if len(authPass) == 0 { + return fmt.Errorf("Missing '_auth_pass' while user authentication was requested with '_auth'") + } + + // Check if User Authentication is Enabled + authExists := conn.AuthEnabled() + if !authExists { + if err := conn.AuthUserAdd(authUser, authPass, true); err != nil { + return err + } } } - } - // Case Sensitive LIKE - if caseSensitiveLike > -1 { - if err := exec(fmt.Sprintf("PRAGMA case_sensitive_like = %d;", caseSensitiveLike)); err != nil { - C.sqlite3_close_v2(db) - return nil, err + // Case Sensitive LIKE + if caseSensitiveLike > -1 { + if err := exec(fmt.Sprintf("PRAGMA case_sensitive_like = %d;", caseSensitiveLike)); err != nil { + return err + } } - } - // Defer Foreign Keys - if deferForeignKeys > -1 { - if err := exec(fmt.Sprintf("PRAGMA defer_foreign_keys = %d;", deferForeignKeys)); err != nil { - C.sqlite3_close_v2(db) - return nil, err + // Defer Foreign Keys + if deferForeignKeys > -1 { + if err := exec(fmt.Sprintf("PRAGMA defer_foreign_keys = %d;", deferForeignKeys)); err != nil { + return err + } } - } - // Foreign Keys - if foreignKeys > -1 { - if err := exec(fmt.Sprintf("PRAGMA foreign_keys = %d;", foreignKeys)); err != nil { - C.sqlite3_close_v2(db) - return nil, err + // Foreign Keys + if foreignKeys > -1 { + if err := exec(fmt.Sprintf("PRAGMA foreign_keys = %d;", foreignKeys)); err != nil { + return err + } } - } - // Ignore CHECK Constraints - if ignoreCheckConstraints > -1 { - if err := exec(fmt.Sprintf("PRAGMA ignore_check_constraints = %d;", ignoreCheckConstraints)); err != nil { - C.sqlite3_close_v2(db) - return nil, err + // Ignore CHECK Constraints + if ignoreCheckConstraints > -1 { + if err := exec(fmt.Sprintf("PRAGMA ignore_check_constraints = %d;", ignoreCheckConstraints)); err != nil { + return err + } } - } - // Journal Mode - if journalMode != "" { - if err := exec(fmt.Sprintf("PRAGMA journal_mode = %s;", journalMode)); err != nil { - C.sqlite3_close_v2(db) - return nil, err + // Journal Mode + if journalMode != "" { + if err := exec(fmt.Sprintf("PRAGMA journal_mode = %s;", journalMode)); err != nil { + return err + } } - } - // Locking Mode - // Because the default is NORMAL and this is not changed in this package - // by using the compile time SQLITE_DEFAULT_LOCKING_MODE this PRAGMA can always be executed - if err := exec(fmt.Sprintf("PRAGMA locking_mode = %s;", lockingMode)); err != nil { - C.sqlite3_close_v2(db) - return nil, err - } + // Locking Mode + // Because the default is NORMAL and this is not changed in this package + // by using the compile time SQLITE_DEFAULT_LOCKING_MODE this PRAGMA can always be executed + if err := exec(fmt.Sprintf("PRAGMA locking_mode = %s;", lockingMode)); err != nil { + return err + } - // Query Only - if queryOnly > -1 { - if err := exec(fmt.Sprintf("PRAGMA query_only = %d;", queryOnly)); err != nil { - C.sqlite3_close_v2(db) - return nil, err + // Query Only + if queryOnly > -1 { + if err := exec(fmt.Sprintf("PRAGMA query_only = %d;", queryOnly)); err != nil { + return err + } } - } - // Recursive Triggers - if recursiveTriggers > -1 { - if err := exec(fmt.Sprintf("PRAGMA recursive_triggers = %d;", recursiveTriggers)); err != nil { - C.sqlite3_close_v2(db) - return nil, err + // Recursive Triggers + if recursiveTriggers > -1 { + if err := exec(fmt.Sprintf("PRAGMA recursive_triggers = %d;", recursiveTriggers)); err != nil { + return err + } } - } - // Secure Delete - // - // Because this package can set the compile time flag SQLITE_SECURE_DELETE with a build tag - // the default value for secureDelete var is 'DEFAULT' this way - // you can compile with secure_delete 'ON' and disable it for a specific database connection. - if secureDelete != "DEFAULT" { - if err := exec(fmt.Sprintf("PRAGMA secure_delete = %s;", secureDelete)); err != nil { - C.sqlite3_close_v2(db) - return nil, err + // Secure Delete + // + // Because this package can set the compile time flag SQLITE_SECURE_DELETE with a build tag + // the default value for secureDelete var is 'DEFAULT' this way + // you can compile with secure_delete 'ON' and disable it for a specific database connection. + if secureDelete != "DEFAULT" { + if err := exec(fmt.Sprintf("PRAGMA secure_delete = %s;", secureDelete)); err != nil { + return err + } } - } - // Synchronous Mode - // - // Because default is NORMAL this statement is always executed - if err := exec(fmt.Sprintf("PRAGMA synchronous = %s;", synchronousMode)); err != nil { - conn.Close() - return nil, err - } + // Synchronous Mode + // + // Because default is NORMAL this statement is always executed + if err := exec(fmt.Sprintf("PRAGMA synchronous = %s;", synchronousMode)); err != nil { + return err + } - // Writable Schema - if writableSchema > -1 { - if err := exec(fmt.Sprintf("PRAGMA writable_schema = %d;", writableSchema)); err != nil { - C.sqlite3_close_v2(db) - return nil, err + // Writable Schema + if writableSchema > -1 { + if err := exec(fmt.Sprintf("PRAGMA writable_schema = %d;", writableSchema)); err != nil { + return err + } } - } - // Cache Size - if cacheSize != nil { - if err := exec(fmt.Sprintf("PRAGMA cache_size = %d;", *cacheSize)); err != nil { - C.sqlite3_close_v2(db) - return nil, err + // Cache Size + if cacheSize != nil { + if err := exec(fmt.Sprintf("PRAGMA cache_size = %d;", *cacheSize)); err != nil { + return err + } } - } - if len(d.Extensions) > 0 { - if err := conn.loadExtensions(d.Extensions); err != nil { - conn.Close() - return nil, err + if len(d.Extensions) > 0 { + if err := conn.loadExtensions(d.Extensions); err != nil { + return err + } } - } - if d.ConnectHook != nil { - if err := d.ConnectHook(conn); err != nil { - conn.Close() - return nil, err + if d.ConnectHook != nil { + if err := d.ConnectHook(conn); err != nil { + return err + } } + return nil + }() + if err != nil { + _ = conn.Close() + return nil, err } + runtime.SetFinalizer(conn, (*SQLiteConn).Close) return conn, nil } diff --git a/sqlite3_test.go b/sqlite3_test.go index 94de7386..f38c92af 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -2112,6 +2112,7 @@ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkRows", F: benchmarkRows}, {Name: "BenchmarkStmtRows", F: benchmarkStmtRows}, {Name: "BenchmarkQueryParallel", F: benchmarkQueryParallel}, + {Name: "BenchmarkOpen", F: benchmarkOpen}, } func (db *TestDB) mustExec(sql string, args ...any) sql.Result { @@ -2586,3 +2587,18 @@ func benchmarkQueryParallel(b *testing.B) { } }) } + +func benchmarkOpen(b *testing.B) { + for i := 0; i < b.N; i++ { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + b.Fatal(err) + } + if err := db.Ping(); err != nil { + b.Fatal(err) + } + if err := db.Close(); err != nil { + b.Fatal(err) + } + } +}