mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-01-18 16:17:26 +01:00
Kill existing subscribers when topic is reserved
This commit is contained in:
@@ -10,10 +10,16 @@ import (
|
||||
// can publish a message
|
||||
type topic struct {
|
||||
ID string
|
||||
subscribers map[int]subscriber
|
||||
subscribers map[int]*topicSubscriber
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type topicSubscriber struct {
|
||||
userID string // User ID associated with this subscription, may be empty
|
||||
subscriber subscriber
|
||||
cancel func()
|
||||
}
|
||||
|
||||
// subscriber is a function that is called for every new message on a topic
|
||||
type subscriber func(v *visitor, msg *message) error
|
||||
|
||||
@@ -21,16 +27,20 @@ type subscriber func(v *visitor, msg *message) error
|
||||
func newTopic(id string) *topic {
|
||||
return &topic{
|
||||
ID: id,
|
||||
subscribers: make(map[int]subscriber),
|
||||
subscribers: make(map[int]*topicSubscriber),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe subscribes to this topic
|
||||
func (t *topic) Subscribe(s subscriber) int {
|
||||
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
subscriberID := rand.Int()
|
||||
t.subscribers[subscriberID] = s
|
||||
t.subscribers[subscriberID] = &topicSubscriber{
|
||||
userID: userID, // May be empty
|
||||
subscriber: s,
|
||||
cancel: cancel,
|
||||
}
|
||||
return subscriberID
|
||||
}
|
||||
|
||||
@@ -56,7 +66,7 @@ func (t *topic) Publish(v *visitor, m *message) error {
|
||||
if err := s(v, m); err != nil {
|
||||
log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m))
|
||||
}
|
||||
}(s)
|
||||
}(s.subscriber)
|
||||
}
|
||||
} else {
|
||||
log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m))
|
||||
@@ -72,13 +82,29 @@ func (t *topic) SubscribersCount() int {
|
||||
return len(t.subscribers)
|
||||
}
|
||||
|
||||
// subscribersCopy returns a shallow copy of the subscribers map
|
||||
func (t *topic) subscribersCopy() map[int]subscriber {
|
||||
// CancelSubscribers calls the cancel function for all subscribers, forcing
|
||||
func (t *topic) CancelSubscribers(exceptUserID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
subscribers := make(map[int]subscriber)
|
||||
for k, v := range t.subscribers {
|
||||
subscribers[k] = v
|
||||
for _, s := range t.subscribers {
|
||||
if s.userID != exceptUserID {
|
||||
log.Trace("Canceling subscriber %s", s.userID)
|
||||
s.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// subscribersCopy returns a shallow copy of the subscribers map
|
||||
func (t *topic) subscribersCopy() map[int]*topicSubscriber {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
subscribers := make(map[int]*topicSubscriber)
|
||||
for k, sub := range t.subscribers {
|
||||
subscribers[k] = &topicSubscriber{
|
||||
userID: sub.userID,
|
||||
subscriber: sub.subscriber,
|
||||
cancel: sub.cancel,
|
||||
}
|
||||
}
|
||||
return subscribers
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user