From 85bdfc61ceae999b836f2a19ffb1b8474ff2c5a3 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Wed, 11 Mar 2026 21:07:58 -0400 Subject: [PATCH] Refine, log unhealthy replica --- cmd/user.go | 4 +-- db/db.go | 59 ++++++++++++++++++++++------------------- db/pg/pg.go | 43 ++++++++++++++++++++++-------- db/test/test.go | 16 +++++------ message/cache_sqlite.go | 2 +- server/server.go | 11 ++++---- tools/pgimport/main.go | 3 ++- user/manager_sqlite.go | 2 +- user/manager_test.go | 4 +-- webpush/store_sqlite.go | 2 +- 10 files changed, 86 insertions(+), 60 deletions(-) diff --git a/cmd/user.go b/cmd/user.go index f108650b..cd6cf795 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -380,11 +380,11 @@ func createUserManager(c *cli.Context) (*user.Manager, error) { QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval, } if databaseURL != "" { - pool, dbErr := pg.Open(databaseURL) + host, dbErr := pg.Open(databaseURL) if dbErr != nil { return nil, dbErr } - return user.NewPostgresManager(db.NewDB(pool, nil), authConfig) + return user.NewPostgresManager(db.New(host, nil), authConfig) } else if authFile != "" { if !util.FileExists(authFile) { return nil, errors.New("auth-file does not exist; please start the server at least once to create it") diff --git a/db/db.go b/db/db.go index 9b45ca7f..3d1a5709 100644 --- a/db/db.go +++ b/db/db.go @@ -10,8 +10,9 @@ import ( ) const ( - replicaHealthCheckInterval = 30 * time.Second - replicaHealthCheckTimeout = 2 * time.Second + replicaHealthCheckInitialDelay = 5 * time.Second + replicaHealthCheckInterval = 30 * time.Second + replicaHealthCheckTimeout = 10 * time.Second ) // Beginner is an interface for types that can begin a database transaction. @@ -24,30 +25,29 @@ type Beginner interface { // delegate to the primary. The ReadOnly() method returns a *sql.DB from a healthy replica // (round-robin), falling back to the primary if no replicas are configured or all are unhealthy. type DB struct { - primary *sql.DB - replicas []*replica + primary *Host + replicas []*Host counter atomic.Uint64 cancel context.CancelFunc } -type replica struct { - db *sql.DB +// Host pairs a *sql.DB with the host:port it was opened against. +type Host struct { + Addr string // "host:port" + DB *sql.DB healthy atomic.Bool } -// NewDB creates a new DB that wraps the given primary and optional replica connections. +// New creates a new DB that wraps the given primary and optional replica connections. // If replicas is nil or empty, ReadOnly() simply returns the primary. // Replicas start unhealthy and are checked immediately by a background goroutine. -func NewDB(primary *sql.DB, replicas []*sql.DB) *DB { +func New(primary *Host, replicas []*Host) *DB { ctx, cancel := context.WithCancel(context.Background()) d := &DB{ primary: primary, - replicas: make([]*replica, len(replicas)), + replicas: replicas, cancel: cancel, } - for i, r := range replicas { - d.replicas[i] = &replica{db: r} // healthy defaults to false - } if len(d.replicas) > 0 { go d.healthCheckLoop(ctx) } @@ -57,63 +57,68 @@ func NewDB(primary *sql.DB, replicas []*sql.DB) *DB { // Primary returns the underlying primary *sql.DB. This is only intended for // one-time schema setup during store initialization, not for regular queries. func (d *DB) Primary() *sql.DB { - return d.primary + return d.primary.DB } // Query delegates to the primary database. func (d *DB) Query(query string, args ...any) (*sql.Rows, error) { - return d.primary.Query(query, args...) + return d.primary.DB.Query(query, args...) } // QueryRow delegates to the primary database. func (d *DB) QueryRow(query string, args ...any) *sql.Row { - return d.primary.QueryRow(query, args...) + return d.primary.DB.QueryRow(query, args...) } // Exec delegates to the primary database. func (d *DB) Exec(query string, args ...any) (sql.Result, error) { - return d.primary.Exec(query, args...) + return d.primary.DB.Exec(query, args...) } // Begin delegates to the primary database. func (d *DB) Begin() (*sql.Tx, error) { - return d.primary.Begin() + return d.primary.DB.Begin() } // Ping delegates to the primary database. func (d *DB) Ping() error { - return d.primary.Ping() + return d.primary.DB.Ping() } // Close closes the primary database and all replicas, and stops the health-check goroutine. func (d *DB) Close() error { d.cancel() for _, r := range d.replicas { - r.db.Close() + r.DB.Close() } - return d.primary.Close() + return d.primary.DB.Close() } // ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy // replicas. If all replicas are unhealthy or none are configured, the primary is returned. func (d *DB) ReadOnly() *sql.DB { if len(d.replicas) == 0 { - return d.primary + return d.primary.DB } n := len(d.replicas) start := int(d.counter.Add(1) - 1) for i := 0; i < n; i++ { r := d.replicas[(start+i)%n] if r.healthy.Load() { - return r.db + return r.DB } } - return d.primary + return d.primary.DB } // healthCheckLoop checks replicas immediately, then periodically on a ticker. func (d *DB) healthCheckLoop(ctx context.Context) { - d.checkReplicas(ctx) + select { + case <-ctx.Done(): + return + case <-time.After(replicaHealthCheckInitialDelay): + d.checkReplicas(ctx) + } for { select { case <-ctx.Done(): @@ -129,17 +134,17 @@ func (d *DB) checkReplicas(ctx context.Context) { for _, r := range d.replicas { wasHealthy := r.healthy.Load() pingCtx, cancel := context.WithTimeout(ctx, replicaHealthCheckTimeout) - err := r.db.PingContext(pingCtx) + err := r.DB.PingContext(pingCtx) cancel() if err != nil { r.healthy.Store(false) if wasHealthy { - log.Error("Database replica is now unhealthy: %s", err) + log.Error("Database replica %s is unhealthy: %s", r.Addr, err) } } else { r.healthy.Store(true) if !wasHealthy { - log.Info("Database replica is now healthy again") + log.Info("Database replica %s is healthy", r.Addr) } } } diff --git a/db/pg/pg.go b/db/pg/pg.go index 228c167f..00da2518 100644 --- a/db/pg/pg.go +++ b/db/pg/pg.go @@ -9,6 +9,8 @@ import ( "time" _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver + + "heckel.io/ntfy/v2/db" ) const ( @@ -20,11 +22,30 @@ const ( defaultMaxOpenConns = 10 ) -// Open opens a PostgreSQL database connection pool from a DSN string. It supports custom +// Open opens a PostgreSQL connection pool for a primary database. It pings the database +// to verify connectivity before returning. +func Open(dsn string) (*db.Host, error) { + d, err := open(dsn) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + if err := d.DB.Ping(); err != nil { + return nil, fmt.Errorf("database ping failed on %v: %w", d.Addr, err) + } + return d, nil +} + +// OpenReplica opens a PostgreSQL connection pool for a read replica. Unlike Open, it does +// not ping the database, since replicas are health-checked in the background by db.DB. +func OpenReplica(dsn string) (*db.Host, error) { + return open(dsn) +} + +// open opens a PostgreSQL database connection pool from a DSN string. It supports custom // query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns, // pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from // the DSN before passing it to the driver. -func Open(dsn string) (*sql.DB, error) { +func open(dsn string) (*db.Host, error) { u, err := url.Parse(dsn) if err != nil { return nil, fmt.Errorf("invalid database URL: %w", err) @@ -53,24 +74,24 @@ func Open(dsn string) (*sql.DB, error) { return nil, err } u.RawQuery = q.Encode() - db, err := sql.Open("pgx", u.String()) + d, err := sql.Open("pgx", u.String()) if err != nil { return nil, err } - db.SetMaxOpenConns(maxOpenConns) + d.SetMaxOpenConns(maxOpenConns) if maxIdleConns > 0 { - db.SetMaxIdleConns(maxIdleConns) + d.SetMaxIdleConns(maxIdleConns) } if connMaxLifetime > 0 { - db.SetConnMaxLifetime(connMaxLifetime) + d.SetConnMaxLifetime(connMaxLifetime) } if connMaxIdleTime > 0 { - db.SetConnMaxIdleTime(connMaxIdleTime) + d.SetConnMaxIdleTime(connMaxIdleTime) } - if err := db.Ping(); err != nil { - return nil, fmt.Errorf("database ping failed (URL: %s): %w", censorPassword(u), err) - } - return db, nil + return &db.Host{ + Addr: u.Host, + DB: d, + }, nil } func extractIntParam(q url.Values, key string, defaultValue int) (int, error) { diff --git a/db/test/test.go b/db/test/test.go index 0bfa8e11..8d3f329b 100644 --- a/db/test/test.go +++ b/db/test/test.go @@ -30,19 +30,19 @@ func CreateTestPostgresSchema(t *testing.T) string { q.Set("pool_max_conns", testPoolMaxConns) u.RawQuery = q.Encode() dsn = u.String() - setupDB, err := pg.Open(dsn) + setupHost, err := pg.Open(dsn) require.Nil(t, err) - _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) + _, err = setupHost.DB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) require.Nil(t, err) - require.Nil(t, setupDB.Close()) + require.Nil(t, setupHost.DB.Close()) q.Set("search_path", schema) u.RawQuery = q.Encode() schemaDSN := u.String() t.Cleanup(func() { - cleanDB, err := pg.Open(dsn) + cleanHost, err := pg.Open(dsn) if err == nil { - cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) - cleanDB.Close() + cleanHost.DB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) + cleanHost.DB.Close() } }) return schemaDSN @@ -54,9 +54,9 @@ func CreateTestPostgresSchema(t *testing.T) string { func CreateTestPostgres(t *testing.T) *db.DB { t.Helper() schemaDSN := CreateTestPostgresSchema(t) - testDB, err := pg.Open(schemaDSN) + testHost, err := pg.Open(schemaDSN) require.Nil(t, err) - d := db.NewDB(testDB, nil) + d := db.New(testHost, nil) t.Cleanup(func() { d.Close() }) diff --git a/message/cache_sqlite.go b/message/cache_sqlite.go index 375902d6..5593ec63 100644 --- a/message/cache_sqlite.go +++ b/message/cache_sqlite.go @@ -118,7 +118,7 @@ func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration if err := setupSQLite(sqlDB, startupQueries, cacheDuration); err != nil { return nil, err } - return newCache(db.NewDB(sqlDB, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil + return newCache(db.New(&db.Host{DB: sqlDB}, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil } // NewMemStore creates an in-memory cache diff --git a/server/server.go b/server/server.go index 62fba3c2..24c712bd 100644 --- a/server/server.go +++ b/server/server.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "crypto/sha256" - "database/sql" "embed" "encoding/base64" "encoding/json" @@ -186,20 +185,20 @@ func New(conf *Config) (*Server, error) { if err != nil { return nil, err } - var replicas []*sql.DB + var replicas []*db.Host for _, replicaURL := range conf.DatabaseReplicaURLs { - r, err := pg.Open(replicaURL) + r, err := pg.OpenReplica(replicaURL) if err != nil { // Close already-opened replicas before returning for _, opened := range replicas { - opened.Close() + opened.DB.Close() } - primary.Close() + primary.DB.Close() return nil, fmt.Errorf("failed to open database replica: %w", err) } replicas = append(replicas, r) } - pool = db.NewDB(primary, replicas) + pool = db.New(primary, replicas) } messageCache, err := createMessageCache(conf, pool) if err != nil { diff --git a/tools/pgimport/main.go b/tools/pgimport/main.go index cbc171dd..3ba5273e 100644 --- a/tools/pgimport/main.go +++ b/tools/pgimport/main.go @@ -236,10 +236,11 @@ func execImport(c *cli.Context) error { } fmt.Println() - pgDB, err := pg.Open(databaseURL) + pgHost, err := pg.Open(databaseURL) if err != nil { return fmt.Errorf("cannot connect to PostgreSQL: %w", err) } + pgDB := pgHost.DB defer pgDB.Close() if c.Bool("create-schema") { diff --git a/user/manager_sqlite.go b/user/manager_sqlite.go index 573d1e41..a28a846a 100644 --- a/user/manager_sqlite.go +++ b/user/manager_sqlite.go @@ -291,5 +291,5 @@ func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil { return nil, err } - return newManager(db.NewDB(sqlDB, nil), sqliteQueries, config) + return newManager(db.New(&db.Host{DB: sqlDB}, nil), sqliteQueries, config) } diff --git a/user/manager_test.go b/user/manager_test.go index 73df3df1..3e023909 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -37,9 +37,9 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc t.Run("postgres", func(t *testing.T) { schemaDSN := dbtest.CreateTestPostgresSchema(t) f(t, func(config *Config) *Manager { - pool, err := pg.Open(schemaDSN) + host, err := pg.Open(schemaDSN) require.Nil(t, err) - a, err := NewPostgresManager(db.NewDB(pool, nil), config) + a, err := NewPostgresManager(db.New(host, nil), config) require.Nil(t, err) return a }) diff --git a/webpush/store_sqlite.go b/webpush/store_sqlite.go index c2d105ff..81c44704 100644 --- a/webpush/store_sqlite.go +++ b/webpush/store_sqlite.go @@ -90,7 +90,7 @@ func NewSQLiteStore(filename, startupQueries string) (*Store, error) { return nil, err } return &Store{ - db: db.NewDB(sqlDB, nil), + db: db.New(&db.Host{DB: sqlDB}, nil), queries: queries{ selectSubscriptionIDByEndpoint: sqliteSelectSubscriptionIDByEndpointQuery, selectSubscriptionCountBySubscriberIP: sqliteSelectSubscriptionCountBySubscriberIPQuery,