From 997e20fa3f70920097b4deb376b7f2f38b6c6dc3 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Tue, 10 Mar 2026 21:18:34 -0400 Subject: [PATCH] Better error message for database-url errors --- db/pg/pg.go | 17 +++++++++++++++- db/pg/pg_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 db/pg/pg_test.go diff --git a/db/pg/pg.go b/db/pg/pg.go index 99f802b3..228c167f 100644 --- a/db/pg/pg.go +++ b/db/pg/pg.go @@ -5,6 +5,7 @@ import ( "fmt" "net/url" "strconv" + "strings" "time" _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver @@ -28,6 +29,12 @@ func Open(dsn string) (*sql.DB, error) { if err != nil { return nil, fmt.Errorf("invalid database URL: %w", err) } + switch u.Scheme { + case "postgres", "postgresql": + // OK + default: + return nil, fmt.Errorf("invalid database URL scheme %q, must be \"postgres\" or \"postgresql\" (URL: %s)", u.Scheme, censorPassword(u)) + } q := u.Query() maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns) if err != nil { @@ -61,7 +68,7 @@ func Open(dsn string) (*sql.DB, error) { db.SetConnMaxIdleTime(connMaxIdleTime) } if err := db.Ping(); err != nil { - return nil, fmt.Errorf("ping failed: %w", err) + return nil, fmt.Errorf("database ping failed (URL: %s): %w", censorPassword(u), err) } return db, nil } @@ -79,6 +86,14 @@ func extractIntParam(q url.Values, key string, defaultValue int) (int, error) { return v, nil } +// censorPassword returns a string representation of the URL with the password replaced by "*****". +func censorPassword(u *url.URL) string { + if password, hasPassword := u.User.Password(); hasPassword { + return strings.Replace(u.String(), ":"+password+"@", ":*****@", 1) + } + return u.String() +} + func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) { s := q.Get(key) if s == "" { diff --git a/db/pg/pg_test.go b/db/pg/pg_test.go new file mode 100644 index 00000000..cc66d7e9 --- /dev/null +++ b/db/pg/pg_test.go @@ -0,0 +1,53 @@ +package pg + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOpen_InvalidScheme(t *testing.T) { + _, err := Open("postgresql+psycopg2://user:pass@localhost/db") + require.Error(t, err) + require.Contains(t, err.Error(), `invalid database URL scheme "postgresql+psycopg2"`) + require.Contains(t, err.Error(), "*****") + require.NotContains(t, err.Error(), "pass") +} + +func TestOpen_InvalidURL(t *testing.T) { + _, err := Open("not a valid url\x00") + require.Error(t, err) + require.Contains(t, err.Error(), "invalid database URL") +} + +func TestCensorPassword(t *testing.T) { + tests := []struct { + name string + url string + expected string + }{ + { + name: "with password", + url: "postgres://user:secret@localhost/db", + expected: "postgres://user:*****@localhost/db", + }, + { + name: "without password", + url: "postgres://localhost/db", + expected: "postgres://localhost/db", + }, + { + name: "user only", + url: "postgres://user@localhost/db", + expected: "postgres://user@localhost/db", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, err := url.Parse(tt.url) + require.NoError(t, err) + require.Equal(t, tt.expected, censorPassword(u)) + }) + } +}