Move OpenPostgres

This commit is contained in:
binwiederhier
2026-03-02 19:52:36 -05:00
parent ea4739f79b
commit 8afeb813d9
7 changed files with 106 additions and 101 deletions

View File

@@ -11,7 +11,7 @@ import (
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/server"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
@@ -379,7 +379,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
}
if databaseURL != "" {
pool, dbErr := db.OpenPostgres(databaseURL)
pool, dbErr := pg.Open(databaseURL)
if dbErr != nil {
return nil, dbErr
}

View File

@@ -2,96 +2,8 @@ package db
import (
"database/sql"
"fmt"
"net/url"
"strconv"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
const (
paramMaxOpenConns = "pool_max_conns"
paramMaxIdleConns = "pool_max_idle_conns"
paramConnMaxLifetime = "pool_conn_max_lifetime"
paramConnMaxIdleTime = "pool_conn_max_idle_time"
defaultMaxOpenConns = 10
)
// OpenPostgres opens a PostgreSQL database connection pool from a DSN string. It supports custom
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
// the DSN before passing it to the driver.
func OpenPostgres(dsn string) (*sql.DB, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid database URL: %w", err)
}
q := u.Query()
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
if err != nil {
return nil, err
}
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
if err != nil {
return nil, err
}
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
if err != nil {
return nil, err
}
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
if err != nil {
return nil, err
}
u.RawQuery = q.Encode()
db, err := sql.Open("pgx", u.String())
if err != nil {
return nil, err
}
db.SetMaxOpenConns(maxOpenConns)
if maxIdleConns > 0 {
db.SetMaxIdleConns(maxIdleConns)
}
if connMaxLifetime > 0 {
db.SetConnMaxLifetime(connMaxLifetime)
}
if connMaxIdleTime > 0 {
db.SetConnMaxIdleTime(connMaxIdleTime)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping failed: %w", err)
}
return db, nil
}
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
v, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return v, nil
}
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
d, err := time.ParseDuration(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return d, nil
}
// ExecTx executes a function within a database transaction. If the function returns an error,
// the transaction is rolled back. Otherwise, the transaction is committed.
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {

93
db/pg/pg.go Normal file
View File

@@ -0,0 +1,93 @@
package pg
import (
"database/sql"
"fmt"
"net/url"
"strconv"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
const (
paramMaxOpenConns = "pool_max_conns"
paramMaxIdleConns = "pool_max_idle_conns"
paramConnMaxLifetime = "pool_conn_max_lifetime"
paramConnMaxIdleTime = "pool_conn_max_idle_time"
defaultMaxOpenConns = 10
)
// Open opens a PostgreSQL database connection pool from a DSN string. It supports custom
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
// the DSN before passing it to the driver.
func Open(dsn string) (*sql.DB, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid database URL: %w", err)
}
q := u.Query()
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
if err != nil {
return nil, err
}
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
if err != nil {
return nil, err
}
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
if err != nil {
return nil, err
}
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
if err != nil {
return nil, err
}
u.RawQuery = q.Encode()
db, err := sql.Open("pgx", u.String())
if err != nil {
return nil, err
}
db.SetMaxOpenConns(maxOpenConns)
if maxIdleConns > 0 {
db.SetMaxIdleConns(maxIdleConns)
}
if connMaxLifetime > 0 {
db.SetConnMaxLifetime(connMaxLifetime)
}
if connMaxIdleTime > 0 {
db.SetConnMaxIdleTime(connMaxIdleTime)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping failed: %w", err)
}
return db, nil
}
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
v, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return v, nil
}
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
d, err := time.ParseDuration(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return d, nil
}

View File

@@ -8,7 +8,7 @@ import (
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/util"
)
@@ -30,7 +30,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
q.Set("pool_max_conns", testPoolMaxConns)
u.RawQuery = q.Encode()
dsn = u.String()
setupDB, err := db.OpenPostgres(dsn)
setupDB, err := pg.Open(dsn)
require.Nil(t, err)
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err)
@@ -39,7 +39,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
u.RawQuery = q.Encode()
schemaDSN := u.String()
t.Cleanup(func() {
cleanDB, err := db.OpenPostgres(dsn)
cleanDB, err := pg.Open(dsn)
if err == nil {
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanDB.Close()
@@ -54,7 +54,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
func CreateTestPostgres(t *testing.T) *sql.DB {
t.Helper()
schemaDSN := CreateTestPostgresSchema(t)
testDB, err := db.OpenPostgres(schemaDSN)
testDB, err := pg.Open(schemaDSN)
require.Nil(t, err)
t.Cleanup(func() {
testDB.Close()

View File

@@ -33,7 +33,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
@@ -178,11 +178,11 @@ func New(conf *Config) (*Server, error) {
if payments.Available && conf.StripeSecretKey != "" {
stripe = newStripeAPI()
}
// OpenPostgres shared PostgreSQL connection pool if configured
// Open shared PostgreSQL connection pool if configured
var pool *sql.DB
if conf.DatabaseURL != "" {
var err error
pool, err = db.OpenPostgres(conf.DatabaseURL)
pool, err = pg.Open(conf.DatabaseURL)
if err != nil {
return nil, err
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
)
const (
@@ -79,7 +79,7 @@ func execImport(c *cli.Context) error {
}
fmt.Println()
pgDB, err := db.OpenPostgres(databaseURL)
pgDB, err := pg.Open(databaseURL)
if err != nil {
return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/util"
)
@@ -36,7 +36,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc
t.Run("postgres", func(t *testing.T) {
schemaDSN := dbtest.CreateTestPostgresSchema(t)
f(t, func(config *Config) *Manager {
pool, err := db.OpenPostgres(schemaDSN)
pool, err := pg.Open(schemaDSN)
require.Nil(t, err)
a, err := NewPostgresManager(pool, config)
require.Nil(t, err)