Move OpenPostgres

This commit is contained in:
binwiederhier
2026-03-02 19:52:36 -05:00
parent ea4739f79b
commit 8afeb813d9
7 changed files with 106 additions and 101 deletions

View File

@@ -11,7 +11,7 @@ import (
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc" "github.com/urfave/cli/v2/altsrc"
"heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/server" "heckel.io/ntfy/v2/server"
"heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/util"
@@ -379,7 +379,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval, QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
} }
if databaseURL != "" { if databaseURL != "" {
pool, dbErr := db.OpenPostgres(databaseURL) pool, dbErr := pg.Open(databaseURL)
if dbErr != nil { if dbErr != nil {
return nil, dbErr return nil, dbErr
} }

View File

@@ -2,96 +2,8 @@ package db
import ( import (
"database/sql" "database/sql"
"fmt"
"net/url"
"strconv"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
) )
const (
paramMaxOpenConns = "pool_max_conns"
paramMaxIdleConns = "pool_max_idle_conns"
paramConnMaxLifetime = "pool_conn_max_lifetime"
paramConnMaxIdleTime = "pool_conn_max_idle_time"
defaultMaxOpenConns = 10
)
// OpenPostgres opens a PostgreSQL database connection pool from a DSN string. It supports custom
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
// the DSN before passing it to the driver.
func OpenPostgres(dsn string) (*sql.DB, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid database URL: %w", err)
}
q := u.Query()
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
if err != nil {
return nil, err
}
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
if err != nil {
return nil, err
}
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
if err != nil {
return nil, err
}
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
if err != nil {
return nil, err
}
u.RawQuery = q.Encode()
db, err := sql.Open("pgx", u.String())
if err != nil {
return nil, err
}
db.SetMaxOpenConns(maxOpenConns)
if maxIdleConns > 0 {
db.SetMaxIdleConns(maxIdleConns)
}
if connMaxLifetime > 0 {
db.SetConnMaxLifetime(connMaxLifetime)
}
if connMaxIdleTime > 0 {
db.SetConnMaxIdleTime(connMaxIdleTime)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping failed: %w", err)
}
return db, nil
}
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
v, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return v, nil
}
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
d, err := time.ParseDuration(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return d, nil
}
// ExecTx executes a function within a database transaction. If the function returns an error, // ExecTx executes a function within a database transaction. If the function returns an error,
// the transaction is rolled back. Otherwise, the transaction is committed. // the transaction is rolled back. Otherwise, the transaction is committed.
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error { func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {

93
db/pg/pg.go Normal file
View File

@@ -0,0 +1,93 @@
package pg
import (
"database/sql"
"fmt"
"net/url"
"strconv"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
const (
paramMaxOpenConns = "pool_max_conns"
paramMaxIdleConns = "pool_max_idle_conns"
paramConnMaxLifetime = "pool_conn_max_lifetime"
paramConnMaxIdleTime = "pool_conn_max_idle_time"
defaultMaxOpenConns = 10
)
// Open opens a PostgreSQL database connection pool from a DSN string. It supports custom
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
// the DSN before passing it to the driver.
func Open(dsn string) (*sql.DB, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid database URL: %w", err)
}
q := u.Query()
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
if err != nil {
return nil, err
}
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
if err != nil {
return nil, err
}
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
if err != nil {
return nil, err
}
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
if err != nil {
return nil, err
}
u.RawQuery = q.Encode()
db, err := sql.Open("pgx", u.String())
if err != nil {
return nil, err
}
db.SetMaxOpenConns(maxOpenConns)
if maxIdleConns > 0 {
db.SetMaxIdleConns(maxIdleConns)
}
if connMaxLifetime > 0 {
db.SetConnMaxLifetime(connMaxLifetime)
}
if connMaxIdleTime > 0 {
db.SetConnMaxIdleTime(connMaxIdleTime)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping failed: %w", err)
}
return db, nil
}
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
v, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return v, nil
}
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
d, err := time.ParseDuration(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return d, nil
}

View File

@@ -8,7 +8,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/util"
) )
@@ -30,7 +30,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
q.Set("pool_max_conns", testPoolMaxConns) q.Set("pool_max_conns", testPoolMaxConns)
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
dsn = u.String() dsn = u.String()
setupDB, err := db.OpenPostgres(dsn) setupDB, err := pg.Open(dsn)
require.Nil(t, err) require.Nil(t, err)
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err) require.Nil(t, err)
@@ -39,7 +39,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
schemaDSN := u.String() schemaDSN := u.String()
t.Cleanup(func() { t.Cleanup(func() {
cleanDB, err := db.OpenPostgres(dsn) cleanDB, err := pg.Open(dsn)
if err == nil { if err == nil {
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanDB.Close() cleanDB.Close()
@@ -54,7 +54,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
func CreateTestPostgres(t *testing.T) *sql.DB { func CreateTestPostgres(t *testing.T) *sql.DB {
t.Helper() t.Helper()
schemaDSN := CreateTestPostgresSchema(t) schemaDSN := CreateTestPostgresSchema(t)
testDB, err := db.OpenPostgres(schemaDSN) testDB, err := pg.Open(schemaDSN)
require.Nil(t, err) require.Nil(t, err)
t.Cleanup(func() { t.Cleanup(func() {
testDB.Close() testDB.Close()

View File

@@ -33,7 +33,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/message" "heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/model"
@@ -178,11 +178,11 @@ func New(conf *Config) (*Server, error) {
if payments.Available && conf.StripeSecretKey != "" { if payments.Available && conf.StripeSecretKey != "" {
stripe = newStripeAPI() stripe = newStripeAPI()
} }
// OpenPostgres shared PostgreSQL connection pool if configured // Open shared PostgreSQL connection pool if configured
var pool *sql.DB var pool *sql.DB
if conf.DatabaseURL != "" { if conf.DatabaseURL != "" {
var err error var err error
pool, err = db.OpenPostgres(conf.DatabaseURL) pool, err = pg.Open(conf.DatabaseURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -11,7 +11,7 @@ import (
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc" "github.com/urfave/cli/v2/altsrc"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/db/pg"
) )
const ( const (
@@ -79,7 +79,7 @@ func execImport(c *cli.Context) error {
} }
fmt.Println() fmt.Println()
pgDB, err := db.OpenPostgres(databaseURL) pgDB, err := pg.Open(databaseURL)
if err != nil { if err != nil {
return fmt.Errorf("cannot connect to PostgreSQL: %w", err) return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
} }

View File

@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/db/pg"
dbtest "heckel.io/ntfy/v2/db/test" dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/util"
) )
@@ -36,7 +36,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc
t.Run("postgres", func(t *testing.T) { t.Run("postgres", func(t *testing.T) {
schemaDSN := dbtest.CreateTestPostgresSchema(t) schemaDSN := dbtest.CreateTestPostgresSchema(t)
f(t, func(config *Config) *Manager { f(t, func(config *Config) *Manager {
pool, err := db.OpenPostgres(schemaDSN) pool, err := pg.Open(schemaDSN)
require.Nil(t, err) require.Nil(t, err)
a, err := NewPostgresManager(pool, config) a, err := NewPostgresManager(pool, config)
require.Nil(t, err) require.Nil(t, err)