diff --git a/cmd/user.go b/cmd/user.go index e0041151..1ffc3e6b 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -11,7 +11,7 @@ import ( "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" - "heckel.io/ntfy/v2/db" + "heckel.io/ntfy/v2/db/pg" "heckel.io/ntfy/v2/server" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" @@ -379,7 +379,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) { QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval, } if databaseURL != "" { - pool, dbErr := db.OpenPostgres(databaseURL) + pool, dbErr := pg.Open(databaseURL) if dbErr != nil { return nil, dbErr } diff --git a/db/db.go b/db/db.go index 3b6585c1..00ae91f4 100644 --- a/db/db.go +++ b/db/db.go @@ -2,96 +2,8 @@ package db import ( "database/sql" - "fmt" - "net/url" - "strconv" - "time" - - _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver ) -const ( - paramMaxOpenConns = "pool_max_conns" - paramMaxIdleConns = "pool_max_idle_conns" - paramConnMaxLifetime = "pool_conn_max_lifetime" - paramConnMaxIdleTime = "pool_conn_max_idle_time" - - defaultMaxOpenConns = 10 -) - -// OpenPostgres opens a PostgreSQL database connection pool from a DSN string. It supports custom -// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns, -// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from -// the DSN before passing it to the driver. -func OpenPostgres(dsn string) (*sql.DB, error) { - u, err := url.Parse(dsn) - if err != nil { - return nil, fmt.Errorf("invalid database URL: %w", err) - } - q := u.Query() - maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns) - if err != nil { - return nil, err - } - maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0) - if err != nil { - return nil, err - } - connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0) - if err != nil { - return nil, err - } - connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0) - if err != nil { - return nil, err - } - u.RawQuery = q.Encode() - db, err := sql.Open("pgx", u.String()) - if err != nil { - return nil, err - } - db.SetMaxOpenConns(maxOpenConns) - if maxIdleConns > 0 { - db.SetMaxIdleConns(maxIdleConns) - } - if connMaxLifetime > 0 { - db.SetConnMaxLifetime(connMaxLifetime) - } - if connMaxIdleTime > 0 { - db.SetConnMaxIdleTime(connMaxIdleTime) - } - if err := db.Ping(); err != nil { - return nil, fmt.Errorf("ping failed: %w", err) - } - return db, nil -} - -func extractIntParam(q url.Values, key string, defaultValue int) (int, error) { - s := q.Get(key) - if s == "" { - return defaultValue, nil - } - q.Del(key) - v, err := strconv.Atoi(s) - if err != nil { - return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err) - } - return v, nil -} - -func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) { - s := q.Get(key) - if s == "" { - return defaultValue, nil - } - q.Del(key) - d, err := time.ParseDuration(s) - if err != nil { - return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err) - } - return d, nil -} - // ExecTx executes a function within a database transaction. If the function returns an error, // the transaction is rolled back. Otherwise, the transaction is committed. func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error { diff --git a/db/pg/pg.go b/db/pg/pg.go new file mode 100644 index 00000000..99f802b3 --- /dev/null +++ b/db/pg/pg.go @@ -0,0 +1,93 @@ +package pg + +import ( + "database/sql" + "fmt" + "net/url" + "strconv" + "time" + + _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver +) + +const ( + paramMaxOpenConns = "pool_max_conns" + paramMaxIdleConns = "pool_max_idle_conns" + paramConnMaxLifetime = "pool_conn_max_lifetime" + paramConnMaxIdleTime = "pool_conn_max_idle_time" + + defaultMaxOpenConns = 10 +) + +// Open opens a PostgreSQL database connection pool from a DSN string. It supports custom +// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns, +// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from +// the DSN before passing it to the driver. +func Open(dsn string) (*sql.DB, error) { + u, err := url.Parse(dsn) + if err != nil { + return nil, fmt.Errorf("invalid database URL: %w", err) + } + q := u.Query() + maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns) + if err != nil { + return nil, err + } + maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0) + if err != nil { + return nil, err + } + connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0) + if err != nil { + return nil, err + } + connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0) + if err != nil { + return nil, err + } + u.RawQuery = q.Encode() + db, err := sql.Open("pgx", u.String()) + if err != nil { + return nil, err + } + db.SetMaxOpenConns(maxOpenConns) + if maxIdleConns > 0 { + db.SetMaxIdleConns(maxIdleConns) + } + if connMaxLifetime > 0 { + db.SetConnMaxLifetime(connMaxLifetime) + } + if connMaxIdleTime > 0 { + db.SetConnMaxIdleTime(connMaxIdleTime) + } + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("ping failed: %w", err) + } + return db, nil +} + +func extractIntParam(q url.Values, key string, defaultValue int) (int, error) { + s := q.Get(key) + if s == "" { + return defaultValue, nil + } + q.Del(key) + v, err := strconv.Atoi(s) + if err != nil { + return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err) + } + return v, nil +} + +func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) { + s := q.Get(key) + if s == "" { + return defaultValue, nil + } + q.Del(key) + d, err := time.ParseDuration(s) + if err != nil { + return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err) + } + return d, nil +} diff --git a/db/test/test.go b/db/test/test.go index d843becb..36c3fc86 100644 --- a/db/test/test.go +++ b/db/test/test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/stretchr/testify/require" - "heckel.io/ntfy/v2/db" + "heckel.io/ntfy/v2/db/pg" "heckel.io/ntfy/v2/util" ) @@ -30,7 +30,7 @@ func CreateTestPostgresSchema(t *testing.T) string { q.Set("pool_max_conns", testPoolMaxConns) u.RawQuery = q.Encode() dsn = u.String() - setupDB, err := db.OpenPostgres(dsn) + setupDB, err := pg.Open(dsn) require.Nil(t, err) _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) require.Nil(t, err) @@ -39,7 +39,7 @@ func CreateTestPostgresSchema(t *testing.T) string { u.RawQuery = q.Encode() schemaDSN := u.String() t.Cleanup(func() { - cleanDB, err := db.OpenPostgres(dsn) + cleanDB, err := pg.Open(dsn) if err == nil { cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) cleanDB.Close() @@ -54,7 +54,7 @@ func CreateTestPostgresSchema(t *testing.T) string { func CreateTestPostgres(t *testing.T) *sql.DB { t.Helper() schemaDSN := CreateTestPostgresSchema(t) - testDB, err := db.OpenPostgres(schemaDSN) + testDB, err := pg.Open(schemaDSN) require.Nil(t, err) t.Cleanup(func() { testDB.Close() diff --git a/server/server.go b/server/server.go index 2c09b56a..329b0ab5 100644 --- a/server/server.go +++ b/server/server.go @@ -33,7 +33,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v2" - "heckel.io/ntfy/v2/db" + "heckel.io/ntfy/v2/db/pg" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/message" "heckel.io/ntfy/v2/model" @@ -178,11 +178,11 @@ func New(conf *Config) (*Server, error) { if payments.Available && conf.StripeSecretKey != "" { stripe = newStripeAPI() } - // OpenPostgres shared PostgreSQL connection pool if configured + // Open shared PostgreSQL connection pool if configured var pool *sql.DB if conf.DatabaseURL != "" { var err error - pool, err = db.OpenPostgres(conf.DatabaseURL) + pool, err = pg.Open(conf.DatabaseURL) if err != nil { return nil, err } diff --git a/tools/pgimport/main.go b/tools/pgimport/main.go index 9637d33d..9af2212d 100644 --- a/tools/pgimport/main.go +++ b/tools/pgimport/main.go @@ -11,7 +11,7 @@ import ( "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" "gopkg.in/yaml.v2" - "heckel.io/ntfy/v2/db" + "heckel.io/ntfy/v2/db/pg" ) const ( @@ -79,7 +79,7 @@ func execImport(c *cli.Context) error { } fmt.Println() - pgDB, err := db.OpenPostgres(databaseURL) + pgDB, err := pg.Open(databaseURL) if err != nil { return fmt.Errorf("cannot connect to PostgreSQL: %w", err) } diff --git a/user/manager_test.go b/user/manager_test.go index fd1e59fc..c353acb8 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" - "heckel.io/ntfy/v2/db" + "heckel.io/ntfy/v2/db/pg" dbtest "heckel.io/ntfy/v2/db/test" "heckel.io/ntfy/v2/util" ) @@ -36,7 +36,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc t.Run("postgres", func(t *testing.T) { schemaDSN := dbtest.CreateTestPostgresSchema(t) f(t, func(config *Config) *Manager { - pool, err := db.OpenPostgres(schemaDSN) + pool, err := pg.Open(schemaDSN) require.Nil(t, err) a, err := NewPostgresManager(pool, config) require.Nil(t, err)