mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
Move OpenPostgres
This commit is contained in:
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
"github.com/urfave/cli/v2/altsrc"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/server"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
@@ -379,7 +379,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
||||
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
||||
}
|
||||
if databaseURL != "" {
|
||||
pool, dbErr := db.OpenPostgres(databaseURL)
|
||||
pool, dbErr := pg.Open(databaseURL)
|
||||
if dbErr != nil {
|
||||
return nil, dbErr
|
||||
}
|
||||
|
||||
88
db/db.go
88
db/db.go
@@ -2,96 +2,8 @@ package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
const (
|
||||
paramMaxOpenConns = "pool_max_conns"
|
||||
paramMaxIdleConns = "pool_max_idle_conns"
|
||||
paramConnMaxLifetime = "pool_conn_max_lifetime"
|
||||
paramConnMaxIdleTime = "pool_conn_max_idle_time"
|
||||
|
||||
defaultMaxOpenConns = 10
|
||||
)
|
||||
|
||||
// OpenPostgres opens a PostgreSQL database connection pool from a DSN string. It supports custom
|
||||
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
|
||||
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
|
||||
// the DSN before passing it to the driver.
|
||||
func OpenPostgres(dsn string) (*sql.DB, error) {
|
||||
u, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid database URL: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
db, err := sql.Open("pgx", u.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(maxOpenConns)
|
||||
if maxIdleConns > 0 {
|
||||
db.SetMaxIdleConns(maxIdleConns)
|
||||
}
|
||||
if connMaxLifetime > 0 {
|
||||
db.SetConnMaxLifetime(connMaxLifetime)
|
||||
}
|
||||
if connMaxIdleTime > 0 {
|
||||
db.SetConnMaxIdleTime(connMaxIdleTime)
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("ping failed: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
|
||||
s := q.Get(key)
|
||||
if s == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
q.Del(key)
|
||||
v, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
|
||||
s := q.Get(key)
|
||||
if s == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
q.Del(key)
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// ExecTx executes a function within a database transaction. If the function returns an error,
|
||||
// the transaction is rolled back. Otherwise, the transaction is committed.
|
||||
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||
|
||||
93
db/pg/pg.go
Normal file
93
db/pg/pg.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package pg
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
const (
|
||||
paramMaxOpenConns = "pool_max_conns"
|
||||
paramMaxIdleConns = "pool_max_idle_conns"
|
||||
paramConnMaxLifetime = "pool_conn_max_lifetime"
|
||||
paramConnMaxIdleTime = "pool_conn_max_idle_time"
|
||||
|
||||
defaultMaxOpenConns = 10
|
||||
)
|
||||
|
||||
// Open opens a PostgreSQL database connection pool from a DSN string. It supports custom
|
||||
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
|
||||
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
|
||||
// the DSN before passing it to the driver.
|
||||
func Open(dsn string) (*sql.DB, error) {
|
||||
u, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid database URL: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
db, err := sql.Open("pgx", u.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(maxOpenConns)
|
||||
if maxIdleConns > 0 {
|
||||
db.SetMaxIdleConns(maxIdleConns)
|
||||
}
|
||||
if connMaxLifetime > 0 {
|
||||
db.SetConnMaxLifetime(connMaxLifetime)
|
||||
}
|
||||
if connMaxIdleTime > 0 {
|
||||
db.SetConnMaxIdleTime(connMaxIdleTime)
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("ping failed: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
|
||||
s := q.Get(key)
|
||||
if s == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
q.Del(key)
|
||||
v, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
|
||||
s := q.Get(key)
|
||||
if s == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
q.Del(key)
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -30,7 +30,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
|
||||
q.Set("pool_max_conns", testPoolMaxConns)
|
||||
u.RawQuery = q.Encode()
|
||||
dsn = u.String()
|
||||
setupDB, err := db.OpenPostgres(dsn)
|
||||
setupDB, err := pg.Open(dsn)
|
||||
require.Nil(t, err)
|
||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||
require.Nil(t, err)
|
||||
@@ -39,7 +39,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
|
||||
u.RawQuery = q.Encode()
|
||||
schemaDSN := u.String()
|
||||
t.Cleanup(func() {
|
||||
cleanDB, err := db.OpenPostgres(dsn)
|
||||
cleanDB, err := pg.Open(dsn)
|
||||
if err == nil {
|
||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||
cleanDB.Close()
|
||||
@@ -54,7 +54,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
|
||||
func CreateTestPostgres(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
schemaDSN := CreateTestPostgresSchema(t)
|
||||
testDB, err := db.OpenPostgres(schemaDSN)
|
||||
testDB, err := pg.Open(schemaDSN)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
testDB.Close()
|
||||
|
||||
@@ -33,7 +33,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gopkg.in/yaml.v2"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
@@ -178,11 +178,11 @@ func New(conf *Config) (*Server, error) {
|
||||
if payments.Available && conf.StripeSecretKey != "" {
|
||||
stripe = newStripeAPI()
|
||||
}
|
||||
// OpenPostgres shared PostgreSQL connection pool if configured
|
||||
// Open shared PostgreSQL connection pool if configured
|
||||
var pool *sql.DB
|
||||
if conf.DatabaseURL != "" {
|
||||
var err error
|
||||
pool, err = db.OpenPostgres(conf.DatabaseURL)
|
||||
pool, err = pg.Open(conf.DatabaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/urfave/cli/v2"
|
||||
"github.com/urfave/cli/v2/altsrc"
|
||||
"gopkg.in/yaml.v2"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -79,7 +79,7 @@ func execImport(c *cli.Context) error {
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
pgDB, err := db.OpenPostgres(databaseURL)
|
||||
pgDB, err := pg.Open(databaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
dbtest "heckel.io/ntfy/v2/db/test"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
@@ -36,7 +36,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc
|
||||
t.Run("postgres", func(t *testing.T) {
|
||||
schemaDSN := dbtest.CreateTestPostgresSchema(t)
|
||||
f(t, func(config *Config) *Manager {
|
||||
pool, err := db.OpenPostgres(schemaDSN)
|
||||
pool, err := pg.Open(schemaDSN)
|
||||
require.Nil(t, err)
|
||||
a, err := NewPostgresManager(pool, config)
|
||||
require.Nil(t, err)
|
||||
|
||||
Reference in New Issue
Block a user