Refactor tests again: forEachBackend

This commit is contained in:
binwiederhier
2026-02-21 21:02:41 -05:00
parent b1eb90addc
commit 459c80ef9b
9 changed files with 1562 additions and 2096 deletions

View File

@@ -1,93 +0,0 @@
package message_test
import (
"testing"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/message"
"github.com/stretchr/testify/require"
)
func newTestPostgresStore(t *testing.T) message.Store {
testDB := dbtest.CreateTestPostgres(t)
store, err := message.NewPostgresStore(testDB, 0, 0)
require.Nil(t, err)
return store
}
func TestPostgresStore_Messages(t *testing.T) {
testCacheMessages(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newTestPostgresStore(t))
}
func TestPostgresStore_Topics(t *testing.T) {
testCacheTopics(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newTestPostgresStore(t))
}
func TestPostgresStore_Prune(t *testing.T) {
testCachePrune(t, newTestPostgresStore(t))
}
func TestPostgresStore_Attachments(t *testing.T) {
testCacheAttachments(t, newTestPostgresStore(t))
}
func TestPostgresStore_AttachmentsExpired(t *testing.T) {
testCacheAttachmentsExpired(t, newTestPostgresStore(t))
}
func TestPostgresStore_Sender(t *testing.T) {
testSender(t, newTestPostgresStore(t))
}
func TestPostgresStore_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessageByID(t *testing.T) {
testMessageByID(t, newTestPostgresStore(t))
}
func TestPostgresStore_MarkPublished(t *testing.T) {
testMarkPublished(t, newTestPostgresStore(t))
}
func TestPostgresStore_ExpireMessages(t *testing.T) {
testExpireMessages(t, newTestPostgresStore(t))
}
func TestPostgresStore_MarkAttachmentsDeleted(t *testing.T) {
testMarkAttachmentsDeleted(t, newTestPostgresStore(t))
}
func TestPostgresStore_Stats(t *testing.T) {
testStats(t, newTestPostgresStore(t))
}
func TestPostgresStore_AddMessages(t *testing.T) {
testAddMessages(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesDue(t *testing.T) {
testMessagesDue(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessageFieldRoundTrip(t *testing.T) {
testMessageFieldRoundTrip(t, newTestPostgresStore(t))
}

View File

@@ -13,158 +13,6 @@ import (
"heckel.io/ntfy/v2/model"
)
func TestSqliteStore_Messages(t *testing.T) {
testCacheMessages(t, newSqliteTestStore(t))
}
func TestMemStore_Messages(t *testing.T) {
testCacheMessages(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newMemTestStore(t))
}
func TestSqliteStore_Topics(t *testing.T) {
testCacheTopics(t, newSqliteTestStore(t))
}
func TestMemStore_Topics(t *testing.T) {
testCacheTopics(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newMemTestStore(t))
}
func TestSqliteStore_Prune(t *testing.T) {
testCachePrune(t, newSqliteTestStore(t))
}
func TestMemStore_Prune(t *testing.T) {
testCachePrune(t, newMemTestStore(t))
}
func TestSqliteStore_Attachments(t *testing.T) {
testCacheAttachments(t, newSqliteTestStore(t))
}
func TestMemStore_Attachments(t *testing.T) {
testCacheAttachments(t, newMemTestStore(t))
}
func TestSqliteStore_AttachmentsExpired(t *testing.T) {
testCacheAttachmentsExpired(t, newSqliteTestStore(t))
}
func TestMemStore_AttachmentsExpired(t *testing.T) {
testCacheAttachmentsExpired(t, newMemTestStore(t))
}
func TestSqliteStore_Sender(t *testing.T) {
testSender(t, newSqliteTestStore(t))
}
func TestMemStore_Sender(t *testing.T) {
testSender(t, newMemTestStore(t))
}
func TestSqliteStore_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newSqliteTestStore(t))
}
func TestMemStore_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newMemTestStore(t))
}
func TestSqliteStore_MessageByID(t *testing.T) {
testMessageByID(t, newSqliteTestStore(t))
}
func TestMemStore_MessageByID(t *testing.T) {
testMessageByID(t, newMemTestStore(t))
}
func TestSqliteStore_MarkPublished(t *testing.T) {
testMarkPublished(t, newSqliteTestStore(t))
}
func TestMemStore_MarkPublished(t *testing.T) {
testMarkPublished(t, newMemTestStore(t))
}
func TestSqliteStore_ExpireMessages(t *testing.T) {
testExpireMessages(t, newSqliteTestStore(t))
}
func TestMemStore_ExpireMessages(t *testing.T) {
testExpireMessages(t, newMemTestStore(t))
}
func TestSqliteStore_MarkAttachmentsDeleted(t *testing.T) {
testMarkAttachmentsDeleted(t, newSqliteTestStore(t))
}
func TestMemStore_MarkAttachmentsDeleted(t *testing.T) {
testMarkAttachmentsDeleted(t, newMemTestStore(t))
}
func TestSqliteStore_Stats(t *testing.T) {
testStats(t, newSqliteTestStore(t))
}
func TestMemStore_Stats(t *testing.T) {
testStats(t, newMemTestStore(t))
}
func TestSqliteStore_AddMessages(t *testing.T) {
testAddMessages(t, newSqliteTestStore(t))
}
func TestMemStore_AddMessages(t *testing.T) {
testAddMessages(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesDue(t *testing.T) {
testMessagesDue(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesDue(t *testing.T) {
testMessagesDue(t, newMemTestStore(t))
}
func TestSqliteStore_MessageFieldRoundTrip(t *testing.T) {
testMessageFieldRoundTrip(t, newSqliteTestStore(t))
}
func TestMemStore_MessageFieldRoundTrip(t *testing.T) {
testMessageFieldRoundTrip(t, newMemTestStore(t))
}
func TestSqliteStore_Migration_From0(t *testing.T) {
filename := newSqliteTestStoreFile(t)
db, err := sql.Open("sqlite3", filename)
@@ -419,14 +267,6 @@ func TestNopStore(t *testing.T) {
require.Empty(t, topics)
}
func newSqliteTestStore(t *testing.T) message.Store {
filename := filepath.Join(t.TempDir(), "cache.db")
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func newSqliteTestStoreFile(t *testing.T) string {
return filepath.Join(t.TempDir(), "cache.db")
}
@@ -438,13 +278,6 @@ func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) m
return s
}
func newMemTestStore(t *testing.T) message.Store {
s, err := message.NewMemStore()
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func checkSqliteSchemaVersion(t *testing.T, filename string) {
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)

View File

@@ -2,6 +2,7 @@ package message_test
import (
"net/netip"
"path/filepath"
"sort"
"sync"
"testing"
@@ -9,11 +10,47 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
)
func testCacheMessages(t *testing.T, s message.Store) {
func newSqliteTestStore(t *testing.T) message.Store {
filename := filepath.Join(t.TempDir(), "cache.db")
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func newMemTestStore(t *testing.T) message.Store {
s, err := message.NewMemStore()
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func newTestPostgresStore(t *testing.T) message.Store {
testDB := dbtest.CreateTestPostgres(t)
store, err := message.NewPostgresStore(testDB, 0, 0)
require.Nil(t, err)
return store
}
func forEachBackend(t *testing.T, f func(t *testing.T, s message.Store)) {
t.Run("sqlite", func(t *testing.T) {
f(t, newSqliteTestStore(t))
})
t.Run("mem", func(t *testing.T) {
f(t, newMemTestStore(t))
})
t.Run("postgres", func(t *testing.T) {
f(t, newTestPostgresStore(t))
})
}
func TestStore_Messages(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "my message")
m1.Time = 1
@@ -82,9 +119,11 @@ func testCacheMessages(t *testing.T, s message.Store) {
// non-existing: since all
messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false)
require.Empty(t, messages)
})
}
func testCacheMessagesLock(t *testing.T, s message.Store) {
func TestStore_MessagesLock(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
var wg sync.WaitGroup
for i := 0; i < 5000; i++ {
wg.Add(1)
@@ -94,9 +133,11 @@ func testCacheMessagesLock(t *testing.T, s message.Store) {
}()
}
wg.Wait()
})
}
func testCacheMessagesScheduled(t *testing.T, s message.Store) {
func TestStore_MessagesScheduled(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "message 1")
m2 := model.NewDefaultMessage("mytopic", "message 2")
m2.Time = time.Now().Add(time.Hour).Unix()
@@ -120,9 +161,11 @@ func testCacheMessagesScheduled(t *testing.T, s message.Store) {
messages, _ = s.MessagesDue()
require.Empty(t, messages)
})
}
func testCacheTopics(t *testing.T, s message.Store) {
func TestStore_Topics(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
@@ -135,9 +178,11 @@ func testCacheTopics(t *testing.T, s message.Store) {
require.Equal(t, 2, len(topics))
require.Contains(t, topics, "topic1")
require.Contains(t, topics, "topic2")
})
}
func testCacheMessagesTagsPrioAndTitle(t *testing.T, s message.Store) {
func TestStore_MessagesTagsPrioAndTitle(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
m := model.NewDefaultMessage("mytopic", "some message")
m.Tags = []string{"tag1", "tag2"}
m.Priority = 5
@@ -148,9 +193,11 @@ func testCacheMessagesTagsPrioAndTitle(t *testing.T, s message.Store) {
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
require.Equal(t, 5, messages[0].Priority)
require.Equal(t, "some title", messages[0].Title)
})
}
func testCacheMessagesSinceID(t *testing.T, s message.Store) {
func TestStore_MessagesSinceID(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "message 1")
m1.Time = 100
m2 := model.NewDefaultMessage("mytopic", "message 2")
@@ -210,9 +257,11 @@ func testCacheMessagesSinceID(t *testing.T, s message.Store) {
require.Equal(t, 2, len(messages))
require.Equal(t, "message 5", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message)
})
}
func testCachePrune(t *testing.T, s message.Store) {
func TestStore_Prune(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
now := time.Now().Unix()
m1 := model.NewDefaultMessage("mytopic", "my message")
@@ -249,9 +298,11 @@ func testCachePrune(t *testing.T, s message.Store) {
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
})
}
func testCacheAttachments(t *testing.T, s message.Store) {
func TestStore_Attachments(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
m := model.NewDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
@@ -326,9 +377,11 @@ func testCacheAttachments(t *testing.T, s message.Store) {
size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa")
require.Nil(t, err)
require.Equal(t, int64(20000), size)
})
}
func testCacheAttachmentsExpired(t *testing.T, s message.Store) {
func TestStore_AttachmentsExpired(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
m := model.NewDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
@@ -377,9 +430,11 @@ func testCacheAttachmentsExpired(t *testing.T, s message.Store) {
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "m4", ids[0])
})
}
func testSender(t *testing.T, s message.Store) {
func TestStore_Sender(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "mymessage")
m1.Sender = netip.MustParseAddr("1.2.3.4")
require.Nil(t, s.AddMessage(m1))
@@ -392,9 +447,11 @@ func testSender(t *testing.T, s message.Store) {
require.Equal(t, 2, len(messages))
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
require.Equal(t, messages[1].Sender, netip.Addr{})
})
}
func testDeleteScheduledBySequenceID(t *testing.T, s message.Store) {
func TestStore_DeleteScheduledBySequenceID(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
// Create a scheduled (unpublished) message
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
scheduledMsg.ID = "scheduled1"
@@ -457,9 +514,11 @@ func testDeleteScheduledBySequenceID(t *testing.T, s message.Store) {
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
})
}
func testMessageByID(t *testing.T, s message.Store) {
func TestStore_MessageByID(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
// Add a message
m := model.NewDefaultMessage("mytopic", "some message")
m.Title = "some title"
@@ -480,10 +539,12 @@ func testMessageByID(t *testing.T, s message.Store) {
// Non-existent ID returns ErrMessageNotFound
_, err = s.Message("doesnotexist")
require.Equal(t, model.ErrMessageNotFound, err)
})
}
func testMarkPublished(t *testing.T, s message.Store) {
// Add a scheduled message (future time → unpublished)
func TestStore_MarkPublished(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
// Add a scheduled message (future time -> unpublished)
m := model.NewDefaultMessage("mytopic", "scheduled message")
m.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m))
@@ -506,9 +567,11 @@ func testMarkPublished(t *testing.T, s message.Store) {
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "scheduled message", messages[0].Message)
})
}
func testExpireMessages(t *testing.T, s message.Store) {
func TestStore_ExpireMessages(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
// Add messages to two topics
m1 := model.NewDefaultMessage("topic1", "message 1")
m1.Expires = time.Now().Add(time.Hour).Unix()
@@ -545,9 +608,11 @@ func testExpireMessages(t *testing.T, s message.Store) {
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "message 3", messages[0].Message)
})
}
func testMarkAttachmentsDeleted(t *testing.T, s message.Store) {
func TestStore_MarkAttachmentsDeleted(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
// Add a message with an expired attachment (file needs cleanup)
m1 := model.NewDefaultMessage("mytopic", "old file")
m1.ID = "msg1"
@@ -602,9 +667,11 @@ func testMarkAttachmentsDeleted(t *testing.T, s message.Store) {
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
})
}
func testStats(t *testing.T, s message.Store) {
func TestStore_Stats(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
// Initial stats should be zero
messages, err := s.Stats()
require.Nil(t, err)
@@ -621,9 +688,11 @@ func testStats(t *testing.T, s message.Store) {
messages, err = s.Stats()
require.Nil(t, err)
require.Equal(t, int64(100), messages)
})
}
func testAddMessages(t *testing.T, s message.Store) {
func TestStore_AddMessages(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
// Batch add multiple messages
msgs := []*model.Message{
model.NewDefaultMessage("mytopic", "batch 1"),
@@ -650,9 +719,11 @@ func testAddMessages(t *testing.T, s message.Store) {
model.NewKeepaliveMessage("mytopic"),
}
require.NotNil(t, s.AddMessages(badMsgs))
})
}
func testMessagesDue(t *testing.T, s message.Store) {
func TestStore_MessagesDue(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
// Add a message scheduled in the past (i.e. it's due now)
m1 := model.NewDefaultMessage("mytopic", "due message")
m1.Time = time.Now().Add(-time.Second).Unix()
@@ -692,9 +763,11 @@ func testMessagesDue(t *testing.T, s message.Store) {
require.Nil(t, err)
require.Equal(t, 1, len(due))
require.Equal(t, "truly due message", due[0].Message)
})
}
func testMessageFieldRoundTrip(t *testing.T, s message.Store) {
func TestStore_MessageFieldRoundTrip(t *testing.T) {
forEachBackend(t, func(t *testing.T, s message.Store) {
// Create a message with all fields populated
m := model.NewDefaultMessage("mytopic", "hello world")
m.SequenceID = "custom_seq_id"
@@ -764,4 +837,5 @@ func testMessageFieldRoundTrip(t *testing.T, s message.Store) {
require.Equal(t, "PUT", retrieved.Actions[1].Method)
require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"])
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
})
}

View File

@@ -1,180 +0,0 @@
package user_test
import (
"testing"
"github.com/stretchr/testify/require"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/user"
)
func newTestPostgresStore(t *testing.T) user.Store {
testDB := dbtest.CreateTestPostgres(t)
store, err := user.NewPostgresStore(testDB)
require.Nil(t, err)
return store
}
func TestPostgresStoreAddUser(t *testing.T) {
testStoreAddUser(t, newTestPostgresStore(t))
}
func TestPostgresStoreAddUserAlreadyExists(t *testing.T) {
testStoreAddUserAlreadyExists(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveUser(t *testing.T) {
testStoreRemoveUser(t, newTestPostgresStore(t))
}
func TestPostgresStoreUserByID(t *testing.T) {
testStoreUserByID(t, newTestPostgresStore(t))
}
func TestPostgresStoreUserByToken(t *testing.T) {
testStoreUserByToken(t, newTestPostgresStore(t))
}
func TestPostgresStoreUserByStripeCustomer(t *testing.T) {
testStoreUserByStripeCustomer(t, newTestPostgresStore(t))
}
func TestPostgresStoreUsers(t *testing.T) {
testStoreUsers(t, newTestPostgresStore(t))
}
func TestPostgresStoreUsersCount(t *testing.T) {
testStoreUsersCount(t, newTestPostgresStore(t))
}
func TestPostgresStoreChangePassword(t *testing.T) {
testStoreChangePassword(t, newTestPostgresStore(t))
}
func TestPostgresStoreChangeRole(t *testing.T) {
testStoreChangeRole(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokens(t *testing.T) {
testStoreTokens(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokenChangeLabel(t *testing.T) {
testStoreTokenChangeLabel(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokenRemove(t *testing.T) {
testStoreTokenRemove(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokenRemoveExpired(t *testing.T) {
testStoreTokenRemoveExpired(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokenRemoveExcess(t *testing.T) {
testStoreTokenRemoveExcess(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokenUpdateLastAccess(t *testing.T) {
testStoreTokenUpdateLastAccess(t, newTestPostgresStore(t))
}
func TestPostgresStoreAllowAccess(t *testing.T) {
testStoreAllowAccess(t, newTestPostgresStore(t))
}
func TestPostgresStoreAllowAccessReadOnly(t *testing.T) {
testStoreAllowAccessReadOnly(t, newTestPostgresStore(t))
}
func TestPostgresStoreResetAccess(t *testing.T) {
testStoreResetAccess(t, newTestPostgresStore(t))
}
func TestPostgresStoreResetAccessAll(t *testing.T) {
testStoreResetAccessAll(t, newTestPostgresStore(t))
}
func TestPostgresStoreAuthorizeTopicAccess(t *testing.T) {
testStoreAuthorizeTopicAccess(t, newTestPostgresStore(t))
}
func TestPostgresStoreAuthorizeTopicAccessNotFound(t *testing.T) {
testStoreAuthorizeTopicAccessNotFound(t, newTestPostgresStore(t))
}
func TestPostgresStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
testStoreAuthorizeTopicAccessDenyAll(t, newTestPostgresStore(t))
}
func TestPostgresStoreReservations(t *testing.T) {
testStoreReservations(t, newTestPostgresStore(t))
}
func TestPostgresStoreReservationsCount(t *testing.T) {
testStoreReservationsCount(t, newTestPostgresStore(t))
}
func TestPostgresStoreHasReservation(t *testing.T) {
testStoreHasReservation(t, newTestPostgresStore(t))
}
func TestPostgresStoreReservationOwner(t *testing.T) {
testStoreReservationOwner(t, newTestPostgresStore(t))
}
func TestPostgresStoreTiers(t *testing.T) {
testStoreTiers(t, newTestPostgresStore(t))
}
func TestPostgresStoreTierUpdate(t *testing.T) {
testStoreTierUpdate(t, newTestPostgresStore(t))
}
func TestPostgresStoreTierRemove(t *testing.T) {
testStoreTierRemove(t, newTestPostgresStore(t))
}
func TestPostgresStoreTierByStripePrice(t *testing.T) {
testStoreTierByStripePrice(t, newTestPostgresStore(t))
}
func TestPostgresStoreChangeTier(t *testing.T) {
testStoreChangeTier(t, newTestPostgresStore(t))
}
func TestPostgresStorePhoneNumbers(t *testing.T) {
testStorePhoneNumbers(t, newTestPostgresStore(t))
}
func TestPostgresStoreChangeSettings(t *testing.T) {
testStoreChangeSettings(t, newTestPostgresStore(t))
}
func TestPostgresStoreChangeBilling(t *testing.T) {
testStoreChangeBilling(t, newTestPostgresStore(t))
}
func TestPostgresStoreUpdateStats(t *testing.T) {
testStoreUpdateStats(t, newTestPostgresStore(t))
}
func TestPostgresStoreResetStats(t *testing.T) {
testStoreResetStats(t, newTestPostgresStore(t))
}
func TestPostgresStoreMarkUserRemoved(t *testing.T) {
testStoreMarkUserRemoved(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveDeletedUsers(t *testing.T) {
testStoreRemoveDeletedUsers(t, newTestPostgresStore(t))
}
func TestPostgresStoreAllGrants(t *testing.T) {
testStoreAllGrants(t, newTestPostgresStore(t))
}
func TestPostgresStoreOtherAccessCount(t *testing.T) {
testStoreOtherAccessCount(t, newTestPostgresStore(t))
}

View File

@@ -1,180 +0,0 @@
package user_test
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/user"
)
func newTestSQLiteStore(t *testing.T) user.Store {
store, err := user.NewSQLiteStore(filepath.Join(t.TempDir(), "user.db"), "")
require.Nil(t, err)
t.Cleanup(func() { store.Close() })
return store
}
func TestSQLiteStoreAddUser(t *testing.T) {
testStoreAddUser(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAddUserAlreadyExists(t *testing.T) {
testStoreAddUserAlreadyExists(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveUser(t *testing.T) {
testStoreRemoveUser(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUserByID(t *testing.T) {
testStoreUserByID(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUserByToken(t *testing.T) {
testStoreUserByToken(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUserByStripeCustomer(t *testing.T) {
testStoreUserByStripeCustomer(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUsers(t *testing.T) {
testStoreUsers(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUsersCount(t *testing.T) {
testStoreUsersCount(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreChangePassword(t *testing.T) {
testStoreChangePassword(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreChangeRole(t *testing.T) {
testStoreChangeRole(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokens(t *testing.T) {
testStoreTokens(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokenChangeLabel(t *testing.T) {
testStoreTokenChangeLabel(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokenRemove(t *testing.T) {
testStoreTokenRemove(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokenRemoveExpired(t *testing.T) {
testStoreTokenRemoveExpired(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokenRemoveExcess(t *testing.T) {
testStoreTokenRemoveExcess(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokenUpdateLastAccess(t *testing.T) {
testStoreTokenUpdateLastAccess(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAllowAccess(t *testing.T) {
testStoreAllowAccess(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAllowAccessReadOnly(t *testing.T) {
testStoreAllowAccessReadOnly(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreResetAccess(t *testing.T) {
testStoreResetAccess(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreResetAccessAll(t *testing.T) {
testStoreResetAccessAll(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAuthorizeTopicAccess(t *testing.T) {
testStoreAuthorizeTopicAccess(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAuthorizeTopicAccessNotFound(t *testing.T) {
testStoreAuthorizeTopicAccessNotFound(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
testStoreAuthorizeTopicAccessDenyAll(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreReservations(t *testing.T) {
testStoreReservations(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreReservationsCount(t *testing.T) {
testStoreReservationsCount(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreHasReservation(t *testing.T) {
testStoreHasReservation(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreReservationOwner(t *testing.T) {
testStoreReservationOwner(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTiers(t *testing.T) {
testStoreTiers(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTierUpdate(t *testing.T) {
testStoreTierUpdate(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTierRemove(t *testing.T) {
testStoreTierRemove(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTierByStripePrice(t *testing.T) {
testStoreTierByStripePrice(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreChangeTier(t *testing.T) {
testStoreChangeTier(t, newTestSQLiteStore(t))
}
func TestSQLiteStorePhoneNumbers(t *testing.T) {
testStorePhoneNumbers(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreChangeSettings(t *testing.T) {
testStoreChangeSettings(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreChangeBilling(t *testing.T) {
testStoreChangeBilling(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUpdateStats(t *testing.T) {
testStoreUpdateStats(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreResetStats(t *testing.T) {
testStoreResetStats(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreMarkUserRemoved(t *testing.T) {
testStoreMarkUserRemoved(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveDeletedUsers(t *testing.T) {
testStoreRemoveDeletedUsers(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAllGrants(t *testing.T) {
testStoreAllGrants(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreOtherAccessCount(t *testing.T) {
testStoreOtherAccessCount(t, newTestSQLiteStore(t))
}

View File

@@ -2,14 +2,32 @@ package user_test
import (
"net/netip"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/user"
)
func testStoreAddUser(t *testing.T, store user.Store) {
func forEachStoreBackend(t *testing.T, f func(t *testing.T, store user.Store)) {
t.Run("sqlite", func(t *testing.T) {
store, err := user.NewSQLiteStore(filepath.Join(t.TempDir(), "user.db"), "")
require.Nil(t, err)
t.Cleanup(func() { store.Close() })
f(t, store)
})
t.Run("postgres", func(t *testing.T) {
testDB := dbtest.CreateTestPostgres(t)
store, err := user.NewPostgresStore(testDB)
require.Nil(t, err)
f(t, store)
})
}
func TestStoreAddUser(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -18,14 +36,18 @@ func testStoreAddUser(t *testing.T, store user.Store) {
require.False(t, u.Provisioned)
require.NotEmpty(t, u.ID)
require.NotEmpty(t, u.SyncTopic)
})
}
func testStoreAddUserAlreadyExists(t *testing.T, store user.Store) {
func TestStoreAddUserAlreadyExists(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Equal(t, user.ErrUserExists, store.AddUser("phil", "philhash", user.RoleUser, false))
})
}
func testStoreRemoveUser(t *testing.T, store user.Store) {
func TestStoreRemoveUser(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -34,9 +56,11 @@ func testStoreRemoveUser(t *testing.T, store user.Store) {
require.Nil(t, store.RemoveUser("phil"))
_, err = store.User("phil")
require.Equal(t, user.ErrUserNotFound, err)
})
}
func testStoreUserByID(t *testing.T, store user.Store) {
func TestStoreUserByID(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleAdmin, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -45,9 +69,11 @@ func testStoreUserByID(t *testing.T, store user.Store) {
require.Nil(t, err)
require.Equal(t, u.Name, u2.Name)
require.Equal(t, u.ID, u2.ID)
})
}
func testStoreUserByToken(t *testing.T, store user.Store) {
func TestStoreUserByToken(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -59,9 +85,11 @@ func testStoreUserByToken(t *testing.T, store user.Store) {
u2, err := store.UserByToken(tk.Value)
require.Nil(t, err)
require.Equal(t, "phil", u2.Name)
})
}
func testStoreUserByStripeCustomer(t *testing.T, store user.Store) {
func TestStoreUserByStripeCustomer(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.ChangeBilling("phil", &user.Billing{
StripeCustomerID: "cus_test123",
@@ -72,18 +100,22 @@ func testStoreUserByStripeCustomer(t *testing.T, store user.Store) {
require.Nil(t, err)
require.Equal(t, "phil", u.Name)
require.Equal(t, "cus_test123", u.Billing.StripeCustomerID)
})
}
func testStoreUsers(t *testing.T, store user.Store) {
func TestStoreUsers(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AddUser("ben", "benhash", user.RoleAdmin, false))
users, err := store.Users()
require.Nil(t, err)
require.True(t, len(users) >= 3) // phil, ben, and the everyone user
})
}
func testStoreUsersCount(t *testing.T, store user.Store) {
func TestStoreUsersCount(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
count, err := store.UsersCount()
require.Nil(t, err)
require.True(t, count >= 1) // At least the everyone user
@@ -92,9 +124,11 @@ func testStoreUsersCount(t *testing.T, store user.Store) {
count2, err := store.UsersCount()
require.Nil(t, err)
require.Equal(t, count+1, count2)
})
}
func testStoreChangePassword(t *testing.T, store user.Store) {
func TestStoreChangePassword(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -104,9 +138,11 @@ func testStoreChangePassword(t *testing.T, store user.Store) {
u, err = store.User("phil")
require.Nil(t, err)
require.Equal(t, "newhash", u.Hash)
})
}
func testStoreChangeRole(t *testing.T, store user.Store) {
func TestStoreChangeRole(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -116,9 +152,11 @@ func testStoreChangeRole(t *testing.T, store user.Store) {
u, err = store.User("phil")
require.Nil(t, err)
require.Equal(t, user.RoleAdmin, u.Role)
})
}
func testStoreTokens(t *testing.T, store user.Store) {
func TestStoreTokens(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -148,9 +186,11 @@ func testStoreTokens(t *testing.T, store user.Store) {
count, err := store.TokenCount(u.ID)
require.Nil(t, err)
require.Equal(t, 1, count)
})
}
func testStoreTokenChangeLabel(t *testing.T, store user.Store) {
func TestStoreTokenChangeLabel(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -162,9 +202,11 @@ func testStoreTokenChangeLabel(t *testing.T, store user.Store) {
tk, err := store.Token(u.ID, "tk_abc")
require.Nil(t, err)
require.Equal(t, "new label", tk.Label)
})
}
func testStoreTokenRemove(t *testing.T, store user.Store) {
func TestStoreTokenRemove(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -175,9 +217,11 @@ func testStoreTokenRemove(t *testing.T, store user.Store) {
require.Nil(t, store.RemoveToken(u.ID, "tk_abc"))
_, err = store.Token(u.ID, "tk_abc")
require.Equal(t, user.ErrTokenNotFound, err)
})
}
func testStoreTokenRemoveExpired(t *testing.T, store user.Store) {
func TestStoreTokenRemoveExpired(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -198,9 +242,11 @@ func testStoreTokenRemoveExpired(t *testing.T, store user.Store) {
tk, err := store.Token(u.ID, "tk_active")
require.Nil(t, err)
require.Equal(t, "tk_active", tk.Value)
})
}
func testStoreTokenRemoveExcess(t *testing.T, store user.Store) {
func TestStoreTokenRemoveExcess(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -231,9 +277,11 @@ func testStoreTokenRemoveExcess(t *testing.T, store user.Store) {
require.Nil(t, err)
_, err = store.Token(u.ID, "tk_c")
require.Nil(t, err)
})
}
func testStoreTokenUpdateLastAccess(t *testing.T, store user.Store) {
func TestStoreTokenUpdateLastAccess(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -249,9 +297,11 @@ func testStoreTokenUpdateLastAccess(t *testing.T, store user.Store) {
require.Nil(t, err)
require.Equal(t, newTime.Unix(), tk.LastAccess.Unix())
require.Equal(t, newOrigin, tk.LastOrigin)
})
}
func testStoreAllowAccess(t *testing.T, store user.Store) {
func TestStoreAllowAccess(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
@@ -260,9 +310,11 @@ func testStoreAllowAccess(t *testing.T, store user.Store) {
require.Len(t, grants, 1)
require.Equal(t, "mytopic", grants[0].TopicPattern)
require.True(t, grants[0].Permission.IsReadWrite())
})
}
func testStoreAllowAccessReadOnly(t *testing.T, store user.Store) {
func TestStoreAllowAccessReadOnly(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "announcements", true, false, "", false))
@@ -271,9 +323,11 @@ func testStoreAllowAccessReadOnly(t *testing.T, store user.Store) {
require.Len(t, grants, 1)
require.True(t, grants[0].Permission.IsRead())
require.False(t, grants[0].Permission.IsWrite())
})
}
func testStoreResetAccess(t *testing.T, store user.Store) {
func TestStoreResetAccess(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
@@ -287,9 +341,11 @@ func testStoreResetAccess(t *testing.T, store user.Store) {
require.Nil(t, err)
require.Len(t, grants, 1)
require.Equal(t, "topic2", grants[0].TopicPattern)
})
}
func testStoreResetAccessAll(t *testing.T, store user.Store) {
func TestStoreResetAccessAll(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
@@ -298,9 +354,11 @@ func testStoreResetAccessAll(t *testing.T, store user.Store) {
grants, err := store.Grants("phil")
require.Nil(t, err)
require.Len(t, grants, 0)
})
}
func testStoreAuthorizeTopicAccess(t *testing.T, store user.Store) {
func TestStoreAuthorizeTopicAccess(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
@@ -309,17 +367,21 @@ func testStoreAuthorizeTopicAccess(t *testing.T, store user.Store) {
require.True(t, found)
require.True(t, read)
require.True(t, write)
})
}
func testStoreAuthorizeTopicAccessNotFound(t *testing.T, store user.Store) {
func TestStoreAuthorizeTopicAccessNotFound(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
_, _, found, err := store.AuthorizeTopicAccess("phil", "other")
require.Nil(t, err)
require.False(t, found)
})
}
func testStoreAuthorizeTopicAccessDenyAll(t *testing.T, store user.Store) {
func TestStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "secret", false, false, "", false))
@@ -328,9 +390,11 @@ func testStoreAuthorizeTopicAccessDenyAll(t *testing.T, store user.Store) {
require.True(t, found)
require.False(t, read)
require.False(t, write)
})
}
func testStoreReservations(t *testing.T, store user.Store) {
func TestStoreReservations(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
require.Nil(t, store.AllowAccess(user.Everyone, "mytopic", true, false, "phil", false))
@@ -342,9 +406,11 @@ func testStoreReservations(t *testing.T, store user.Store) {
require.True(t, reservations[0].Owner.IsReadWrite())
require.True(t, reservations[0].Everyone.IsRead())
require.False(t, reservations[0].Everyone.IsWrite())
})
}
func testStoreReservationsCount(t *testing.T, store user.Store) {
func TestStoreReservationsCount(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "phil", false))
require.Nil(t, store.AllowAccess("phil", "topic2", true, true, "phil", false))
@@ -352,9 +418,11 @@ func testStoreReservationsCount(t *testing.T, store user.Store) {
count, err := store.ReservationsCount("phil")
require.Nil(t, err)
require.Equal(t, int64(2), count)
})
}
func testStoreHasReservation(t *testing.T, store user.Store) {
func TestStoreHasReservation(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
@@ -365,9 +433,11 @@ func testStoreHasReservation(t *testing.T, store user.Store) {
has, err = store.HasReservation("phil", "other")
require.Nil(t, err)
require.False(t, has)
})
}
func testStoreReservationOwner(t *testing.T, store user.Store) {
func TestStoreReservationOwner(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
@@ -378,9 +448,11 @@ func testStoreReservationOwner(t *testing.T, store user.Store) {
owner, err = store.ReservationOwner("unowned")
require.Nil(t, err)
require.Empty(t, owner)
})
}
func testStoreTiers(t *testing.T, store user.Store) {
func TestStoreTiers(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
tier := &user.Tier{
ID: "ti_test",
Code: "pro",
@@ -413,9 +485,11 @@ func testStoreTiers(t *testing.T, store user.Store) {
require.Nil(t, err)
require.Len(t, tiers, 1)
require.Equal(t, "pro", tiers[0].Code)
})
}
func testStoreTierUpdate(t *testing.T, store user.Store) {
func TestStoreTierUpdate(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
tier := &user.Tier{
ID: "ti_test",
Code: "pro",
@@ -431,9 +505,11 @@ func testStoreTierUpdate(t *testing.T, store user.Store) {
require.Nil(t, err)
require.Equal(t, "Professional", t2.Name)
require.Equal(t, int64(9999), t2.MessageLimit)
})
}
func testStoreTierRemove(t *testing.T, store user.Store) {
func TestStoreTierRemove(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
tier := &user.Tier{
ID: "ti_test",
Code: "pro",
@@ -448,9 +524,11 @@ func testStoreTierRemove(t *testing.T, store user.Store) {
require.Nil(t, store.RemoveTier("pro"))
_, err = store.Tier("pro")
require.Equal(t, user.ErrTierNotFound, err)
})
}
func testStoreTierByStripePrice(t *testing.T, store user.Store) {
func TestStoreTierByStripePrice(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
tier := &user.Tier{
ID: "ti_test",
Code: "pro",
@@ -467,9 +545,11 @@ func testStoreTierByStripePrice(t *testing.T, store user.Store) {
t3, err := store.TierByStripePrice("price_yearly")
require.Nil(t, err)
require.Equal(t, "pro", t3.Code)
})
}
func testStoreChangeTier(t *testing.T, store user.Store) {
func TestStoreChangeTier(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
tier := &user.Tier{
ID: "ti_test",
Code: "pro",
@@ -483,9 +563,11 @@ func testStoreChangeTier(t *testing.T, store user.Store) {
require.Nil(t, err)
require.NotNil(t, u.Tier)
require.Equal(t, "pro", u.Tier.Code)
})
}
func testStorePhoneNumbers(t *testing.T, store user.Store) {
func TestStorePhoneNumbers(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -502,9 +584,11 @@ func testStorePhoneNumbers(t *testing.T, store user.Store) {
require.Nil(t, err)
require.Len(t, numbers, 1)
require.Equal(t, "+0987654321", numbers[0])
})
}
func testStoreChangeSettings(t *testing.T, store user.Store) {
func TestStoreChangeSettings(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -517,9 +601,11 @@ func testStoreChangeSettings(t *testing.T, store user.Store) {
require.Nil(t, err)
require.NotNil(t, u2.Prefs)
require.Equal(t, "de", *u2.Prefs.Language)
})
}
func testStoreChangeBilling(t *testing.T, store user.Store) {
func TestStoreChangeBilling(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
billing := &user.Billing{
@@ -532,9 +618,11 @@ func testStoreChangeBilling(t *testing.T, store user.Store) {
require.Nil(t, err)
require.Equal(t, "cus_123", u.Billing.StripeCustomerID)
require.Equal(t, "sub_456", u.Billing.StripeSubscriptionID)
})
}
func testStoreUpdateStats(t *testing.T, store user.Store) {
func TestStoreUpdateStats(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -547,9 +635,11 @@ func testStoreUpdateStats(t *testing.T, store user.Store) {
require.Equal(t, int64(42), u2.Stats.Messages)
require.Equal(t, int64(3), u2.Stats.Emails)
require.Equal(t, int64(1), u2.Stats.Calls)
})
}
func testStoreResetStats(t *testing.T, store user.Store) {
func TestStoreResetStats(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -562,9 +652,11 @@ func testStoreResetStats(t *testing.T, store user.Store) {
require.Equal(t, int64(0), u2.Stats.Messages)
require.Equal(t, int64(0), u2.Stats.Emails)
require.Equal(t, int64(0), u2.Stats.Calls)
})
}
func testStoreMarkUserRemoved(t *testing.T, store user.Store) {
func TestStoreMarkUserRemoved(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -574,9 +666,11 @@ func testStoreMarkUserRemoved(t *testing.T, store user.Store) {
u2, err := store.User("phil")
require.Nil(t, err)
require.True(t, u2.Deleted)
})
}
func testStoreRemoveDeletedUsers(t *testing.T, store user.Store) {
func TestStoreRemoveDeletedUsers(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
@@ -589,9 +683,11 @@ func testStoreRemoveDeletedUsers(t *testing.T, store user.Store) {
u2, err := store.User("phil")
require.Nil(t, err)
require.True(t, u2.Deleted)
})
}
func testStoreAllGrants(t *testing.T, store user.Store) {
func TestStoreAllGrants(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
phil, err := store.User("phil")
@@ -606,9 +702,11 @@ func testStoreAllGrants(t *testing.T, store user.Store) {
require.Nil(t, err)
require.Contains(t, grants, phil.ID)
require.Contains(t, grants, ben.ID)
})
}
func testStoreOtherAccessCount(t *testing.T, store user.Store) {
func TestStoreOtherAccessCount(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("ben", "mytopic", true, true, "ben", false))
@@ -616,4 +714,5 @@ func testStoreOtherAccessCount(t *testing.T, store user.Store) {
count, err := store.OtherAccessCount("phil", "mytopic")
require.Nil(t, err)
require.Equal(t, 1, count)
})
}

View File

@@ -1,63 +0,0 @@
package webpush_test
import (
"testing"
"github.com/stretchr/testify/require"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/webpush"
)
func newTestPostgresStore(t *testing.T) webpush.Store {
testDB := dbtest.CreateTestPostgres(t)
store, err := webpush.NewPostgresStore(testDB)
require.Nil(t, err)
return store
}
func TestPostgresStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
testStoreUpsertSubscriptionSubscriptionsForTopic(t, newTestPostgresStore(t))
}
func TestPostgresStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
testStoreUpsertSubscriptionSubscriberIPLimitReached(t, newTestPostgresStore(t))
}
func TestPostgresStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
testStoreUpsertSubscriptionUpdateTopics(t, newTestPostgresStore(t))
}
func TestPostgresStoreUpsertSubscriptionUpdateFields(t *testing.T) {
testStoreUpsertSubscriptionUpdateFields(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveByUserIDMultiple(t *testing.T) {
testStoreRemoveByUserIDMultiple(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveByEndpoint(t *testing.T) {
testStoreRemoveByEndpoint(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveByUserID(t *testing.T) {
testStoreRemoveByUserID(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveByUserIDEmpty(t *testing.T) {
testStoreRemoveByUserIDEmpty(t, newTestPostgresStore(t))
}
func TestPostgresStoreExpiryWarningSent(t *testing.T) {
store := newTestPostgresStore(t)
testStoreExpiryWarningSent(t, store, store.SetSubscriptionUpdatedAt)
}
func TestPostgresStoreExpiring(t *testing.T) {
store := newTestPostgresStore(t)
testStoreExpiring(t, store, store.SetSubscriptionUpdatedAt)
}
func TestPostgresStoreRemoveExpired(t *testing.T) {
store := newTestPostgresStore(t)
testStoreRemoveExpired(t, store, store.SetSubscriptionUpdatedAt)
}

View File

@@ -1,63 +0,0 @@
package webpush_test
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/webpush"
)
func newTestSQLiteStore(t *testing.T) webpush.Store {
store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "")
require.Nil(t, err)
t.Cleanup(func() { store.Close() })
return store
}
func TestSQLiteStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
testStoreUpsertSubscriptionSubscriptionsForTopic(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
testStoreUpsertSubscriptionSubscriberIPLimitReached(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
testStoreUpsertSubscriptionUpdateTopics(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUpsertSubscriptionUpdateFields(t *testing.T) {
testStoreUpsertSubscriptionUpdateFields(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveByUserIDMultiple(t *testing.T) {
testStoreRemoveByUserIDMultiple(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveByEndpoint(t *testing.T) {
testStoreRemoveByEndpoint(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveByUserID(t *testing.T) {
testStoreRemoveByUserID(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveByUserIDEmpty(t *testing.T) {
testStoreRemoveByUserIDEmpty(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreExpiryWarningSent(t *testing.T) {
store := newTestSQLiteStore(t)
testStoreExpiryWarningSent(t, store, store.SetSubscriptionUpdatedAt)
}
func TestSQLiteStoreExpiring(t *testing.T) {
store := newTestSQLiteStore(t)
testStoreExpiring(t, store, store.SetSubscriptionUpdatedAt)
}
func TestSQLiteStoreRemoveExpired(t *testing.T) {
store := newTestSQLiteStore(t)
testStoreRemoveExpired(t, store, store.SetSubscriptionUpdatedAt)
}

View File

@@ -3,16 +3,34 @@ package webpush_test
import (
"fmt"
"net/netip"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/webpush"
)
const testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
func testStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T, store webpush.Store) {
func forEachBackend(t *testing.T, f func(t *testing.T, store webpush.Store)) {
t.Run("sqlite", func(t *testing.T) {
store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "")
require.Nil(t, err)
t.Cleanup(func() { store.Close() })
f(t, store)
})
t.Run("postgres", func(t *testing.T) {
testDB := dbtest.CreateTestPostgres(t)
store, err := webpush.NewPostgresStore(testDB)
require.Nil(t, err)
f(t, store)
})
}
func TestStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
subs, err := store.SubscriptionsForTopic("test-topic")
@@ -27,9 +45,11 @@ func testStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T, store webpus
require.Nil(t, err)
require.Len(t, subs2, 1)
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
})
}
func testStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T, store webpush.Store) {
func TestStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
// Insert 10 subscriptions with the same IP address
for i := 0; i < 10; i++ {
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
@@ -44,9 +64,11 @@ func testStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T, store web
// But with a different IP address it should be fine again
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
})
}
func testStoreUpsertSubscriptionUpdateTopics(t *testing.T, store webpush.Store) {
func TestStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
// Insert subscription with two topics, and another with one topic
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
@@ -73,9 +95,11 @@ func testStoreUpsertSubscriptionUpdateTopics(t *testing.T, store webpush.Store)
subs, err = store.SubscriptionsForTopic("topic2")
require.Nil(t, err)
require.Len(t, subs, 0)
})
}
func testStoreUpsertSubscriptionUpdateFields(t *testing.T, store webpush.Store) {
func TestStoreUpsertSubscriptionUpdateFields(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
// Insert a subscription
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
@@ -96,9 +120,11 @@ func testStoreUpsertSubscriptionUpdateFields(t *testing.T, store webpush.Store)
require.Equal(t, "new-auth", subs[0].Auth)
require.Equal(t, "new-p256dh", subs[0].P256dh)
require.Equal(t, "u_5678", subs[0].UserID)
})
}
func testStoreRemoveByUserIDMultiple(t *testing.T, store webpush.Store) {
func TestStoreRemoveByUserIDMultiple(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
// Insert two subscriptions for u_1234 and one for u_5678
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
@@ -117,9 +143,11 @@ func testStoreRemoveByUserIDMultiple(t *testing.T, store webpush.Store) {
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint+"2", subs[0].Endpoint)
require.Equal(t, "u_5678", subs[0].UserID)
})
}
func testStoreRemoveByEndpoint(t *testing.T, store webpush.Store) {
func TestStoreRemoveByEndpoint(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
@@ -131,9 +159,11 @@ func testStoreRemoveByEndpoint(t *testing.T, store webpush.Store) {
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
})
}
func testStoreRemoveByUserID(t *testing.T, store webpush.Store) {
func TestStoreRemoveByUserID(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
@@ -145,18 +175,22 @@ func testStoreRemoveByUserID(t *testing.T, store webpush.Store) {
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
})
}
func testStoreRemoveByUserIDEmpty(t *testing.T, store webpush.Store) {
func TestStoreRemoveByUserIDEmpty(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID(""))
})
}
func testStoreExpiryWarningSent(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
func TestStoreExpiryWarningSent(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
// Set updated_at to the past so it shows up as expiring
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
// Verify subscription appears in expiring list (warned_at == 0)
subs, err := store.SubscriptionsExpiring(7 * 24 * time.Hour)
@@ -171,9 +205,11 @@ func testStoreExpiryWarningSent(t *testing.T, store webpush.Store, setUpdatedAt
subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour)
require.Nil(t, err)
require.Len(t, subs, 0)
})
}
func testStoreExpiring(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
func TestStoreExpiring(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
@@ -181,7 +217,7 @@ func testStoreExpiring(t *testing.T, store webpush.Store, setUpdatedAt func(endp
require.Len(t, subs, 1)
// Fake-mark them as soon-to-expire
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
// Should not be cleaned up yet
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
@@ -191,9 +227,11 @@ func testStoreExpiring(t *testing.T, store webpush.Store, setUpdatedAt func(endp
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
})
}
func testStoreRemoveExpired(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
func TestStoreRemoveExpired(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
@@ -201,7 +239,7 @@ func testStoreRemoveExpired(t *testing.T, store webpush.Store, setUpdatedAt func
require.Len(t, subs, 1)
// Fake-mark them as expired
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-10*24*time.Hour).Unix()))
require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-10*24*time.Hour).Unix()))
// Run expiration
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
@@ -210,4 +248,5 @@ func testStoreRemoveExpired(t *testing.T, store webpush.Store, setUpdatedAt func
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
})
}