Better error message for database-url errors

This commit is contained in:
binwiederhier
2026-03-10 21:18:34 -04:00
parent bcd07115c2
commit 997e20fa3f
2 changed files with 69 additions and 1 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"strconv" "strconv"
"strings"
"time" "time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
@@ -28,6 +29,12 @@ func Open(dsn string) (*sql.DB, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid database URL: %w", err) return nil, fmt.Errorf("invalid database URL: %w", err)
} }
switch u.Scheme {
case "postgres", "postgresql":
// OK
default:
return nil, fmt.Errorf("invalid database URL scheme %q, must be \"postgres\" or \"postgresql\" (URL: %s)", u.Scheme, censorPassword(u))
}
q := u.Query() q := u.Query()
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns) maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
if err != nil { if err != nil {
@@ -61,7 +68,7 @@ func Open(dsn string) (*sql.DB, error) {
db.SetConnMaxIdleTime(connMaxIdleTime) db.SetConnMaxIdleTime(connMaxIdleTime)
} }
if err := db.Ping(); err != nil { if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping failed: %w", err) return nil, fmt.Errorf("database ping failed (URL: %s): %w", censorPassword(u), err)
} }
return db, nil return db, nil
} }
@@ -79,6 +86,14 @@ func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
return v, nil return v, nil
} }
// censorPassword returns a string representation of the URL with the password replaced by "*****".
func censorPassword(u *url.URL) string {
if password, hasPassword := u.User.Password(); hasPassword {
return strings.Replace(u.String(), ":"+password+"@", ":*****@", 1)
}
return u.String()
}
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) { func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
s := q.Get(key) s := q.Get(key)
if s == "" { if s == "" {

53
db/pg/pg_test.go Normal file
View File

@@ -0,0 +1,53 @@
package pg
import (
"net/url"
"testing"
"github.com/stretchr/testify/require"
)
func TestOpen_InvalidScheme(t *testing.T) {
_, err := Open("postgresql+psycopg2://user:pass@localhost/db")
require.Error(t, err)
require.Contains(t, err.Error(), `invalid database URL scheme "postgresql+psycopg2"`)
require.Contains(t, err.Error(), "*****")
require.NotContains(t, err.Error(), "pass")
}
func TestOpen_InvalidURL(t *testing.T) {
_, err := Open("not a valid url\x00")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid database URL")
}
func TestCensorPassword(t *testing.T) {
tests := []struct {
name string
url string
expected string
}{
{
name: "with password",
url: "postgres://user:secret@localhost/db",
expected: "postgres://user:*****@localhost/db",
},
{
name: "without password",
url: "postgres://localhost/db",
expected: "postgres://localhost/db",
},
{
name: "user only",
url: "postgres://user@localhost/db",
expected: "postgres://user@localhost/db",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u, err := url.Parse(tt.url)
require.NoError(t, err)
require.Equal(t, tt.expected, censorPassword(u))
})
}
}