mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
Move OpenPostgres
This commit is contained in:
@@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
"github.com/urfave/cli/v2/altsrc"
|
"github.com/urfave/cli/v2/altsrc"
|
||||||
"heckel.io/ntfy/v2/db"
|
"heckel.io/ntfy/v2/db/pg"
|
||||||
"heckel.io/ntfy/v2/server"
|
"heckel.io/ntfy/v2/server"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
@@ -379,7 +379,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
|||||||
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
||||||
}
|
}
|
||||||
if databaseURL != "" {
|
if databaseURL != "" {
|
||||||
pool, dbErr := db.OpenPostgres(databaseURL)
|
pool, dbErr := pg.Open(databaseURL)
|
||||||
if dbErr != nil {
|
if dbErr != nil {
|
||||||
return nil, dbErr
|
return nil, dbErr
|
||||||
}
|
}
|
||||||
|
|||||||
88
db/db.go
88
db/db.go
@@ -2,96 +2,8 @@ package db
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
paramMaxOpenConns = "pool_max_conns"
|
|
||||||
paramMaxIdleConns = "pool_max_idle_conns"
|
|
||||||
paramConnMaxLifetime = "pool_conn_max_lifetime"
|
|
||||||
paramConnMaxIdleTime = "pool_conn_max_idle_time"
|
|
||||||
|
|
||||||
defaultMaxOpenConns = 10
|
|
||||||
)
|
|
||||||
|
|
||||||
// OpenPostgres opens a PostgreSQL database connection pool from a DSN string. It supports custom
|
|
||||||
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
|
|
||||||
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
|
|
||||||
// the DSN before passing it to the driver.
|
|
||||||
func OpenPostgres(dsn string) (*sql.DB, error) {
|
|
||||||
u, err := url.Parse(dsn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid database URL: %w", err)
|
|
||||||
}
|
|
||||||
q := u.Query()
|
|
||||||
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
db, err := sql.Open("pgx", u.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
db.SetMaxOpenConns(maxOpenConns)
|
|
||||||
if maxIdleConns > 0 {
|
|
||||||
db.SetMaxIdleConns(maxIdleConns)
|
|
||||||
}
|
|
||||||
if connMaxLifetime > 0 {
|
|
||||||
db.SetConnMaxLifetime(connMaxLifetime)
|
|
||||||
}
|
|
||||||
if connMaxIdleTime > 0 {
|
|
||||||
db.SetConnMaxIdleTime(connMaxIdleTime)
|
|
||||||
}
|
|
||||||
if err := db.Ping(); err != nil {
|
|
||||||
return nil, fmt.Errorf("ping failed: %w", err)
|
|
||||||
}
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
|
|
||||||
s := q.Get(key)
|
|
||||||
if s == "" {
|
|
||||||
return defaultValue, nil
|
|
||||||
}
|
|
||||||
q.Del(key)
|
|
||||||
v, err := strconv.Atoi(s)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
|
||||||
}
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
|
|
||||||
s := q.Get(key)
|
|
||||||
if s == "" {
|
|
||||||
return defaultValue, nil
|
|
||||||
}
|
|
||||||
q.Del(key)
|
|
||||||
d, err := time.ParseDuration(s)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
|
||||||
}
|
|
||||||
return d, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecTx executes a function within a database transaction. If the function returns an error,
|
// ExecTx executes a function within a database transaction. If the function returns an error,
|
||||||
// the transaction is rolled back. Otherwise, the transaction is committed.
|
// the transaction is rolled back. Otherwise, the transaction is committed.
|
||||||
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||||
|
|||||||
93
db/pg/pg.go
Normal file
93
db/pg/pg.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package pg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
paramMaxOpenConns = "pool_max_conns"
|
||||||
|
paramMaxIdleConns = "pool_max_idle_conns"
|
||||||
|
paramConnMaxLifetime = "pool_conn_max_lifetime"
|
||||||
|
paramConnMaxIdleTime = "pool_conn_max_idle_time"
|
||||||
|
|
||||||
|
defaultMaxOpenConns = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// Open opens a PostgreSQL database connection pool from a DSN string. It supports custom
|
||||||
|
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
|
||||||
|
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
|
||||||
|
// the DSN before passing it to the driver.
|
||||||
|
func Open(dsn string) (*sql.DB, error) {
|
||||||
|
u, err := url.Parse(dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid database URL: %w", err)
|
||||||
|
}
|
||||||
|
q := u.Query()
|
||||||
|
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
db, err := sql.Open("pgx", u.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
db.SetMaxOpenConns(maxOpenConns)
|
||||||
|
if maxIdleConns > 0 {
|
||||||
|
db.SetMaxIdleConns(maxIdleConns)
|
||||||
|
}
|
||||||
|
if connMaxLifetime > 0 {
|
||||||
|
db.SetConnMaxLifetime(connMaxLifetime)
|
||||||
|
}
|
||||||
|
if connMaxIdleTime > 0 {
|
||||||
|
db.SetConnMaxIdleTime(connMaxIdleTime)
|
||||||
|
}
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
return nil, fmt.Errorf("ping failed: %w", err)
|
||||||
|
}
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
|
||||||
|
s := q.Get(key)
|
||||||
|
if s == "" {
|
||||||
|
return defaultValue, nil
|
||||||
|
}
|
||||||
|
q.Del(key)
|
||||||
|
v, err := strconv.Atoi(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
|
||||||
|
s := q.Get(key)
|
||||||
|
if s == "" {
|
||||||
|
return defaultValue, nil
|
||||||
|
}
|
||||||
|
q.Del(key)
|
||||||
|
d, err := time.ParseDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"heckel.io/ntfy/v2/db"
|
"heckel.io/ntfy/v2/db/pg"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
|
|||||||
q.Set("pool_max_conns", testPoolMaxConns)
|
q.Set("pool_max_conns", testPoolMaxConns)
|
||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
dsn = u.String()
|
dsn = u.String()
|
||||||
setupDB, err := db.OpenPostgres(dsn)
|
setupDB, err := pg.Open(dsn)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@@ -39,7 +39,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
|
|||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
schemaDSN := u.String()
|
schemaDSN := u.String()
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
cleanDB, err := db.OpenPostgres(dsn)
|
cleanDB, err := pg.Open(dsn)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||||
cleanDB.Close()
|
cleanDB.Close()
|
||||||
@@ -54,7 +54,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
|
|||||||
func CreateTestPostgres(t *testing.T) *sql.DB {
|
func CreateTestPostgres(t *testing.T) *sql.DB {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
schemaDSN := CreateTestPostgresSchema(t)
|
schemaDSN := CreateTestPostgresSchema(t)
|
||||||
testDB, err := db.OpenPostgres(schemaDSN)
|
testDB, err := pg.Open(schemaDSN)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
testDB.Close()
|
testDB.Close()
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ import (
|
|||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
"heckel.io/ntfy/v2/db"
|
"heckel.io/ntfy/v2/db/pg"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
"heckel.io/ntfy/v2/message"
|
"heckel.io/ntfy/v2/message"
|
||||||
"heckel.io/ntfy/v2/model"
|
"heckel.io/ntfy/v2/model"
|
||||||
@@ -178,11 +178,11 @@ func New(conf *Config) (*Server, error) {
|
|||||||
if payments.Available && conf.StripeSecretKey != "" {
|
if payments.Available && conf.StripeSecretKey != "" {
|
||||||
stripe = newStripeAPI()
|
stripe = newStripeAPI()
|
||||||
}
|
}
|
||||||
// OpenPostgres shared PostgreSQL connection pool if configured
|
// Open shared PostgreSQL connection pool if configured
|
||||||
var pool *sql.DB
|
var pool *sql.DB
|
||||||
if conf.DatabaseURL != "" {
|
if conf.DatabaseURL != "" {
|
||||||
var err error
|
var err error
|
||||||
pool, err = db.OpenPostgres(conf.DatabaseURL)
|
pool, err = pg.Open(conf.DatabaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
"github.com/urfave/cli/v2/altsrc"
|
"github.com/urfave/cli/v2/altsrc"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
"heckel.io/ntfy/v2/db"
|
"heckel.io/ntfy/v2/db/pg"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -79,7 +79,7 @@ func execImport(c *cli.Context) error {
|
|||||||
}
|
}
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
|
|
||||||
pgDB, err := db.OpenPostgres(databaseURL)
|
pgDB, err := pg.Open(databaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
|
return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"heckel.io/ntfy/v2/db"
|
"heckel.io/ntfy/v2/db/pg"
|
||||||
dbtest "heckel.io/ntfy/v2/db/test"
|
dbtest "heckel.io/ntfy/v2/db/test"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
@@ -36,7 +36,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc
|
|||||||
t.Run("postgres", func(t *testing.T) {
|
t.Run("postgres", func(t *testing.T) {
|
||||||
schemaDSN := dbtest.CreateTestPostgresSchema(t)
|
schemaDSN := dbtest.CreateTestPostgresSchema(t)
|
||||||
f(t, func(config *Config) *Manager {
|
f(t, func(config *Config) *Manager {
|
||||||
pool, err := db.OpenPostgres(schemaDSN)
|
pool, err := pg.Open(schemaDSN)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
a, err := NewPostgresManager(pool, config)
|
a, err := NewPostgresManager(pool, config)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|||||||
Reference in New Issue
Block a user