From f59df0f40ada3221a1857c665c9c82e0fdf63d2e Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Mon, 21 Jul 2025 17:44:00 +0200 Subject: [PATCH] Works --- Makefile | 2 +- cmd/access.go | 30 ++++-- cmd/serve.go | 109 ++++++++++++++----- cmd/user.go | 3 +- go.sum | 31 ------ server/message_cache.go | 7 ++ server/server.go | 5 +- server/server.yml | 6 ++ server/server_admin.go | 2 +- user/manager.go | 230 +++++++++++++++++++++++++++++++--------- user/manager_test.go | 10 +- user/types.go | 26 ++--- util/util.go | 12 +++ 13 files changed, 333 insertions(+), 140 deletions(-) diff --git a/Makefile b/Makefile index 575bb788..df131c7a 100644 --- a/Makefile +++ b/Makefile @@ -232,7 +232,7 @@ cli-deps-update: go get -u go install honnef.co/go/tools/cmd/staticcheck@latest go install golang.org/x/lint/golint@latest - go install github.com/goreleaser/goreleaser@latest + go install github.com/goreleaser/goreleaser/v2@latest cli-build-results: cat dist/config.yaml diff --git a/cmd/access.go b/cmd/access.go index c6be94b5..10247b5f 100644 --- a/cmd/access.go +++ b/cmd/access.go @@ -105,8 +105,10 @@ func changeAccess(c *cli.Context, manager *user.Manager, username string, topic return err } u, err := manager.User(username) - if err == user.ErrUserNotFound { + if errors.Is(err, user.ErrUserNotFound) { return fmt.Errorf("user %s does not exist", username) + } else if err != nil { + return err } else if u.Role == user.RoleAdmin { return fmt.Errorf("user %s is an admin user, access control entries have no effect", username) } @@ -175,7 +177,7 @@ func showAllAccess(c *cli.Context, manager *user.Manager) error { func showUserAccess(c *cli.Context, manager *user.Manager, username string) error { users, err := manager.User(username) - if err == user.ErrUserNotFound { + if errors.Is(err, user.ErrUserNotFound) { return fmt.Errorf("user %s does not exist", username) } else if err != nil { return err @@ -193,19 +195,27 @@ func showUsers(c *cli.Context, manager *user.Manager, users []*user.User) error if u.Tier != nil { tier = u.Tier.Name } - fmt.Fprintf(c.App.ErrWriter, "user %s (role: %s, tier: %s)\n", u.Name, u.Role, tier) + provisioned := "" + if u.Provisioned { + provisioned = ", provisioned user" + } + fmt.Fprintf(c.App.ErrWriter, "user %s (role: %s, tier: %s%s)\n", u.Name, u.Role, tier, provisioned) if u.Role == user.RoleAdmin { fmt.Fprintf(c.App.ErrWriter, "- read-write access to all topics (admin role)\n") } else if len(grants) > 0 { for _, grant := range grants { - if grant.Allow.IsReadWrite() { - fmt.Fprintf(c.App.ErrWriter, "- read-write access to topic %s\n", grant.TopicPattern) - } else if grant.Allow.IsRead() { - fmt.Fprintf(c.App.ErrWriter, "- read-only access to topic %s\n", grant.TopicPattern) - } else if grant.Allow.IsWrite() { - fmt.Fprintf(c.App.ErrWriter, "- write-only access to topic %s\n", grant.TopicPattern) + grantProvisioned := "" + if grant.Provisioned { + grantProvisioned = ", provisioned access entry" + } + if grant.Permission.IsReadWrite() { + fmt.Fprintf(c.App.ErrWriter, "- read-write access to topic %s%s\n", grant.TopicPattern, grantProvisioned) + } else if grant.Permission.IsRead() { + fmt.Fprintf(c.App.ErrWriter, "- read-only access to topic %s%s\n", grant.TopicPattern, grantProvisioned) + } else if grant.Permission.IsWrite() { + fmt.Fprintf(c.App.ErrWriter, "- write-only access to topic %s%s\n", grant.TopicPattern, grantProvisioned) } else { - fmt.Fprintf(c.App.ErrWriter, "- no access to topic %s\n", grant.TopicPattern) + fmt.Fprintf(c.App.ErrWriter, "- no access to topic %s%s\n", grant.TopicPattern, grantProvisioned) } } } else { diff --git a/cmd/serve.go b/cmd/serve.go index 50314b88..ef37ee6f 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -48,7 +48,8 @@ var flagsServe = append( altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file", "H"}, EnvVars: []string{"NTFY_AUTH_FILE"}, Usage: "auth database file used for access control"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-startup-queries", Aliases: []string{"auth_startup_queries"}, EnvVars: []string{"NTFY_AUTH_STARTUP_QUERIES"}, Usage: "queries run when the auth database is initialized"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}), - altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-provisioned-users", Aliases: []string{"auth_provisioned_users"}, EnvVars: []string{"NTFY_AUTH_PROVISIONED_USERS"}, Usage: "pre-provisioned declarative users"}), + altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-provision-users", Aliases: []string{"auth_provision_users"}, EnvVars: []string{"NTFY_AUTH_PROVISION_USERS"}, Usage: "pre-provisioned declarative users"}), + altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-provision-access", Aliases: []string{"auth_provision_access"}, EnvVars: []string{"NTFY_AUTH_PROVISION_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}), @@ -155,8 +156,8 @@ func execServe(c *cli.Context) error { authFile := c.String("auth-file") authStartupQueries := c.String("auth-startup-queries") authDefaultAccess := c.String("auth-default-access") - authProvisionedUsersRaw := c.StringSlice("auth-provisioned-users") - //authProvisionedAccessRaw := c.StringSlice("auth-provisioned-access") + authProvisionUsersRaw := c.StringSlice("auth-provision-users") + authProvisionAccessRaw := c.StringSlice("auth-provision-access") attachmentCacheDir := c.String("attachment-cache-dir") attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit") attachmentFileSizeLimitStr := c.String("attachment-file-size-limit") @@ -352,27 +353,13 @@ func execServe(c *cli.Context) error { if err != nil { return errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'") } - authProvisionedUsers := make([]*user.User, 0) - for _, userLine := range authProvisionedUsersRaw { - parts := strings.Split(userLine, ":") - if len(parts) != 3 { - return fmt.Errorf("invalid provisioned user %s, expected format: 'name:hash:role'", userLine) - } - username := strings.TrimSpace(parts[0]) - passwordHash := strings.TrimSpace(parts[1]) - role := user.Role(strings.TrimSpace(parts[2])) - if !user.AllowedUsername(username) { - return fmt.Errorf("invalid provisioned user %s, username invalid", userLine) - } else if passwordHash == "" { - return fmt.Errorf("invalid provisioned user %s, password hash cannot be empty", userLine) - } else if !user.AllowedRole(role) { - return fmt.Errorf("invalid provisioned user %s, role %s is not allowed, allowed roles are 'admin' or 'user'", userLine, role) - } - authProvisionedUsers = append(authProvisionedUsers, &user.User{ - Name: username, - Hash: passwordHash, - Role: role, - }) + authProvisionUsers, err := parseProvisionUsers(authProvisionUsersRaw) + if err != nil { + return err + } + authProvisionAccess, err := parseProvisionAccess(authProvisionUsers, authProvisionAccessRaw) + if err != nil { + return err } // Special case: Unset default @@ -429,8 +416,8 @@ func execServe(c *cli.Context) error { conf.AuthFile = authFile conf.AuthStartupQueries = authStartupQueries conf.AuthDefault = authDefault - conf.AuthProvisionedUsers = authProvisionedUsers - conf.AuthProvisionedAccess = nil // FIXME + conf.AuthProvisionedUsers = authProvisionUsers + conf.AuthProvisionedAccess = authProvisionAccess conf.AttachmentCacheDir = attachmentCacheDir conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit conf.AttachmentFileSizeLimit = attachmentFileSizeLimit @@ -544,6 +531,76 @@ func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) { return } +func parseProvisionUsers(usersRaw []string) ([]*user.User, error) { + provisionUsers := make([]*user.User, 0) + for _, userLine := range usersRaw { + parts := strings.Split(userLine, ":") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid auth-provision-users: %s, expected format: 'name:hash:role'", userLine) + } + username := strings.TrimSpace(parts[0]) + passwordHash := strings.TrimSpace(parts[1]) + role := user.Role(strings.TrimSpace(parts[2])) + if !user.AllowedUsername(username) { + return nil, fmt.Errorf("invalid auth-provision-users: %s, username invalid", userLine) + } else if passwordHash == "" { + return nil, fmt.Errorf("invalid auth-provision-users: %s, password hash cannot be empty", userLine) + } else if !user.AllowedRole(role) { + return nil, fmt.Errorf("invalid auth-provision-users: %s, role %s is not allowed, allowed roles are 'admin' or 'user'", userLine, role) + } + provisionUsers = append(provisionUsers, &user.User{ + Name: username, + Hash: passwordHash, + Role: role, + Provisioned: true, + }) + } + return provisionUsers, nil +} + +func parseProvisionAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[string][]*user.Grant, error) { + access := make(map[string][]*user.Grant) + for _, accessLine := range provisionAccessRaw { + parts := strings.Split(accessLine, ":") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid auth-provision-access: %s, expected format: 'user:topic:permission'", accessLine) + } + username := strings.TrimSpace(parts[0]) + if username == userEveryone { + username = user.Everyone + } + provisionUser, exists := util.Find(provisionUsers, func(u *user.User) bool { + return u.Name == username + }) + if username != user.Everyone { + if !exists { + return nil, fmt.Errorf("invalid auth-provision-access: %s, user %s is not provisioned", accessLine, username) + } else if !user.AllowedUsername(username) { + return nil, fmt.Errorf("invalid auth-provision-access: %s, username %s invalid", accessLine, username) + } else if provisionUser.Role != user.RoleUser { + return nil, fmt.Errorf("invalid auth-provision-access: %s, user %s is not a regular user, only regular users can have ACL entries", accessLine, username) + } + } + topic := strings.TrimSpace(parts[1]) + if !user.AllowedTopicPattern(topic) { + return nil, fmt.Errorf("invalid auth-provision-access: %s, topic pattern %s invalid", accessLine, topic) + } + permission, err := user.ParsePermission(strings.TrimSpace(parts[2])) + if err != nil { + return nil, fmt.Errorf("invalid auth-provision-access: %s, permission %s invalid, %s", accessLine, parts[2], err.Error()) + } + if _, exists := access[username]; !exists { + access[username] = make([]*user.Grant, 0) + } + access[username] = append(access[username], &user.Grant{ + TopicPattern: topic, + Permission: permission, + Provisioned: true, + }) + } + return access, nil +} + func reloadLogLevel(inputSource altsrc.InputSourceContext) error { newLevelStr, err := inputSource.String("log-level") if err != nil { diff --git a/cmd/user.go b/cmd/user.go index 31f4c31b..0a6e24a1 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -349,8 +349,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) { Filename: authFile, StartupQueries: authStartupQueries, DefaultAccess: authDefault, - ProvisionedUsers: nil, //FIXME - ProvisionedAccess: nil, //FIXME + ProvisionEnabled: false, // Do not re-provision users on manager initialization BcryptCost: user.DefaultUserPasswordBcryptCost, QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval, } diff --git a/go.sum b/go.sum index 1f98da35..575b5c22 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,7 @@ cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY= cel.dev/expr v0.24.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= -cloud.google.com/go v0.121.3 h1:84RD+hQXNdY5Sw/MWVAx5O9Aui/rd5VQ9HEcdN19afo= -cloud.google.com/go v0.121.3/go.mod h1:6vWF3nJWRrEUv26mMB3FEIU/o1MQNVPG1iHdisa2SJc= cloud.google.com/go v0.121.4 h1:cVvUiY0sX0xwyxPwdSU2KsF9knOVmtRyAMt8xou0iTs= cloud.google.com/go v0.121.4/go.mod h1:XEBchUiHFJbz4lKBZwYBDHV/rSyfFktk737TLDU089s= -cloud.google.com/go/auth v0.16.2 h1:QvBAGFPLrDeoiNjyfVunhQ10HKNYuOwZ5noee0M5df4= -cloud.google.com/go/auth v0.16.2/go.mod h1:sRBas2Y1fB1vZTdurouM0AzuYQBMZinrUYL8EufhtEA= cloud.google.com/go/auth v0.16.3 h1:kabzoQ9/bobUmnseYnBO6qQG7q4a/CffFRlJSxv2wCc= cloud.google.com/go/auth v0.16.3/go.mod h1:NucRGjaXfzP1ltpcQ7On/VTZ0H4kWB5Jy+Y9Dnm76fA= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= @@ -26,8 +22,6 @@ cloud.google.com/go/storage v1.55.0 h1:NESjdAToN9u1tmhVqhXCaCwYBuvEhZLLv0gBr+2zn cloud.google.com/go/storage v1.55.0/go.mod h1:ztSmTTwzsdXe5syLVS0YsbFxXuvEmEyZj7v7zChEmuY= cloud.google.com/go/trace v1.11.6 h1:2O2zjPzqPYAHrn3OKl029qlqG6W8ZdYaOWRyr8NgMT4= cloud.google.com/go/trace v1.11.6/go.mod h1:GA855OeDEBiBMzcckLPE2kDunIpC72N+Pq8WFieFjnI= -firebase.google.com/go/v4 v4.16.1 h1:Kl5cgXmM0VOWDGT1UAx6b0T2UFWa14ak0CvYqeI7Py4= -firebase.google.com/go/v4 v4.16.1/go.mod h1:aAPJq/bOyb23tBlc1K6GR+2E8sOGAeJSc8wIJVgl9SM= firebase.google.com/go/v4 v4.17.0 h1:Bih69QV/k0YKPA1qUX04ln0aPT9IERrAo2ezibcngzE= firebase.google.com/go/v4 v4.17.0/go.mod h1:aAPJq/bOyb23tBlc1K6GR+2E8sOGAeJSc8wIJVgl9SM= github.com/AlekSi/pointer v1.2.0 h1:glcy/gc4h8HnG2Z3ZECSzZ1IX1x2JxRVuDzaJwQE0+w= @@ -87,8 +81,6 @@ github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= -github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.3 h1:kkGXqQOBSDDWRhWNXTFpqGSCMyh/PLnqUvMGJPDJDs0= github.com/golang-jwt/jwt/v5 v5.2.3/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -106,8 +98,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= -github.com/googleapis/gax-go/v2 v2.14.2 h1:eBLnkZ9635krYIPD+ag1USrOAI0Nr0QYF3+/3GqO0k0= -github.com/googleapis/gax-go/v2 v2.14.2/go.mod h1:ON64QhlJkhVtSqp4v1uaK92VyZ2gmvDQsweuyLV+8+w= github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= @@ -194,8 +184,6 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= -golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -212,8 +200,6 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= -golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= @@ -225,8 +211,6 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -241,8 +225,6 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= @@ -254,8 +236,6 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= -golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -269,8 +249,6 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= -golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= @@ -283,23 +261,14 @@ golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58 golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.240.0 h1:PxG3AA2UIqT1ofIzWV2COM3j3JagKTKSwy7L6RHNXNU= -google.golang.org/api v0.240.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50= google.golang.org/api v0.242.0 h1:7Lnb1nfnpvbkCiZek6IXKdJ0MFuAZNAJKQfA1ws62xg= google.golang.org/api v0.242.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50= -google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= google.golang.org/appengine/v2 v2.0.6 h1:LvPZLGuchSBslPBp+LAhihBeGSiRh1myRoYK4NtuBIw= google.golang.org/appengine/v2 v2.0.6/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI= -google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= -google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= google.golang.org/genproto v0.0.0-20250715232539-7130f93afb79 h1:Nt6z9UHqSlIdIGJdz6KhTIs2VRx/iOsA5iE8bmQNcxs= google.golang.org/genproto v0.0.0-20250715232539-7130f93afb79/go.mod h1:kTmlBHMPqR5uCZPBvwa2B18mvubkjyY3CRLI0c6fj0s= -google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY= -google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc= google.golang.org/genproto/googleapis/api v0.0.0-20250715232539-7130f93afb79 h1:iOye66xuaAK0WnkPuhQPUFy8eJcmwUXqGGP3om6IxX8= google.golang.org/genproto/googleapis/api v0.0.0-20250715232539-7130f93afb79/go.mod h1:HKJDgKsFUnv5VAGeQjz8kxcgDP0HoE0iZNp0OdZNlhE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/genproto/googleapis/rpc v0.0.0-20250715232539-7130f93afb79 h1:1ZwqphdOdWYXsUHgMpU/101nCtf/kSp9hOrcvFsnl10= google.golang.org/genproto/googleapis/rpc v0.0.0-20250715232539-7130f93afb79/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= diff --git a/server/message_cache.go b/server/message_cache.go index e314ace3..03cb4969 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/netip" + "path/filepath" "strings" "time" @@ -286,6 +287,12 @@ type messageCache struct { // newSqliteCache creates a SQLite file-backed cache func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*messageCache, error) { + // Check the parent directory of the database file (makes for friendly error messages) + parentDir := filepath.Dir(filename) + if !util.FileExists(parentDir) { + return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir) + } + // Open database db, err := sql.Open("sqlite3", filename) if err != nil { return nil, err diff --git a/server/server.go b/server/server.go index d585faa0..d3ef9cbb 100644 --- a/server/server.go +++ b/server/server.go @@ -200,8 +200,9 @@ func New(conf *Config) (*Server, error) { Filename: conf.AuthFile, StartupQueries: conf.AuthStartupQueries, DefaultAccess: conf.AuthDefault, - ProvisionedUsers: conf.AuthProvisionedUsers, - ProvisionedAccess: conf.AuthProvisionedAccess, + ProvisionEnabled: true, // Enable provisioning of users and access + ProvisionUsers: conf.AuthProvisionedUsers, + ProvisionAccess: conf.AuthProvisionedAccess, BcryptCost: conf.AuthBcryptCost, QueueWriterInterval: conf.AuthStatsQueueWriterInterval, } diff --git a/server/server.yml b/server/server.yml index db968498..02af7383 100644 --- a/server/server.yml +++ b/server/server.yml @@ -82,6 +82,10 @@ # set to "read-write" (default), "read-only", "write-only" or "deny-all". # - auth-startup-queries allows you to run commands when the database is initialized, e.g. to enable # WAL mode. This is similar to cache-startup-queries. See above for details. +# - auth-provision-users is a list of users that are automatically created when the server starts. +# Each entry is in the format "::", e.g. "phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:user" +# - auth-provision-access is a list of access control entries that are automatically created when the server starts. +# Each entry is in the format "::", e.g. "phil:mytopic:rw" or "phil:phil-*:rw". # # Debian/RPM package users: # Use /var/lib/ntfy/user.db as user database to avoid permission issues. The package @@ -94,6 +98,8 @@ # auth-file: # auth-default-access: "read-write" # auth-startup-queries: +# auth-provision-users: +# auth-provision-access: # If set, the X-Forwarded-For header (or whatever is configured in proxy-forwarded-header) is used to determine # the visitor IP address instead of the remote address of the connection. diff --git a/server/server_admin.go b/server/server_admin.go index eb362956..b724d4b7 100644 --- a/server/server_admin.go +++ b/server/server_admin.go @@ -25,7 +25,7 @@ func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visit for i, g := range grants[u.ID] { userGrants[i] = &apiUserGrantResponse{ Topic: g.TopicPattern, - Permission: g.Allow.String(), + Permission: g.Permission.String(), } } usersResponse[i] = &apiUserResponse{ diff --git a/user/manager.go b/user/manager.go index 8932f34a..f2f4875d 100644 --- a/user/manager.go +++ b/user/manager.go @@ -12,6 +12,7 @@ import ( "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/util" "net/netip" + "path/filepath" "strings" "sync" "time" @@ -75,6 +76,7 @@ const ( role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, prefs JSON NOT NULL DEFAULT '{}', sync_topic TEXT NOT NULL, + provisioned INT NOT NULL, stats_messages INT NOT NULL DEFAULT (0), stats_emails INT NOT NULL DEFAULT (0), stats_calls INT NOT NULL DEFAULT (0), @@ -97,6 +99,7 @@ const ( read INT NOT NULL, write INT NOT NULL, owner_user_id INT, + provisioned INT NOT NULL, PRIMARY KEY (user_id, topic), FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE @@ -121,8 +124,8 @@ const ( id INT PRIMARY KEY, version INT NOT NULL ); - INSERT INTO user (id, user, pass, role, sync_topic, created) - VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH()) + INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) + VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH()) ON CONFLICT (id) DO NOTHING; COMMIT; ` @@ -132,26 +135,26 @@ const ( ` selectUserByIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.id = ? ` selectUserByNameQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE user = ? ` selectUserByTokenQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u JOIN user_token tk on u.id = tk.user_id LEFT JOIN tier t on t.id = u.tier_id WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?) ` selectUserByStripeCustomerIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.stripe_customer_id = ? @@ -165,8 +168,8 @@ const ( ` insertUserQuery = ` - INSERT INTO user (id, user, pass, role, sync_topic, created) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) + VALUES (?, ?, ?, ?, ?, ?, ?) ` selectUsernamesQuery = ` SELECT user @@ -189,18 +192,18 @@ const ( deleteUserQuery = `DELETE FROM user WHERE user = ?` upsertUserAccessQuery = ` - INSERT INTO user_access (user_id, topic, read, write, owner_user_id) - VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?)))) + INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned) + VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?) ON CONFLICT (user_id, topic) - DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id + DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned ` selectUserAllAccessQuery = ` - SELECT user_id, topic, read, write + SELECT user_id, topic, read, write, provisioned FROM user_access ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic ` selectUserAccessQuery = ` - SELECT topic, read, write + SELECT topic, read, write, provisioned FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic @@ -244,7 +247,8 @@ const ( WHERE user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?) ` - deleteTopicAccessQuery = ` + deleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = 1` + deleteTopicAccessQuery = ` DELETE FROM user_access WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?)) AND topic = ? @@ -427,6 +431,15 @@ const ( migrate4To5UpdateQueries = ` UPDATE user_access SET topic = REPLACE(topic, '_', '\_'); ` + + // 5 -> 6 + migrate5To6UpdateQueries = ` + ALTER TABLE user ADD COLUMN provisioned INT NOT NULL DEFAULT (0); + ALTER TABLE user ALTER COLUMN provisioned DROP DEFAULT; + + ALTER TABLE user_access ADD COLUMN provisioned INT NOT NULL DEFAULT (0); + ALTER TABLE user_access ALTER COLUMN provisioned DROP DEFAULT; + ` ) var ( @@ -435,6 +448,7 @@ var ( 2: migrateFrom2, 3: migrateFrom3, 4: migrateFrom4, + 5: migrateFrom5, } ) @@ -452,8 +466,9 @@ type Config struct { Filename string // Database filename, e.g. "/var/lib/ntfy/user.db" StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers DefaultAccess Permission // Default permission if no ACL matches - ProvisionedUsers []*User // Predefined users to create on startup - ProvisionedAccess map[string][]*Grant // Predefined access grants to create on startup + ProvisionEnabled bool // Enable auto-provisioning of users and access grants + ProvisionUsers []*User // Predefined users to create on startup + ProvisionAccess map[string][]*Grant // Predefined access grants to create on startup QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database BcryptCost int // Cost of generated passwords; lowering makes testing faster } @@ -469,6 +484,11 @@ func NewManager(config *Config) (*Manager, error) { if config.QueueWriterInterval.Seconds() <= 0 { config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval } + // Check the parent directory of the database file (makes for friendly error messages) + parentDir := filepath.Dir(config.Filename) + if !util.FileExists(parentDir) { + return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir) + } // Open DB and run setup queries db, err := sql.Open("sqlite3", config.Filename) if err != nil { @@ -486,7 +506,7 @@ func NewManager(config *Config) (*Manager, error) { statsQueue: make(map[string]*Stats), tokenQueue: make(map[string]*TokenUpdate), } - if err := manager.provisionUsers(); err != nil { + if err := manager.maybeProvisionUsersAndAccess(); err != nil { return nil, err } go manager.asyncQueueWriter(config.QueueWriterInterval) @@ -586,7 +606,7 @@ func (a *Manager) Tokens(userID string) ([]*Token, error) { tokens := make([]*Token, 0) for { token, err := a.readToken(rows) - if err == ErrTokenNotFound { + if errors.Is(err, ErrTokenNotFound) { break } else if err != nil { return nil, err @@ -884,6 +904,13 @@ func (a *Manager) resolvePerms(base, perm Permission) error { // AddUser adds a user with the given username, password and role func (a *Manager) AddUser(username, password string, role Role, hashed bool) error { + return execTx(a.db, func(tx *sql.Tx) error { + return a.addUserTx(tx, username, password, role, hashed, false) + }) +} + +// AddUser adds a user with the given username, password and role +func (a *Manager) addUserTx(tx *sql.Tx, username, password string, role Role, hashed, provisioned bool) error { if !AllowedUsername(username) || !AllowedRole(role) { return ErrInvalidArgument } @@ -899,8 +926,8 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err } userID := util.RandomStringPrefix(userIDPrefix, userIDLength) syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix() - if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil { - if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique { + if _, err = tx.Exec(insertUserQuery, userID, username, hash, role, syncTopic, provisioned, now); err != nil { + if errors.Is(err, sqlite3.ErrConstraintUnique) { return ErrUserExists } return err @@ -911,11 +938,17 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err // RemoveUser deletes the user with the given username. The function returns nil on success, even // if the user did not exist in the first place. func (a *Manager) RemoveUser(username string) error { + return execTx(a.db, func(tx *sql.Tx) error { + return a.removeUserTx(tx, username) + }) +} + +func (a *Manager) removeUserTx(tx *sql.Tx, username string) error { if !AllowedUsername(username) { return ErrInvalidArgument } // Rows in user_access, user_token, etc. are deleted via foreign keys - if _, err := a.db.Exec(deleteUserQuery, username); err != nil { + if _, err := tx.Exec(deleteUserQuery, username); err != nil { return err } return nil @@ -1029,24 +1062,26 @@ func (a *Manager) userByToken(token string) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() var id, username, hash, role, prefs, syncTopic string + var provisioned bool var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString var messages, emails, calls int64 var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 if !rows.Next() { return nil, ErrUserNotFound } - if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { + if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } user := &User{ - ID: id, - Name: username, - Hash: hash, - Role: Role(role), - Prefs: &Prefs{}, - SyncTopic: syncTopic, + ID: id, + Name: username, + Hash: hash, + Role: Role(role), + Prefs: &Prefs{}, + SyncTopic: syncTopic, + Provisioned: provisioned, Stats: &Stats{ Messages: messages, Emails: emails, @@ -1097,8 +1132,8 @@ func (a *Manager) AllGrants() (map[string][]Grant, error) { grants := make(map[string][]Grant, 0) for rows.Next() { var userID, topic string - var read, write bool - if err := rows.Scan(&userID, &topic, &read, &write); err != nil { + var read, write, provisioned bool + if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -1108,7 +1143,8 @@ func (a *Manager) AllGrants() (map[string][]Grant, error) { } grants[userID] = append(grants[userID], Grant{ TopicPattern: fromSQLWildcard(topic), - Allow: NewPermission(read, write), + Permission: NewPermission(read, write), + Provisioned: provisioned, }) } return grants, nil @@ -1124,15 +1160,16 @@ func (a *Manager) Grants(username string) ([]Grant, error) { grants := make([]Grant, 0) for rows.Next() { var topic string - var read, write bool - if err := rows.Scan(&topic, &read, &write); err != nil { + var read, write, provisioned bool + if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } grants = append(grants, Grant{ TopicPattern: fromSQLWildcard(topic), - Allow: NewPermission(read, write), + Permission: NewPermission(read, write), + Provisioned: provisioned, }) } return grants, nil @@ -1218,9 +1255,14 @@ func (a *Manager) ReservationOwner(topic string) (string, error) { // ChangePassword changes a user's password func (a *Manager) ChangePassword(username, password string, hashed bool) error { + return execTx(a.db, func(tx *sql.Tx) error { + return a.changePasswordTx(tx, username, password, hashed) + }) +} + +func (a *Manager) changePasswordTx(tx *sql.Tx, username, password string, hashed bool) error { var hash []byte var err error - if hashed { hash = []byte(password) } else { @@ -1229,7 +1271,7 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error { return err } } - if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil { + if _, err := tx.Exec(updateUserPassQuery, hash, username); err != nil { return err } return nil @@ -1238,14 +1280,20 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error { // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin, // all existing access control entries (Grant) are removed, since they are no longer needed. func (a *Manager) ChangeRole(username string, role Role) error { + return execTx(a.db, func(tx *sql.Tx) error { + return a.changeRoleTx(tx, username, role) + }) +} + +func (a *Manager) changeRoleTx(tx *sql.Tx, username string, role Role) error { if !AllowedUsername(username) || !AllowedRole(role) { return ErrInvalidArgument } - if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil { + if _, err := tx.Exec(updateUserRoleQuery, string(role), username); err != nil { return err } if role == RoleAdmin { - if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil { + if _, err := tx.Exec(deleteUserAccessQuery, username, username); err != nil { return err } } @@ -1325,13 +1373,19 @@ func (a *Manager) AllowReservation(username string, topic string) error { // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry // owner may either be a user (username), or the system (empty). func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error { + return execTx(a.db, func(tx *sql.Tx) error { + return a.allowAccessTx(tx, username, topicPattern, permission, false) + }) +} + +func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string, permission Permission, provisioned bool) error { if !AllowedUsername(username) && username != Everyone { return ErrInvalidArgument } else if !AllowedTopicPattern(topicPattern) { return ErrInvalidArgument } owner := "" - if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil { + if _, err := tx.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner, provisioned); err != nil { return err } return nil @@ -1524,20 +1578,65 @@ func (a *Manager) Close() error { return a.db.Close() } -func (a *Manager) provisionUsers() error { - for _, user := range a.config.ProvisionedUsers { - if err := a.AddUser(user.Name, user.Hash, user.Role, true); err != nil && !errors.Is(err, ErrUserExists) { - return err - } +func (a *Manager) maybeProvisionUsersAndAccess() error { + if !a.config.ProvisionEnabled { + return nil } - for username, grants := range a.config.ProvisionedAccess { - for _, grant := range grants { - if err := a.AllowAccess(username, grant.TopicPattern, grant.Allow); err != nil { - return err + users, err := a.Users() + if err != nil { + return err + } + provisionUsernames := util.Map(a.config.ProvisionUsers, func(u *User) string { + return u.Name + }) + return execTx(a.db, func(tx *sql.Tx) error { + // Remove users that are provisioned, but not in the config anymore + for _, user := range users { + if user.Name == Everyone { + continue + } else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) { + log.Tag(tag).Info("Removing previously provisioned user %s", user.Name) + if err := a.removeUserTx(tx, user.Name); err != nil { + return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err) + } } } - } - return nil + // Add or update provisioned users + for _, user := range a.config.ProvisionUsers { + if user.Name == Everyone { + continue + } + existingUser, exists := util.Find(users, func(u *User) bool { + return u.Name == user.Name + }) + if !exists { + log.Tag(tag).Info("Adding provisioned user %s", user.Name) + if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) { + return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err) + } + } else if existingUser.Hash != user.Hash || existingUser.Role != user.Role { + log.Tag(tag).Info("Updating provisioned user %s", user.Name) + if err := a.changePasswordTx(tx, user.Name, user.Hash, true); err != nil { + return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err) + } + if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil { + return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err) + } + } + } + // Remove and (re-)add provisioned grants + if _, err := tx.Exec(deleteUserAccessProvisionedQuery); err != nil { + return err + } + for username, grants := range a.config.ProvisionAccess { + for _, grant := range grants { + if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil { + return err + } + } + } + return nil + }) } // toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards, @@ -1711,6 +1810,22 @@ func migrateFrom4(db *sql.DB) error { return tx.Commit() } +func migrateFrom5(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 5 to 6") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 6); err != nil { + return err + } + return tx.Commit() +} + func nullString(s string) sql.NullString { if s == "" { return sql.NullString{} @@ -1724,3 +1839,18 @@ func nullInt64(v int64) sql.NullInt64 { } return sql.NullInt64{Int64: v, Valid: true} } + +// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back. +func execTx(db *sql.DB, f func(tx *sql.Tx) error) error { + tx, err := db.Begin() + if err != nil { + return err + } + if err := f(tx); err != nil { + if e := tx.Rollback(); e != nil { + return err + } + return err + } + return tx.Commit() +} diff --git a/user/manager_test.go b/user/manager_test.go index b57c762c..42def63f 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -489,12 +489,12 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { benGrants, err := a.Grants("ben") require.Nil(t, err) require.Equal(t, 1, len(benGrants)) - require.Equal(t, PermissionReadWrite, benGrants[0].Allow) + require.Equal(t, PermissionReadWrite, benGrants[0].Permission) everyoneGrants, err := a.Grants(Everyone) require.Nil(t, err) require.Equal(t, 1, len(everyoneGrants)) - require.Equal(t, PermissionDenyAll, everyoneGrants[0].Allow) + require.Equal(t, PermissionDenyAll, everyoneGrants[0].Permission) benReservations, err := a.Reservations("ben") require.Nil(t, err) @@ -1201,16 +1201,16 @@ func TestMigrationFrom1(t *testing.T) { require.NotEqual(t, ben.SyncTopic, phil.SyncTopic) require.Equal(t, 2, len(benGrants)) require.Equal(t, "secret", benGrants[0].TopicPattern) - require.Equal(t, PermissionRead, benGrants[0].Allow) + require.Equal(t, PermissionRead, benGrants[0].Permission) require.Equal(t, "stats", benGrants[1].TopicPattern) - require.Equal(t, PermissionReadWrite, benGrants[1].Allow) + require.Equal(t, PermissionReadWrite, benGrants[1].Permission) require.Equal(t, "u_everyone", everyone.ID) require.Equal(t, Everyone, everyone.Name) require.Equal(t, RoleAnonymous, everyone.Role) require.Equal(t, 1, len(everyoneGrants)) require.Equal(t, "stats", everyoneGrants[0].TopicPattern) - require.Equal(t, PermissionRead, everyoneGrants[0].Allow) + require.Equal(t, PermissionRead, everyoneGrants[0].Permission) } func TestMigrationFrom4(t *testing.T) { diff --git a/user/types.go b/user/types.go index 6f6b1f69..90eeefce 100644 --- a/user/types.go +++ b/user/types.go @@ -12,17 +12,18 @@ import ( // User is a struct that represents a user type User struct { - ID string - Name string - Hash string // password hash (bcrypt) - Token string // Only set if token was used to log in - Role Role - Prefs *Prefs - Tier *Tier - Stats *Stats - Billing *Billing - SyncTopic string - Deleted bool + ID string + Name string + Hash string // Password hash (bcrypt) + Token string // Only set if token was used to log in + Role Role + Prefs *Prefs + Tier *Tier + Stats *Stats + Billing *Billing + SyncTopic string + Provisioned bool // Whether the user was provisioned by the config file + Deleted bool // Whether the user was soft-deleted } // TierID returns the ID of the User.Tier, or an empty string if the user has no tier, @@ -148,7 +149,8 @@ type Billing struct { // Grant is a struct that represents an access control entry to a topic by a user type Grant struct { TopicPattern string // May include wildcard (*) - Allow Permission + Permission Permission + Provisioned bool // Whether the grant was provisioned by the config file } // Reservation is a struct that represents the ownership over a topic by a user diff --git a/util/util.go b/util/util.go index 73b227af..3648e3a4 100644 --- a/util/util.go +++ b/util/util.go @@ -120,6 +120,18 @@ func Filter[T any](slice []T, f func(T) bool) []T { return result } +// Find returns the first element in the slice that satisfies the given function, and a boolean indicating +// whether such an element was found. If no element is found, it returns the zero value of T and false. +func Find[T any](slice []T, f func(T) bool) (T, bool) { + for _, v := range slice { + if f(v) { + return v, true + } + } + var zero T + return zero, false +} + // RandomString returns a random string with a given length func RandomString(length int) string { return RandomStringPrefix("", length)