mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
REmove store interface
This commit is contained in:
@@ -20,32 +20,8 @@ const (
|
|||||||
|
|
||||||
var errNoRows = errors.New("no rows found")
|
var errNoRows = errors.New("no rows found")
|
||||||
|
|
||||||
// Store is the interface for a message cache store
|
// queries holds the database-specific SQL queries
|
||||||
type Store interface {
|
type queries struct {
|
||||||
AddMessage(m *model.Message) error
|
|
||||||
AddMessages(ms []*model.Message) error
|
|
||||||
Message(id string) (*model.Message, error)
|
|
||||||
MessagesCount() (int, error)
|
|
||||||
Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error)
|
|
||||||
MessagesDue() ([]*model.Message, error)
|
|
||||||
MessagesExpired() ([]string, error)
|
|
||||||
MarkPublished(m *model.Message) error
|
|
||||||
UpdateMessageTime(messageID string, timestamp int64) error
|
|
||||||
Topics() ([]string, error)
|
|
||||||
DeleteMessages(ids ...string) error
|
|
||||||
DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error)
|
|
||||||
ExpireMessages(topics ...string) error
|
|
||||||
AttachmentsExpired() ([]string, error)
|
|
||||||
MarkAttachmentsDeleted(ids ...string) error
|
|
||||||
AttachmentBytesUsedBySender(sender string) (int64, error)
|
|
||||||
AttachmentBytesUsedByUser(userID string) (int64, error)
|
|
||||||
UpdateStats(messages int64) error
|
|
||||||
Stats() (int64, error)
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// storeQueries holds the database-specific SQL queries
|
|
||||||
type storeQueries struct {
|
|
||||||
insertMessage string
|
insertMessage string
|
||||||
deleteMessage string
|
deleteMessage string
|
||||||
selectScheduledMessageIDsBySeqID string
|
selectScheduledMessageIDsBySeqID string
|
||||||
@@ -71,21 +47,21 @@ type storeQueries struct {
|
|||||||
updateMessageTime string
|
updateMessageTime string
|
||||||
}
|
}
|
||||||
|
|
||||||
// commonStore implements store operations that are identical across database backends
|
// Cache stores published messages
|
||||||
type commonStore struct {
|
type Cache struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
queue *util.BatchingQueue[*model.Message]
|
queue *util.BatchingQueue[*model.Message]
|
||||||
nop bool
|
nop bool
|
||||||
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
|
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
|
||||||
queries storeQueries
|
queries queries
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCommonStore(db *sql.DB, queries storeQueries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *commonStore {
|
func newCache(db *sql.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
|
||||||
var queue *util.BatchingQueue[*model.Message]
|
var queue *util.BatchingQueue[*model.Message]
|
||||||
if batchSize > 0 || batchTimeout > 0 {
|
if batchSize > 0 || batchTimeout > 0 {
|
||||||
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
||||||
}
|
}
|
||||||
c := &commonStore{
|
c := &Cache{
|
||||||
db: db,
|
db: db,
|
||||||
queue: queue,
|
queue: queue,
|
||||||
nop: nop,
|
nop: nop,
|
||||||
@@ -96,13 +72,13 @@ func newCommonStore(db *sql.DB, queries storeQueries, mu *sync.Mutex, batchSize
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) maybeLock() {
|
func (c *Cache) maybeLock() {
|
||||||
if c.mu != nil {
|
if c.mu != nil {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) maybeUnlock() {
|
func (c *Cache) maybeUnlock() {
|
||||||
if c.mu != nil {
|
if c.mu != nil {
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
}
|
}
|
||||||
@@ -110,7 +86,7 @@ func (c *commonStore) maybeUnlock() {
|
|||||||
|
|
||||||
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
|
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
|
||||||
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
|
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
|
||||||
func (c *commonStore) AddMessage(m *model.Message) error {
|
func (c *Cache) AddMessage(m *model.Message) error {
|
||||||
if c.queue != nil {
|
if c.queue != nil {
|
||||||
c.queue.Enqueue(m)
|
c.queue.Enqueue(m)
|
||||||
return nil
|
return nil
|
||||||
@@ -119,11 +95,11 @@ func (c *commonStore) AddMessage(m *model.Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddMessages synchronously stores a batch of messages to the message cache
|
// AddMessages synchronously stores a batch of messages to the message cache
|
||||||
func (c *commonStore) AddMessages(ms []*model.Message) error {
|
func (c *Cache) AddMessages(ms []*model.Message) error {
|
||||||
return c.addMessages(ms)
|
return c.addMessages(ms)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) addMessages(ms []*model.Message) error {
|
func (c *Cache) addMessages(ms []*model.Message) error {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
if c.nop {
|
if c.nop {
|
||||||
@@ -209,7 +185,8 @@ func (c *commonStore) addMessages(ms []*model.Message) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
// Messages returns messages for a topic since the given marker, optionally including scheduled messages
|
||||||
|
func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||||
if since.IsNone() {
|
if since.IsNone() {
|
||||||
return make([]*model.Message, 0), nil
|
return make([]*model.Message, 0), nil
|
||||||
} else if since.IsLatest() {
|
} else if since.IsLatest() {
|
||||||
@@ -220,7 +197,7 @@ func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled
|
|||||||
return c.messagesSinceTime(topic, since, scheduled)
|
return c.messagesSinceTime(topic, since, scheduled)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
if scheduled {
|
if scheduled {
|
||||||
@@ -234,7 +211,7 @@ func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, s
|
|||||||
return readMessages(rows)
|
return readMessages(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
if scheduled {
|
if scheduled {
|
||||||
@@ -248,7 +225,7 @@ func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, sch
|
|||||||
return readMessages(rows)
|
return readMessages(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) {
|
func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
|
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -256,7 +233,8 @@ func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) {
|
|||||||
return readMessages(rows)
|
return readMessages(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) MessagesDue() ([]*model.Message, error) {
|
// MessagesDue returns all messages that are due for publishing
|
||||||
|
func (c *Cache) MessagesDue() ([]*model.Message, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
|
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -265,7 +243,7 @@ func (c *commonStore) MessagesDue() ([]*model.Message, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
|
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
|
||||||
func (c *commonStore) MessagesExpired() ([]string, error) {
|
func (c *Cache) MessagesExpired() ([]string, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
|
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -285,7 +263,8 @@ func (c *commonStore) MessagesExpired() ([]string, error) {
|
|||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) Message(id string) (*model.Message, error) {
|
// Message returns the message with the given ID, or ErrMessageNotFound if not found
|
||||||
|
func (c *Cache) Message(id string) (*model.Message, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
|
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -298,21 +277,23 @@ func (c *commonStore) Message(id string) (*model.Message, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
|
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
|
||||||
func (c *commonStore) UpdateMessageTime(messageID string, timestamp int64) error {
|
func (c *Cache) UpdateMessageTime(messageID string, timestamp int64) error {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
|
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) MarkPublished(m *model.Message) error {
|
// MarkPublished marks a message as published
|
||||||
|
func (c *Cache) MarkPublished(m *model.Message) error {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
|
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) MessagesCount() (int, error) {
|
// MessagesCount returns the total number of messages in the cache
|
||||||
|
func (c *Cache) MessagesCount() (int, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesCount)
|
rows, err := c.db.Query(c.queries.selectMessagesCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -328,7 +309,8 @@ func (c *commonStore) MessagesCount() (int, error) {
|
|||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) Topics() ([]string, error) {
|
// Topics returns a list of all topics with messages in the cache
|
||||||
|
func (c *Cache) Topics() ([]string, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectTopics)
|
rows, err := c.db.Query(c.queries.selectTopics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -348,7 +330,8 @@ func (c *commonStore) Topics() ([]string, error) {
|
|||||||
return topics, nil
|
return topics, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) DeleteMessages(ids ...string) error {
|
// DeleteMessages deletes the messages with the given IDs
|
||||||
|
func (c *Cache) DeleteMessages(ids ...string) error {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
tx, err := c.db.Begin()
|
||||||
@@ -366,7 +349,7 @@ func (c *commonStore) DeleteMessages(ids ...string) error {
|
|||||||
|
|
||||||
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
||||||
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
|
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
|
||||||
func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
tx, err := c.db.Begin()
|
||||||
@@ -402,7 +385,8 @@ func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]s
|
|||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) ExpireMessages(topics ...string) error {
|
// ExpireMessages marks messages in the given topics as expired
|
||||||
|
func (c *Cache) ExpireMessages(topics ...string) error {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
tx, err := c.db.Begin()
|
||||||
@@ -418,7 +402,8 @@ func (c *commonStore) ExpireMessages(topics ...string) error {
|
|||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) AttachmentsExpired() ([]string, error) {
|
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
|
||||||
|
func (c *Cache) AttachmentsExpired() ([]string, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
|
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -438,7 +423,8 @@ func (c *commonStore) AttachmentsExpired() ([]string, error) {
|
|||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error {
|
// MarkAttachmentsDeleted marks the attachments for the given message IDs as deleted
|
||||||
|
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
tx, err := c.db.Begin()
|
||||||
@@ -454,7 +440,8 @@ func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error {
|
|||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
|
||||||
|
func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -462,7 +449,8 @@ func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error)
|
|||||||
return c.readAttachmentBytesUsed(rows)
|
return c.readAttachmentBytesUsed(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
// AttachmentBytesUsedByUser returns the total size of active attachments for the given user
|
||||||
|
func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -470,7 +458,7 @@ func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
|||||||
return c.readAttachmentBytesUsed(rows)
|
return c.readAttachmentBytesUsed(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
func (c *Cache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
var size int64
|
var size int64
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
@@ -484,14 +472,16 @@ func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
|||||||
return size, nil
|
return size, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) UpdateStats(messages int64) error {
|
// UpdateStats updates the total message count statistic
|
||||||
|
func (c *Cache) UpdateStats(messages int64) error {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
_, err := c.db.Exec(c.queries.updateStats, messages)
|
_, err := c.db.Exec(c.queries.updateStats, messages)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) Stats() (messages int64, err error) {
|
// Stats returns the total message count statistic
|
||||||
|
func (c *Cache) Stats() (messages int64, err error) {
|
||||||
rows, err := c.db.Query(c.queries.selectStats)
|
rows, err := c.db.Query(c.queries.selectStats)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -506,11 +496,12 @@ func (c *commonStore) Stats() (messages int64, err error) {
|
|||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) Close() error {
|
// Close closes the underlying database connection
|
||||||
|
func (c *Cache) Close() error {
|
||||||
return c.db.Close()
|
return c.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) processMessageBatches() {
|
func (c *Cache) processMessageBatches() {
|
||||||
if c.queue == nil {
|
if c.queue == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -75,7 +75,7 @@ const (
|
|||||||
postgresUpdateMessageTimesQuery = `UPDATE message SET time = $1 WHERE mid = $2`
|
postgresUpdateMessageTimesQuery = `UPDATE message SET time = $1 WHERE mid = $2`
|
||||||
)
|
)
|
||||||
|
|
||||||
var pgQueries = storeQueries{
|
var pgQueries = queries{
|
||||||
insertMessage: postgresInsertMessageQuery,
|
insertMessage: postgresInsertMessageQuery,
|
||||||
deleteMessage: postgresDeleteMessageQuery,
|
deleteMessage: postgresDeleteMessageQuery,
|
||||||
selectScheduledMessageIDsBySeqID: postgresSelectScheduledMessageIDsBySeqIDQuery,
|
selectScheduledMessageIDsBySeqID: postgresSelectScheduledMessageIDsBySeqIDQuery,
|
||||||
@@ -102,9 +102,9 @@ var pgQueries = storeQueries{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
|
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
|
||||||
func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (Store, error) {
|
func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) {
|
||||||
if err := setupPostgresDB(db); err != nil {
|
if err := setupPostgresDB(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newCommonStore(db, pgQueries, nil, batchSize, batchTimeout, false), nil
|
return newCache(db, pgQueries, nil, batchSize, batchTimeout, false), nil
|
||||||
}
|
}
|
||||||
@@ -78,7 +78,7 @@ const (
|
|||||||
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
|
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
|
||||||
)
|
)
|
||||||
|
|
||||||
var sqliteQueries = storeQueries{
|
var sqliteQueries = queries{
|
||||||
insertMessage: sqliteInsertMessageQuery,
|
insertMessage: sqliteInsertMessageQuery,
|
||||||
deleteMessage: sqliteDeleteMessageQuery,
|
deleteMessage: sqliteDeleteMessageQuery,
|
||||||
selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery,
|
selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery,
|
||||||
@@ -105,7 +105,7 @@ var sqliteQueries = storeQueries{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewSQLiteStore creates a SQLite file-backed cache
|
// NewSQLiteStore creates a SQLite file-backed cache
|
||||||
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (Store, error) {
|
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*Cache, error) {
|
||||||
parentDir := filepath.Dir(filename)
|
parentDir := filepath.Dir(filename)
|
||||||
if !util.FileExists(parentDir) {
|
if !util.FileExists(parentDir) {
|
||||||
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
||||||
@@ -117,17 +117,17 @@ func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration
|
|||||||
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
|
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newCommonStore(db, sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
|
return newCache(db, sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMemStore creates an in-memory cache
|
// NewMemStore creates an in-memory cache
|
||||||
func NewMemStore() (Store, error) {
|
func NewMemStore() (*Cache, error) {
|
||||||
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false)
|
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNopStore creates an in-memory cache that discards all messages;
|
// NewNopStore creates an in-memory cache that discards all messages;
|
||||||
// it is always empty and can be used if caching is entirely disabled
|
// it is always empty and can be used if caching is entirely disabled
|
||||||
func NewNopStore() (Store, error) {
|
func NewNopStore() (*Cache, error) {
|
||||||
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true)
|
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -271,7 +271,7 @@ func newSqliteTestStoreFile(t *testing.T) string {
|
|||||||
return filepath.Join(t.TempDir(), "cache.db")
|
return filepath.Join(t.TempDir(), "cache.db")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) message.Store {
|
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) *message.Cache {
|
||||||
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
t.Cleanup(func() { s.Close() })
|
t.Cleanup(func() { s.Close() })
|
||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"heckel.io/ntfy/v2/model"
|
"heckel.io/ntfy/v2/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newSqliteTestStore(t *testing.T) message.Store {
|
func newSqliteTestStore(t *testing.T) *message.Cache {
|
||||||
filename := filepath.Join(t.TempDir(), "cache.db")
|
filename := filepath.Join(t.TempDir(), "cache.db")
|
||||||
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@@ -23,21 +23,21 @@ func newSqliteTestStore(t *testing.T) message.Store {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMemTestStore(t *testing.T) message.Store {
|
func newMemTestStore(t *testing.T) *message.Cache {
|
||||||
s, err := message.NewMemStore()
|
s, err := message.NewMemStore()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
t.Cleanup(func() { s.Close() })
|
t.Cleanup(func() { s.Close() })
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestPostgresStore(t *testing.T) message.Store {
|
func newTestPostgresStore(t *testing.T) *message.Cache {
|
||||||
testDB := dbtest.CreateTestPostgres(t)
|
testDB := dbtest.CreateTestPostgres(t)
|
||||||
store, err := message.NewPostgresStore(testDB, 0, 0)
|
store, err := message.NewPostgresStore(testDB, 0, 0)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
func forEachBackend(t *testing.T, f func(t *testing.T, s message.Store)) {
|
func forEachBackend(t *testing.T, f func(t *testing.T, s *message.Cache)) {
|
||||||
t.Run("sqlite", func(t *testing.T) {
|
t.Run("sqlite", func(t *testing.T) {
|
||||||
f(t, newSqliteTestStore(t))
|
f(t, newSqliteTestStore(t))
|
||||||
})
|
})
|
||||||
@@ -50,7 +50,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, s message.Store)) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_Messages(t *testing.T) {
|
func TestStore_Messages(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
m1 := model.NewDefaultMessage("mytopic", "my message")
|
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||||
m1.Time = 1
|
m1.Time = 1
|
||||||
|
|
||||||
@@ -113,7 +113,7 @@ func TestStore_Messages(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_MessagesLock(t *testing.T) {
|
func TestStore_MessagesLock(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < 5000; i++ {
|
for i := 0; i < 5000; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -127,7 +127,7 @@ func TestStore_MessagesLock(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_MessagesScheduled(t *testing.T) {
|
func TestStore_MessagesScheduled(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||||
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||||
@@ -155,7 +155,7 @@ func TestStore_MessagesScheduled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_Topics(t *testing.T) {
|
func TestStore_Topics(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
|
||||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
|
||||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
|
||||||
@@ -172,7 +172,7 @@ func TestStore_Topics(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
func TestStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
m := model.NewDefaultMessage("mytopic", "some message")
|
m := model.NewDefaultMessage("mytopic", "some message")
|
||||||
m.Tags = []string{"tag1", "tag2"}
|
m.Tags = []string{"tag1", "tag2"}
|
||||||
m.Priority = 5
|
m.Priority = 5
|
||||||
@@ -187,7 +187,7 @@ func TestStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_MessagesSinceID(t *testing.T) {
|
func TestStore_MessagesSinceID(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||||
m1.Time = 100
|
m1.Time = 100
|
||||||
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||||
@@ -251,7 +251,7 @@ func TestStore_MessagesSinceID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_Prune(t *testing.T) {
|
func TestStore_Prune(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
|
|
||||||
m1 := model.NewDefaultMessage("mytopic", "my message")
|
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||||
@@ -290,7 +290,7 @@ func TestStore_Prune(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_Attachments(t *testing.T) {
|
func TestStore_Attachments(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
||||||
m := model.NewDefaultMessage("mytopic", "flower for you")
|
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||||
m.ID = "m1"
|
m.ID = "m1"
|
||||||
@@ -369,7 +369,7 @@ func TestStore_Attachments(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_AttachmentsExpired(t *testing.T) {
|
func TestStore_AttachmentsExpired(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
m := model.NewDefaultMessage("mytopic", "flower for you")
|
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||||
m.ID = "m1"
|
m.ID = "m1"
|
||||||
m.SequenceID = "m1"
|
m.SequenceID = "m1"
|
||||||
@@ -422,7 +422,7 @@ func TestStore_AttachmentsExpired(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_Sender(t *testing.T) {
|
func TestStore_Sender(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
m1 := model.NewDefaultMessage("mytopic", "mymessage")
|
m1 := model.NewDefaultMessage("mytopic", "mymessage")
|
||||||
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
||||||
require.Nil(t, s.AddMessage(m1))
|
require.Nil(t, s.AddMessage(m1))
|
||||||
@@ -439,7 +439,7 @@ func TestStore_Sender(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_DeleteScheduledBySequenceID(t *testing.T) {
|
func TestStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
// Create a scheduled (unpublished) message
|
// Create a scheduled (unpublished) message
|
||||||
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
|
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||||
scheduledMsg.ID = "scheduled1"
|
scheduledMsg.ID = "scheduled1"
|
||||||
@@ -506,7 +506,7 @@ func TestStore_DeleteScheduledBySequenceID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_MessageByID(t *testing.T) {
|
func TestStore_MessageByID(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
// Add a message
|
// Add a message
|
||||||
m := model.NewDefaultMessage("mytopic", "some message")
|
m := model.NewDefaultMessage("mytopic", "some message")
|
||||||
m.Title = "some title"
|
m.Title = "some title"
|
||||||
@@ -531,7 +531,7 @@ func TestStore_MessageByID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_MarkPublished(t *testing.T) {
|
func TestStore_MarkPublished(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
// Add a scheduled message (future time -> unpublished)
|
// Add a scheduled message (future time -> unpublished)
|
||||||
m := model.NewDefaultMessage("mytopic", "scheduled message")
|
m := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||||
m.Time = time.Now().Add(time.Hour).Unix()
|
m.Time = time.Now().Add(time.Hour).Unix()
|
||||||
@@ -559,7 +559,7 @@ func TestStore_MarkPublished(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_ExpireMessages(t *testing.T) {
|
func TestStore_ExpireMessages(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
// Add messages to two topics
|
// Add messages to two topics
|
||||||
m1 := model.NewDefaultMessage("topic1", "message 1")
|
m1 := model.NewDefaultMessage("topic1", "message 1")
|
||||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
@@ -600,7 +600,7 @@ func TestStore_ExpireMessages(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_MarkAttachmentsDeleted(t *testing.T) {
|
func TestStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
// Add a message with an expired attachment (file needs cleanup)
|
// Add a message with an expired attachment (file needs cleanup)
|
||||||
m1 := model.NewDefaultMessage("mytopic", "old file")
|
m1 := model.NewDefaultMessage("mytopic", "old file")
|
||||||
m1.ID = "msg1"
|
m1.ID = "msg1"
|
||||||
@@ -659,7 +659,7 @@ func TestStore_MarkAttachmentsDeleted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_Stats(t *testing.T) {
|
func TestStore_Stats(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
// Initial stats should be zero
|
// Initial stats should be zero
|
||||||
messages, err := s.Stats()
|
messages, err := s.Stats()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@@ -680,7 +680,7 @@ func TestStore_Stats(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_AddMessages(t *testing.T) {
|
func TestStore_AddMessages(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
// Batch add multiple messages
|
// Batch add multiple messages
|
||||||
msgs := []*model.Message{
|
msgs := []*model.Message{
|
||||||
model.NewDefaultMessage("mytopic", "batch 1"),
|
model.NewDefaultMessage("mytopic", "batch 1"),
|
||||||
@@ -711,7 +711,7 @@ func TestStore_AddMessages(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_MessagesDue(t *testing.T) {
|
func TestStore_MessagesDue(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
// Add a message scheduled in the past (i.e. it's due now)
|
// Add a message scheduled in the past (i.e. it's due now)
|
||||||
m1 := model.NewDefaultMessage("mytopic", "due message")
|
m1 := model.NewDefaultMessage("mytopic", "due message")
|
||||||
m1.Time = time.Now().Add(-time.Second).Unix()
|
m1.Time = time.Now().Add(-time.Second).Unix()
|
||||||
@@ -755,7 +755,7 @@ func TestStore_MessagesDue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStore_MessageFieldRoundTrip(t *testing.T) {
|
func TestStore_MessageFieldRoundTrip(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, s message.Store) {
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
// Create a message with all fields populated
|
// Create a message with all fields populated
|
||||||
m := model.NewDefaultMessage("mytopic", "hello world")
|
m := model.NewDefaultMessage("mytopic", "hello world")
|
||||||
m.SequenceID = "custom_seq_id"
|
m.SequenceID = "custom_seq_id"
|
||||||
@@ -62,8 +62,8 @@ type Server struct {
|
|||||||
messages int64 // Total number of messages (persisted if messageCache enabled)
|
messages int64 // Total number of messages (persisted if messageCache enabled)
|
||||||
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
|
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
|
||||||
userManager *user.Manager // Might be nil!
|
userManager *user.Manager // Might be nil!
|
||||||
messageCache message.Store // Database that stores the messages
|
messageCache *message.Cache // Database that stores the messages
|
||||||
webPush webpush.Store // Database that stores web push subscriptions
|
webPush *webpush.Store // Database that stores web push subscriptions
|
||||||
fileCache *fileCache // File system based cache that stores attachments
|
fileCache *fileCache // File system based cache that stores attachments
|
||||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||||
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
||||||
@@ -191,7 +191,7 @@ func New(conf *Config) (*Server, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var wp webpush.Store
|
var wp *webpush.Store
|
||||||
if conf.WebPushPublicKey != "" {
|
if conf.WebPushPublicKey != "" {
|
||||||
if pool != nil {
|
if pool != nil {
|
||||||
wp, err = webpush.NewPostgresStore(pool)
|
wp, err = webpush.NewPostgresStore(pool)
|
||||||
@@ -277,7 +277,7 @@ func New(conf *Config) (*Server, error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createMessageCache(conf *Config, pool *sql.DB) (message.Store, error) {
|
func createMessageCache(conf *Config, pool *sql.DB) (*message.Cache, error) {
|
||||||
if conf.CacheDuration == 0 {
|
if conf.CacheDuration == 0 {
|
||||||
return message.NewNopStore()
|
return message.NewNopStore()
|
||||||
} else if pool != nil {
|
} else if pool != nil {
|
||||||
|
|||||||
@@ -4121,7 +4121,7 @@ func TestServer_DeleteScheduledMessage_WithAttachment(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMemTestCache(t *testing.T) message.Store {
|
func newMemTestCache(t *testing.T) *message.Cache {
|
||||||
c, err := message.NewMemStore()
|
c, err := message.NewMemStore()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return c
|
return c
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ const (
|
|||||||
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
||||||
type visitor struct {
|
type visitor struct {
|
||||||
config *Config
|
config *Config
|
||||||
messageCache message.Store
|
messageCache *message.Cache
|
||||||
userManager *user.Manager // May be nil
|
userManager *user.Manager // May be nil
|
||||||
ip netip.Addr // Visitor IP address
|
ip netip.Addr // Visitor IP address
|
||||||
user *user.User // Only set if authenticated user, otherwise nil
|
user *user.User // Only set if authenticated user, otherwise nil
|
||||||
@@ -115,7 +115,7 @@ const (
|
|||||||
visitorLimitBasisTier = visitorLimitBasis("tier")
|
visitorLimitBasisTier = visitorLimitBasis("tier")
|
||||||
)
|
)
|
||||||
|
|
||||||
func newVisitor(conf *Config, messageCache message.Store, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
func newVisitor(conf *Config, messageCache *message.Cache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||||
var messages, emails, calls int64
|
var messages, emails, calls int64
|
||||||
if user != nil {
|
if user != nil {
|
||||||
messages = user.Stats.Messages
|
messages = user.Stats.Messages
|
||||||
|
|||||||
@@ -21,21 +21,14 @@ var (
|
|||||||
ErrWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
|
ErrWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Store is the interface for a web push subscription store.
|
// Store holds the database connection and queries for web push subscriptions.
|
||||||
type Store interface {
|
type Store struct {
|
||||||
UpsertSubscription(endpoint, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error
|
db *sql.DB
|
||||||
SubscriptionsForTopic(topic string) ([]*Subscription, error)
|
queries queries
|
||||||
SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error)
|
|
||||||
MarkExpiryWarningSent(subscriptions []*Subscription) error
|
|
||||||
RemoveSubscriptionsByEndpoint(endpoint string) error
|
|
||||||
RemoveSubscriptionsByUserID(userID string) error
|
|
||||||
RemoveExpiredSubscriptions(expireAfter time.Duration) error
|
|
||||||
SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error
|
|
||||||
Close() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// storeQueries holds the database-specific SQL queries.
|
// queries holds the database-specific SQL queries.
|
||||||
type storeQueries struct {
|
type queries struct {
|
||||||
selectSubscriptionIDByEndpoint string
|
selectSubscriptionIDByEndpoint string
|
||||||
selectSubscriptionCountBySubscriberIP string
|
selectSubscriptionCountBySubscriberIP string
|
||||||
selectSubscriptionsForTopic string
|
selectSubscriptionsForTopic string
|
||||||
@@ -51,14 +44,8 @@ type storeQueries struct {
|
|||||||
deleteSubscriptionTopicWithoutSubscription string
|
deleteSubscriptionTopicWithoutSubscription string
|
||||||
}
|
}
|
||||||
|
|
||||||
// commonStore implements store operations that are identical across database backends.
|
|
||||||
type commonStore struct {
|
|
||||||
db *sql.DB
|
|
||||||
queries storeQueries
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
|
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
|
||||||
func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
||||||
tx, err := s.db.Begin()
|
tx, err := s.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -97,7 +84,7 @@ func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SubscriptionsForTopic returns all subscriptions for the given topic.
|
// SubscriptionsForTopic returns all subscriptions for the given topic.
|
||||||
func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
|
func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
|
||||||
rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic)
|
rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -107,7 +94,7 @@ func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period.
|
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period.
|
||||||
func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
|
func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
|
||||||
rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
|
rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -117,7 +104,7 @@ func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
|
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
|
||||||
func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error {
|
func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error {
|
||||||
tx, err := s.db.Begin()
|
tx, err := s.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -132,13 +119,13 @@ func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.
|
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.
|
||||||
func (s *commonStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
func (s *Store) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
||||||
_, err := s.db.Exec(s.queries.deleteSubscriptionByEndpoint, endpoint)
|
_, err := s.db.Exec(s.queries.deleteSubscriptionByEndpoint, endpoint)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID.
|
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID.
|
||||||
func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error {
|
func (s *Store) RemoveSubscriptionsByUserID(userID string) error {
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
return ErrWebPushUserIDCannotBeEmpty
|
return ErrWebPushUserIDCannotBeEmpty
|
||||||
}
|
}
|
||||||
@@ -147,7 +134,7 @@ func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period.
|
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period.
|
||||||
func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
func (s *Store) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
||||||
_, err := s.db.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
|
_, err := s.db.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -158,13 +145,13 @@ func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) erro
|
|||||||
|
|
||||||
// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is
|
// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is
|
||||||
// exported for testing purposes.
|
// exported for testing purposes.
|
||||||
func (s *commonStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error {
|
func (s *Store) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error {
|
||||||
_, err := s.db.Exec(s.queries.updateSubscriptionUpdatedAt, updatedAt, endpoint)
|
_, err := s.db.Exec(s.queries.updateSubscriptionUpdatedAt, updatedAt, endpoint)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the underlying database connection.
|
// Close closes the underlying database connection.
|
||||||
func (s *commonStore) Close() error {
|
func (s *Store) Close() error {
|
||||||
return s.db.Close()
|
return s.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,13 +71,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool.
|
// NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool.
|
||||||
func NewPostgresStore(db *sql.DB) (Store, error) {
|
func NewPostgresStore(db *sql.DB) (*Store, error) {
|
||||||
if err := setupPostgresDB(db); err != nil {
|
if err := setupPostgresDB(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &commonStore{
|
return &Store{
|
||||||
db: db,
|
db: db,
|
||||||
queries: storeQueries{
|
queries: queries{
|
||||||
selectSubscriptionIDByEndpoint: postgresSelectSubscriptionIDByEndpointQuery,
|
selectSubscriptionIDByEndpoint: postgresSelectSubscriptionIDByEndpointQuery,
|
||||||
selectSubscriptionCountBySubscriberIP: postgresSelectSubscriptionCountBySubscriberIPQuery,
|
selectSubscriptionCountBySubscriberIP: postgresSelectSubscriptionCountBySubscriberIPQuery,
|
||||||
selectSubscriptionsForTopic: postgresSelectSubscriptionsForTopicQuery,
|
selectSubscriptionsForTopic: postgresSelectSubscriptionsForTopicQuery,
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewSQLiteStore creates a new SQLite-backed web push store.
|
// NewSQLiteStore creates a new SQLite-backed web push store.
|
||||||
func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
func NewSQLiteStore(filename, startupQueries string) (*Store, error) {
|
||||||
db, err := sql.Open("sqlite3", filename)
|
db, err := sql.Open("sqlite3", filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -87,9 +87,9 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
|||||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &commonStore{
|
return &Store{
|
||||||
db: db,
|
db: db,
|
||||||
queries: storeQueries{
|
queries: queries{
|
||||||
selectSubscriptionIDByEndpoint: sqliteSelectWebPushSubscriptionIDByEndpointQuery,
|
selectSubscriptionIDByEndpoint: sqliteSelectWebPushSubscriptionIDByEndpointQuery,
|
||||||
selectSubscriptionCountBySubscriberIP: sqliteSelectWebPushSubscriptionCountBySubscriberIPQuery,
|
selectSubscriptionCountBySubscriberIP: sqliteSelectWebPushSubscriptionCountBySubscriberIPQuery,
|
||||||
selectSubscriptionsForTopic: sqliteSelectWebPushSubscriptionsForTopicQuery,
|
selectSubscriptionsForTopic: sqliteSelectWebPushSubscriptionsForTopicQuery,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
const testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
|
const testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
|
||||||
|
|
||||||
func forEachBackend(t *testing.T, f func(t *testing.T, store webpush.Store)) {
|
func forEachBackend(t *testing.T, f func(t *testing.T, store *webpush.Store)) {
|
||||||
t.Run("sqlite", func(t *testing.T) {
|
t.Run("sqlite", func(t *testing.T) {
|
||||||
store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "")
|
store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@@ -30,7 +30,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, store webpush.Store)) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
|
func TestStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, store webpush.Store) {
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||||
|
|
||||||
subs, err := store.SubscriptionsForTopic("test-topic")
|
subs, err := store.SubscriptionsForTopic("test-topic")
|
||||||
@@ -49,7 +49,7 @@ func TestStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
|
func TestStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, store webpush.Store) {
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert 10 subscriptions with the same IP address
|
// Insert 10 subscriptions with the same IP address
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
|
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
|
||||||
@@ -68,7 +68,7 @@ func TestStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
|
func TestStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, store webpush.Store) {
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics, and another with one topic
|
// Insert subscription with two topics, and another with one topic
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
|
||||||
@@ -99,7 +99,7 @@ func TestStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreUpsertSubscriptionUpdateFields(t *testing.T) {
|
func TestStoreUpsertSubscriptionUpdateFields(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, store webpush.Store) {
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert a subscription
|
// Insert a subscription
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ func TestStoreUpsertSubscriptionUpdateFields(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreRemoveByUserIDMultiple(t *testing.T) {
|
func TestStoreRemoveByUserIDMultiple(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, store webpush.Store) {
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert two subscriptions for u_1234 and one for u_5678
|
// Insert two subscriptions for u_1234 and one for u_5678
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||||
@@ -147,7 +147,7 @@ func TestStoreRemoveByUserIDMultiple(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreRemoveByEndpoint(t *testing.T) {
|
func TestStoreRemoveByEndpoint(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, store webpush.Store) {
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics
|
// Insert subscription with two topics
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
subs, err := store.SubscriptionsForTopic("topic1")
|
subs, err := store.SubscriptionsForTopic("topic1")
|
||||||
@@ -163,7 +163,7 @@ func TestStoreRemoveByEndpoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreRemoveByUserID(t *testing.T) {
|
func TestStoreRemoveByUserID(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, store webpush.Store) {
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics
|
// Insert subscription with two topics
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
subs, err := store.SubscriptionsForTopic("topic1")
|
subs, err := store.SubscriptionsForTopic("topic1")
|
||||||
@@ -179,13 +179,13 @@ func TestStoreRemoveByUserID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreRemoveByUserIDEmpty(t *testing.T) {
|
func TestStoreRemoveByUserIDEmpty(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, store webpush.Store) {
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID(""))
|
require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID(""))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreExpiryWarningSent(t *testing.T) {
|
func TestStoreExpiryWarningSent(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, store webpush.Store) {
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics
|
// Insert subscription with two topics
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
|
|
||||||
@@ -209,7 +209,7 @@ func TestStoreExpiryWarningSent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreExpiring(t *testing.T) {
|
func TestStoreExpiring(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, store webpush.Store) {
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics
|
// Insert subscription with two topics
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
subs, err := store.SubscriptionsForTopic("topic1")
|
subs, err := store.SubscriptionsForTopic("topic1")
|
||||||
@@ -231,7 +231,7 @@ func TestStoreExpiring(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreRemoveExpired(t *testing.T) {
|
func TestStoreRemoveExpired(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T, store webpush.Store) {
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics
|
// Insert subscription with two topics
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
subs, err := store.SubscriptionsForTopic("topic1")
|
subs, err := store.SubscriptionsForTopic("topic1")
|
||||||
|
|||||||
Reference in New Issue
Block a user