Compare commits

...

178 Commits

Author SHA1 Message Date
binwiederhier
6b11bc7468 Merge branch 'main' into attachment-s3 2026-03-16 21:11:38 -04:00
binwiederhier
d9efe50848 Email validation 2026-03-16 21:03:33 -04:00
binwiederhier
2ad78edca1 Release notes 2026-03-16 20:13:39 -04:00
binwiederhier
86015e100c Multipart upload 2026-03-16 20:00:19 -04:00
binwiederhier
458fbad770 Merge branch 'main' into attachment-s3 2026-03-16 15:53:21 -04:00
binwiederhier
9b1a32ec56 Refine 2026-03-16 15:52:57 -04:00
binwiederhier
3d9ce69042 PG races 2026-03-16 15:48:36 -04:00
binwiederhier
59ce581ba2 Fix postgres primary/replica races 2026-03-16 11:21:21 -04:00
binwiederhier
df82fdf44c Add HTTP 413 to normal errors to not log 2026-03-16 10:27:23 -04:00
binwiederhier
3a37ea32f7 Webpush: Fix FK issue with Postgres 2026-03-16 10:24:16 -04:00
binwiederhier
790ba243c7 S3 WIP 2026-03-16 09:48:26 -04:00
binwiederhier
4487299a80 Merge branch 'main' into attachment-s3 2026-03-16 05:43:20 -04:00
binwiederhier
6b38acb23a Route authorization query to read-only database replica to reduce primary database load 2026-03-15 22:01:19 -04:00
binwiederhier
f5c255c53c Grr 2026-03-15 21:17:58 -04:00
binwiederhier
fd0a49244e Disable test temporarily 2026-03-15 21:13:12 -04:00
binwiederhier
4699ed3ffd Fix UTF-8 insert failures in Postgres 2026-03-15 21:03:18 -04:00
Philipp C. Heckel
1afb99db67 Merge pull request #1658 from BradStaton/1657-postgresql-urls
Support `postgresql://` and `postgres://` URLs
2026-03-15 20:45:08 -04:00
binwiederhier
66208e6f88 Pre-import 2026-03-15 20:25:22 -04:00
BradStaton
ce24594c32 Update serve.go
Support multiple postgres connection URL formats
2026-03-15 16:22:22 -04:00
binwiederhier
888850d8bc Add blurp 2026-03-15 10:29:07 -04:00
binwiederhier
be09acd411 Bump 2026-03-15 10:26:03 -04:00
binwiederhier
bf19a5be2d Merge branch 'main' of https://hosted.weblate.org/git/ntfy/web 2026-03-15 10:12:54 -04:00
binwiederhier
b4ec6fa8df AWS deps.. 2026-03-15 10:12:23 -04:00
binwiederhier
d517ce4a2a WIP: S3 2026-03-14 21:10:46 -04:00
BonifacioCalindoro
fd8f356d1f Translated using Weblate (Spanish)
Currently translated at 100.0% (407 of 407 strings)

Translation: ntfy/Web app
Translate-URL: https://hosted.weblate.org/projects/ntfy/web/es/
2026-03-15 01:09:47 +01:00
binwiederhier
3296d158c5 Release notes 2026-03-14 15:07:11 -04:00
binwiederhier
45f045a5a4 Make config generator work on mobile 2026-03-14 14:39:51 -04:00
binwiederhier
f7b6e9bbe3 Merge branch 'config-generator' 2026-03-14 14:21:18 -04:00
binwiederhier
22868f4742 Derp 2026-03-14 14:21:09 -04:00
Philipp C. Heckel
3801a28958 Merge pull request #1654 from binwiederhier/config-generator
Config generator
2026-03-14 14:16:51 -04:00
binwiederhier
2bf8f6271b Review 2026-03-14 14:15:46 -04:00
binwiederhier
13be9747e4 Security reivew 2026-03-14 13:56:43 -04:00
binwiederhier
26dd017401 Merge branch 'main' of github.com:binwiederhier/ntfy into config-generator 2026-03-14 13:47:42 -04:00
binwiederhier
d00cd64220 Add admin user 2026-03-14 13:03:36 -04:00
binwiederhier
fab08e862d More refining for config generator 2026-03-14 12:56:26 -04:00
binwiederhier
143935b917 More refining 2026-03-14 08:42:07 -04:00
Philipp C. Heckel
a82ede8a14 Merge pull request #1648 from binwiederhier/postgres-replica
Add PostgreSQL read-only replica support
2026-03-12 21:28:38 -04:00
binwiederhier
8a34dfe3f8 Move things, rename things 2026-03-12 21:17:30 -04:00
binwiederhier
270fec51a6 Bump 2026-03-11 22:12:10 -04:00
binwiederhier
9eaadd74cf Log 2026-03-11 22:09:00 -04:00
binwiederhier
1f483dcbd3 Remove consts 2026-03-11 21:54:35 -04:00
binwiederhier
85bdfc61ce Refine, log unhealthy replica 2026-03-11 21:07:58 -04:00
binwiederhier
ac65df1e83 Move auth queries to primary, redo health check loop 2026-03-11 20:26:29 -04:00
binwiederhier
ab33ac7ae5 Refine 2026-03-11 11:58:40 -04:00
binwiederhier
f1865749d7 WIP: Postgres read-only replica 2026-03-10 22:17:40 -04:00
binwiederhier
997e20fa3f Better error message for database-url errors 2026-03-10 21:18:34 -04:00
binwiederhier
3402510b47 More config generator 2026-03-10 21:11:27 -04:00
binwiederhier
19d1618bb8 Continued 2026-03-08 22:00:08 -04:00
binwiederhier
612afb1435 Configurator 2026-03-08 19:32:54 -04:00
binwiederhier
2b36ad9eb9 config generator 2026-03-08 18:59:17 -04:00
binwiederhier
bcd07115c2 Add tooltips for edit/delete buttons 2026-03-08 18:30:11 -04:00
binwiederhier
109271a930 Avoid playing sound more than every 2s 2026-03-08 10:55:59 -04:00
binwiederhier
fcf95dc9b8 Fix release notes 2026-03-07 16:17:00 -05:00
binwiederhier
79c3ab9ecc Bump version 2026-03-07 16:07:38 -05:00
binwiederhier
d51465fb6a Release notes 2026-03-07 08:51:57 -05:00
binwiederhier
0b189f65ff Loadtest username/pass 2026-03-06 15:52:53 -05:00
binwiederhier
0d4b1b00e6 Users() optimization to help with startup time 2026-03-06 14:46:53 -05:00
binwiederhier
28c3fd5cbe BUmp 2026-03-06 12:50:49 -05:00
binwiederhier
62bb335675 Grr 2026-03-04 22:42:45 -05:00
binwiederhier
70fb2732af Merge branch 'main' into postgres-support 2026-03-04 22:35:38 -05:00
binwiederhier
8e91e028a0 Fix coverage issue 2026-03-04 22:34:34 -05:00
binwiederhier
6d22f568f9 Merge branch 'main' into postgres-support 2026-03-04 21:07:40 -05:00
binwiederhier
59e6c16633 Fix cov test 2026-03-04 21:07:21 -05:00
binwiederhier
2e0b934bc2 Merge branch 'main' into postgres-support 2026-03-04 20:51:25 -05:00
binwiederhier
4f4a093f8d Merge branch 'main' of https://hosted.weblate.org/git/ntfy/web 2026-03-04 20:50:28 -05:00
binwiederhier
5610b7c56d Bump 2026-03-04 20:49:18 -05:00
binwiederhier
d4038f566c Release notes 2026-03-04 20:48:30 -05:00
binwiederhier
bff2b47eb6 Simplify schema reading 2026-03-03 14:52:29 -05:00
binwiederhier
33b19814c7 Transactions 2026-03-03 14:42:36 -05:00
binwiederhier
fb26e7ef3a BUmp 2026-03-03 09:42:59 -05:00
binwiederhier
66449bd19b pgimport readme 2026-03-02 20:15:35 -05:00
binwiederhier
bedbb121e4 Remove codecov.io 2026-03-02 20:05:29 -05:00
binwiederhier
c4b8cfa756 Manual correction 2026-03-02 20:04:03 -05:00
binwiederhier
c864a9baeb Put more things in tx 2026-03-02 19:58:26 -05:00
binwiederhier
8afeb813d9 Move OpenPostgres 2026-03-02 19:52:36 -05:00
binwiederhier
ea4739f79b Extract ExecTx 2026-03-02 19:45:35 -05:00
binwiederhier
31f0234098 Remove goroutine 2026-03-02 13:28:37 -05:00
binwiederhier
0e6b467cf0 Manual fixes 2026-03-02 12:58:01 -05:00
binwiederhier
11c79a6369 Refine again 2026-03-01 14:24:06 -05:00
binwiederhier
10a6939d8e Rename queries for consistency 2026-03-01 13:47:42 -05:00
binwiederhier
9736973286 REmove store interface 2026-03-01 13:19:53 -05:00
Philipp C. Heckel
941c43c10b Merge pull request #1622 from acress1/patch-1
Update ntfy.tedomum.net URL to ntfy.tedomum.fr
2026-03-01 12:19:44 -05:00
Philipp C. Heckel
af76aa011d Merge pull request #1629 from azrikahar/docs-config-access-token
docs: fix references in access token examples
2026-03-01 12:19:10 -05:00
binwiederhier
039566bcaf REmove unused 2026-03-01 11:59:24 -05:00
binwiederhier
544ce112b5 Manual review 2026-03-01 11:44:22 -05:00
binwiederhier
c19377109e Small refinements 2026-02-28 21:15:51 -05:00
binwiederhier
ccbd02331c Re-add execTx 2026-02-28 19:49:01 -05:00
binwiederhier
542aa403d2 Re-order functions 2026-02-28 19:34:09 -05:00
binwiederhier
ebb48e217d Merge user store and manager 2026-02-28 17:35:35 -05:00
binwiederhier
7710ace184 Self-review 2026-02-28 16:13:58 -05:00
azrikahar
b937b44f2d undo unrelated formatting changes 2026-02-28 15:06:37 +08:00
azrikahar
e618cf1a39 mention token in private instance example 2026-02-28 15:03:09 +08:00
azrikahar
e9cf2b5523 align backup-script user references in private instance section 2026-02-28 14:56:40 +08:00
azrikahar
c49a8179cf align backup-service user references in token via the config section 2026-02-28 14:55:25 +08:00
Bora Atıcı
a1cca7972d Translated using Weblate (Turkish)
Currently translated at 100.0% (407 of 407 strings)

Translation: ntfy/Web app
Translate-URL: https://hosted.weblate.org/projects/ntfy/web/tr/
2026-02-27 09:09:49 +01:00
ezn24
da1c7b1949 Translated using Weblate (Chinese (Traditional Han script))
Currently translated at 100.0% (407 of 407 strings)

Translation: ntfy/Web app
Translate-URL: https://hosted.weblate.org/projects/ntfy/web/zh_Hant/
2026-02-27 09:09:47 +01:00
binwiederhier
5c26e70fe7 Manual review: Make resetUserAccessTx 2026-02-26 17:34:19 -05:00
binwiederhier
c66fa92341 Re-add tx integrity 2026-02-26 16:54:56 -05:00
binwiederhier
a7d5a9c5d8 Sample 2026-02-23 23:08:13 -05:00
binwiederhier
391cd2c920 Migration tool 2026-02-23 22:44:21 -05:00
binwiederhier
9eec72adcc Fix race tests 2026-02-23 21:39:07 -05:00
binwiederhier
28a436c0d2 Optimizations 2026-02-23 14:11:36 -05:00
binwiederhier
b02366b42b Make more consistent 2026-02-23 13:49:54 -05:00
binwiederhier
90d0eca14d Consistency 2026-02-23 11:17:57 -05:00
acress1
0d375d3a08 Update ntfy.tedomum.net URL to ntfy.tedomum.fr 2026-02-23 09:22:38 +01:00
binwiederhier
811c7ae25a TX fixes 2026-02-22 20:39:41 -05:00
binwiederhier
850a9d4cc4 Fix bug with provisioned token removal 2026-02-22 20:28:37 -05:00
binwiederhier
43280fbc0a Performance improvements in pg 2026-02-22 16:21:27 -05:00
binwiederhier
35a54407a8 Manual review 2026-02-22 15:12:54 -05:00
binwiederhier
f726cc768e Manual refactor 2026-02-22 14:25:15 -05:00
binwiederhier
5d301e7dce Merge branch 'main' of github.com:binwiederhier/ntfy into postgres-support 2026-02-22 14:15:42 -05:00
binwiederhier
8b12bdeb3a Release notes 2026-02-22 09:07:48 -05:00
Gergő Mihály
390cff0604 Translated using Weblate (Hungarian)
Currently translated at 55.2% (225 of 407 strings)

Translation: ntfy/Web app
Translate-URL: https://hosted.weblate.org/projects/ntfy/web/hu/
2026-02-22 13:09:46 +00:00
binwiederhier
6375c2ce60 Renaming for consistency 2026-02-21 21:29:29 -05:00
binwiederhier
459c80ef9b Refactor tests again: forEachBackend 2026-02-21 21:02:41 -05:00
binwiederhier
b1eb90addc Move NewPollRequestMessage 2026-02-21 18:01:36 -05:00
binwiederhier
4b6979aa89 Move model constants 2026-02-21 10:42:34 -05:00
binwiederhier
c76e39bb0e Manual updates 2026-02-21 10:36:09 -05:00
binwiederhier
b82e1c3915 Release notes 2026-02-21 10:02:27 -05:00
binwiederhier
07e60ba041 Merge branch 'main' into postgres-support 2026-02-21 09:59:55 -05:00
binwiederhier
4e22e7f4c8 Release notes 2026-02-21 09:58:52 -05:00
binwiederhier
cf3ae187ce Merge branch 'main' of https://hosted.weblate.org/git/ntfy/web 2026-02-21 09:56:33 -05:00
Philipp C. Heckel
d19b825192 Merge pull request #1620 from uzkikh/fix/smtp-html-br-linebreaks
fix(smtp): preserve <br> line breaks in HTML emails
2026-02-21 09:56:24 -05:00
Ivan Uzkikh
2e499389fc fix(smtp): preserve <br> line breaks in HTML emails
HTML-only emails (e.g. from Synology DSM 7.3 Task Scheduler) use <br>
tags for line breaks. The existing implementation passed the raw HTML
body to bluemonday with AddSpaceWhenStrippingTag(true), which replaced
every tag including <br> with a space, causing all content to appear
on a single unreadable line.

Fix: convert <br> tags to newlines before stripping HTML, so line break
semantics are preserved in the notification body.

Resolves the gap noted in #690 ("very sub-par" HTML support).
2026-02-21 14:44:06 +01:00
binwiederhier
7c69a76345 Docs 2026-02-20 16:56:47 -05:00
binwiederhier
a28d8e7924 Update docs, manual changes 2026-02-20 16:36:41 -05:00
binwiederhier
13a3062a7f Re-add comments 2026-02-20 16:15:07 -05:00
binwiederhier
eb6e1ac44a Docs 2026-02-20 15:57:47 -05:00
binwiederhier
a4c836b531 Refine 2026-02-20 15:36:12 -05:00
binwiederhier
e818b063f7 Fmt 2026-02-20 08:25:32 -05:00
binwiederhier
039d555689 Fix tests 2026-02-20 08:19:25 -05:00
binwiederhier
209d5a4c62 Update parallelism 2026-02-19 22:42:08 -05:00
binwiederhier
bf265449ac Docs 2026-02-19 22:38:27 -05:00
binwiederhier
4cbd80c68e Use one PG connection, add support for connection params 2026-02-19 22:34:53 -05:00
binwiederhier
305e3fc9af DB conn to 25 2026-02-19 21:36:09 -05:00
binwiederhier
9e4a48b058 Make server tests also run against postgres 2026-02-19 20:48:01 -05:00
Philipp C. Heckel
93cd7f99f8 Merge pull request #1618 from PanderMusubi/doc_zabbix
add zabbix-ntfy to intergration doc
2026-02-19 11:00:21 -05:00
PanderMusubi
28e85df36e add zabbix-ntfy to intergration doc 2026-02-19 15:47:49 +01:00
binwiederhier
939b3d1117 Fix lint, make pipeline use psotgres 2026-02-18 21:07:31 -05:00
binwiederhier
9cc9891f49 Add postgres to pipeline 2026-02-18 20:55:03 -05:00
binwiederhier
0d1f3444f2 fmt 2026-02-18 20:48:41 -05:00
binwiederhier
2716ede6e1 Extract message cache into message/ package with model/ types
Move message cache from server/message_cache.go into a dedicated
message/ package with Store interface, SQLite and PostgreSQL
implementations. Extract shared types into model/ package.
2026-02-18 20:22:44 -05:00
Philipp C. Heckel
2bc7b5217b Merge pull request #1615 from benkenawell/example/zsh-terminal
Add comment for zsh users in the terminal example
2026-02-18 20:12:49 -05:00
Benjamin Kenawell
046c0e8c79 Add comment for zsh users in the terminal example
Bash and Zsh terminals' history works slightly different. In bash, the
current command is in the history already. In zsh, the
history command does not have the currently running command in it.  But
it does set the HISTCMD variable to the entry number for the current
command. So we can load the current command with `history "$HISTCMD"` in
zsh and apply the same sed filter to get the same result as the bash
`history | tail -n1` command.

I think it makes more sense to leave this as a comment in the current
example than to create a new example, since it otherwise works the exact
same.  I've added this to my workflow locally!
2026-02-18 08:07:08 -05:00
Yaron Shahrabani
652b2097ad Translated using Weblate (Hebrew)
Currently translated at 17.4% (71 of 407 strings)

Translation: ntfy/Web app
Translate-URL: https://hosted.weblate.org/projects/ntfy/web/he/
2026-02-18 13:09:47 +01:00
binwiederhier
ae5e1fe8d8 Merge branch 'postgres-webpush' into postgres-webpush+user 2026-02-17 20:15:27 -05:00
binwiederhier
e3a402ed95 Make user tests work for postgres and sqlite 2026-02-17 20:14:45 -05:00
binwiederhier
1abc1005d0 Re-org 2026-02-17 12:04:51 -05:00
binwiederhier
909c3fe17b Add Store-level unit tests for SQLite and PostgreSQL backends
Add shared test functions in store_test.go covering all Store interface
operations (users, tokens, access control, tiers, billing, stats, etc.)
with per-backend wrappers in store_sqlite_test.go and store_postgres_test.go
following the webpush test pattern.

Fix broken isUniqueConstraintError() which used incorrect interface
assertions instead of string matching for SQLite/PostgreSQL errors.
2026-02-16 22:56:31 -05:00
binwiederhier
07c3e280bf Refactor user package to Store interface with PostgreSQL support
Extract database operations from Manager into a Store interface with
SQLite and PostgreSQL implementations using a shared commonStore.
Split SQLite migrations into store_sqlite_migrations.go, use shared
schema_version table for PostgreSQL, rename user_user/user_tier tables
to "user"/tier, and wire up database-url in CLI commands.
2026-02-16 22:39:54 -05:00
binwiederhier
b567b4e904 Merge branch 'main' into postgres-webpush 2026-02-16 21:05:04 -05:00
binwiederhier
60fa50f0d5 Merge branch 'main' into postgres-webpush+user 2026-02-16 21:04:46 -05:00
binwiederhier
ceda5ec3d8 Move things in user package 2026-02-16 21:04:15 -05:00
binwiederhier
3d72845c81 Tests for user package 2026-02-16 20:49:17 -05:00
binwiederhier
82e15d84bd Manual changes 2026-02-16 20:02:19 -05:00
binwiederhier
4e5f95ba0c Refactor webpush store to eliminate code duplication
Consolidate SQLite and Postgres store implementations into a single
commonStore with database-specific SQL queries passed via configuration.
This eliminates ~100 lines of duplicate code while maintaining full
functionality for both backends.

Key changes:
- Move all store methods to commonStore in store.go
- Remove sqliteStore and postgresStore wrapper structs
- Refactor SQLite to use QueryRow() pattern like Postgres
- Pass database-specific queries via storeQueries struct
- Make store types unexported, only expose Store interface

All tests pass for both SQLite and PostgreSQL backends.
2026-02-16 19:53:34 -05:00
binwiederhier
869b972a50 Manual review 2026-02-16 19:12:14 -05:00
binwiederhier
bdd20197b3 Manual refinements 2026-02-16 19:06:45 -05:00
binwiederhier
a8dcecdb6d Refactor webpush store tests and add coverage
- Add SetSubscriptionUpdatedAt to Store interface, remove DB() accessor
- Rename store files to store_sqlite.go and store_postgres.go
- Use camelCase for test function names
- Add tests for upsert field updates and multi-user removal
- Use transaction in setupNewPostgresDB
- Use lowercase "excluded." in PostgreSQL upsert query
2026-02-16 18:53:12 -05:00
binwiederhier
5331437664 Unify webpush store tests across SQLite and PostgreSQL backends
Share test logic in store_test.go with thin per-backend wrappers.
Add SetSubscriptionUpdatedAt to both stores, removing the need for
raw SQL and the DB() accessor in tests.
2026-02-16 12:38:00 -05:00
binwiederhier
e432bf2886 Rename PostgreSQL table prefix from wp_ to webpush_ 2026-02-16 12:13:10 -05:00
binwiederhier
0edad84d86 Merge branch 'main' of https://hosted.weblate.org/git/ntfy/web 2026-02-16 10:55:17 -05:00
Philipp C. Heckel
ddf728acd1 Merge pull request #1607 from ZebMcKayhan/main
Added SIA-Server to integration list
2026-02-16 10:49:34 -05:00
Philipp C. Heckel
b1d3671dbb Merge pull request #1610 from luneth/patch-1
docs: Clarify F-Droid flavor excludes Google Services
2026-02-16 09:46:51 -05:00
luneth
3e6b46ec0c Fix: Clarify F-Droid flavor excludes Google Services
Updated documentation to reflect that the F-Droid flavor automatically excludes Google Services dependencies [here](33a36c4b54/app/build.gradle (L82))
2026-02-16 15:43:57 +01:00
ZebMcKayhan
b16d381626 Update integrations.md
Added SIA-Server to project list.
2026-02-15 19:41:22 +01:00
Mazurky
3bd1a1ea03 Translated using Weblate (Slovak)
Currently translated at 100.0% (407 of 407 strings)

Translation: ntfy/Web app
Translate-URL: https://hosted.weblate.org/projects/ntfy/web/sk/
2026-02-13 22:09:46 +00:00
Philipp C. Heckel
7adb37b94b Merge pull request #1603 from cyqsimon/npm-override
Allow overriding `npm` binary in Makefile
2026-02-13 16:22:50 -05:00
cyqsimon
bc08819525 Allow overriding npm binary in Makefile 2026-02-12 17:57:09 +08:00
binwiederhier
a03a37feb1 Merge branch 'main' into release-2.17.x 2026-02-08 22:31:35 -05:00
binwiederhier
4cd556f5aa Copy fix 2026-02-08 22:31:22 -05:00
binwiederhier
90aeb811ff Changelog 2026-02-08 22:18:45 -05:00
binwiederhier
c6ab380ea4 Merge branch '1364-copy-action' 2026-02-08 20:46:54 -05:00
binwiederhier
7860f2142c Constants 2026-02-08 15:11:06 -05:00
binwiederhier
18d5d31bd2 Merge branch 'main' of github.com:binwiederhier/ntfy 2026-02-08 14:29:04 -05:00
binwiederhier
cfdc364e3f Version API endpoint 2026-02-08 14:28:27 -05:00
Philipp C. Heckel
763215ecfa Merge pull request #1595 from epifeny/fix/docker-build-include-payments
fix: add payments dir to docker build
2026-02-08 11:24:54 -08:00
epifeny
49991d5aa7 fix: add payments dir to docker build 2026-02-06 10:12:34 +02:00
118 changed files with 23153 additions and 11068 deletions

View File

@@ -6,6 +6,22 @@ on:
jobs:
release:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:17
env:
POSTGRES_USER: ntfy
POSTGRES_PASSWORD: ntfy
POSTGRES_DB: ntfy_test
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U ntfy"
--health-interval 10s
--health-timeout 5s
--health-retries 5
env:
NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable"
steps:
- name: Checkout code
uses: actions/checkout@v3

View File

@@ -3,6 +3,22 @@ on: [ push, pull_request ]
jobs:
test:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:17
env:
POSTGRES_USER: ntfy
POSTGRES_PASSWORD: ntfy
POSTGRES_DB: ntfy_test
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U ntfy"
--health-interval 10s
--health-timeout 5s
--health-retries 5
env:
NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable"
steps:
- name: Checkout code
uses: actions/checkout@v3
@@ -23,8 +39,6 @@ jobs:
- name: Build web app (required for tests)
run: make web
- name: Run tests, formatting, vetting and linting
run: make check
run: make checkv
- name: Run coverage
run: make coverage
- name: Upload coverage to codecov.io
run: make coverage-upload

3
.gitignore vendored
View File

@@ -7,6 +7,9 @@ build/
server/docs/
server/site/
tools/fbsend/fbsend
tools/pgimport/pgimport
tools/loadtest/loadtest
tools/s3cli/s3cli
playground/
secrets/
*.iml

View File

@@ -40,6 +40,7 @@ ADD ./log ./log
ADD ./server ./server
ADD ./user ./user
ADD ./util ./util
ADD ./payments ./payments
RUN make VERSION=$VERSION COMMIT=$COMMIT cli-linux-server
FROM alpine

View File

@@ -1,4 +1,5 @@
MAKEFLAGS := --jobs=1
NPM := npm
PYTHON := python3
PIP := pip3
VERSION := $(shell git describe --tag)
@@ -137,7 +138,7 @@ web: web-deps web-build
web-build:
cd web \
&& npm run build \
&& $(NPM) run build \
&& mv build/index.html build/app.html \
&& rm -rf ../server/site \
&& mv build ../server/site \
@@ -145,20 +146,20 @@ web-build:
../server/site/config.js
web-deps:
cd web && npm install
cd web && $(NPM) install
# If this fails for .svg files, optimize them with svgo
web-deps-update:
cd web && npm update
cd web && $(NPM) update
web-fmt:
cd web && npm run format
cd web && $(NPM) run format
web-fmt-check:
cd web && npm run format:check
cd web && $(NPM) run format:check
web-lint:
cd web && npm run lint
cd web && $(NPM) run lint
# Main server/client build
@@ -264,23 +265,25 @@ cli-build-results:
check: test web-fmt-check fmt-check vet web-lint lint staticcheck
checkv: testv web-fmt-check fmt-check vet web-lint lint staticcheck
test: .PHONY
go test $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
go test $(shell go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./... | grep -vE 'ntfy/v2/(test|examples|tools)')
testv: .PHONY
go test -v $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
go test -v $(shell go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./... | grep -vE 'ntfy/v2/(test|examples|tools)')
race: .PHONY
go test -v -race $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
go test -v -race $(shell go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./... | grep -vE 'ntfy/v2/(test|examples|tools)')
coverage:
mkdir -p build/coverage
go test -v -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
go test -v -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./... | grep -vE 'ntfy/v2/(test|examples|tools|web)')
go tool cover -func build/coverage/coverage.txt
coverage-html:
mkdir -p build/coverage
go test -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
go test -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./... | grep -vE 'ntfy/v2/(test|examples|tools)')
go tool cover -html build/coverage/coverage.txt
coverage-upload:

25
attachment/store.go Normal file
View File

@@ -0,0 +1,25 @@
package attachment
import (
"errors"
"fmt"
"io"
"regexp"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
// Store is an interface for storing and retrieving attachment files
type Store interface {
Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error)
Read(id string) (io.ReadCloser, int64, error)
Remove(ids ...string) error
Size() int64
Remaining() int64
}
var (
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
errInvalidFileID = errors.New("invalid file ID")
)

View File

@@ -1,31 +1,29 @@
package server
package attachment
import (
"errors"
"fmt"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
"io"
"os"
"path/filepath"
"regexp"
"sync"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
var (
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, messageIDLength))
errInvalidFileID = errors.New("invalid file ID")
errFileExists = errors.New("file exists")
)
const tagFileStore = "file_store"
type fileCache struct {
var errFileExists = errors.New("file exists")
type fileStore struct {
dir string
totalSizeCurrent int64
totalSizeLimit int64
mu sync.Mutex
}
func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
// NewFileStore creates a new file-system backed attachment store
func NewFileStore(dir string, totalSizeLimit int64) (Store, error) {
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, err
}
@@ -33,18 +31,18 @@ func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
if err != nil {
return nil, err
}
return &fileCache{
return &fileStore{
dir: dir,
totalSizeCurrent: size,
totalSizeLimit: totalSizeLimit,
}, nil
}
func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
func (c *fileStore) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
log.Tag(tagFileCache).Field("message_id", id).Debug("Writing attachment")
log.Tag(tagFileStore).Field("message_id", id).Debug("Writing attachment")
file := filepath.Join(c.dir, id)
if _, err := os.Stat(file); err == nil {
return 0, errFileExists
@@ -67,20 +65,35 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (in
}
c.mu.Lock()
c.totalSizeCurrent += size
mset(metricAttachmentsTotalSize, c.totalSizeCurrent)
c.mu.Unlock()
return size, nil
}
func (c *fileCache) Remove(ids ...string) error {
func (c *fileStore) Read(id string) (io.ReadCloser, int64, error) {
if !fileIDRegex.MatchString(id) {
return nil, 0, errInvalidFileID
}
file := filepath.Join(c.dir, id)
stat, err := os.Stat(file)
if err != nil {
return nil, 0, err
}
f, err := os.Open(file)
if err != nil {
return nil, 0, err
}
return f, stat.Size(), nil
}
func (c *fileStore) Remove(ids ...string) error {
for _, id := range ids {
if !fileIDRegex.MatchString(id) {
return errInvalidFileID
}
log.Tag(tagFileCache).Field("message_id", id).Debug("Deleting attachment")
log.Tag(tagFileStore).Field("message_id", id).Debug("Deleting attachment")
file := filepath.Join(c.dir, id)
if err := os.Remove(file); err != nil {
log.Tag(tagFileCache).Field("message_id", id).Err(err).Debug("Error deleting attachment")
log.Tag(tagFileStore).Field("message_id", id).Err(err).Debug("Error deleting attachment")
}
}
size, err := dirSize(c.dir)
@@ -90,17 +103,16 @@ func (c *fileCache) Remove(ids ...string) error {
c.mu.Lock()
c.totalSizeCurrent = size
c.mu.Unlock()
mset(metricAttachmentsTotalSize, size)
return nil
}
func (c *fileCache) Size() int64 {
func (c *fileStore) Size() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.totalSizeCurrent
}
func (c *fileCache) Remaining() int64 {
func (c *fileStore) Remaining() int64 {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.totalSizeLimit - c.totalSizeCurrent

View File

@@ -0,0 +1,99 @@
package attachment
import (
"bytes"
"fmt"
"io"
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/util"
)
var (
oneKilobyteArray = make([]byte, 1024)
)
func TestFileStore_Write_Success(t *testing.T) {
dir, s := newTestFileStore(t)
size, err := s.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
require.Equal(t, int64(11), s.Size())
require.Equal(t, int64(10229), s.Remaining())
}
func TestFileStore_Write_Read_Success(t *testing.T) {
_, s := newTestFileStore(t)
size, err := s.Write("abcdefghijkl", strings.NewReader("hello world"))
require.Nil(t, err)
require.Equal(t, int64(11), size)
reader, readSize, err := s.Read("abcdefghijkl")
require.Nil(t, err)
require.Equal(t, int64(11), readSize)
defer reader.Close()
data, err := io.ReadAll(reader)
require.Nil(t, err)
require.Equal(t, "hello world", string(data))
}
func TestFileStore_Write_Remove_Success(t *testing.T) {
dir, s := newTestFileStore(t) // max = 10k (10240), each = 1k (1024)
for i := 0; i < 10; i++ { // 10x999 = 9990
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
require.Nil(t, err)
require.Equal(t, int64(999), size)
}
require.Equal(t, int64(9990), s.Size())
require.Equal(t, int64(250), s.Remaining())
require.FileExists(t, dir+"/abcdefghijk1")
require.FileExists(t, dir+"/abcdefghijk5")
require.Nil(t, s.Remove("abcdefghijk1", "abcdefghijk5"))
require.NoFileExists(t, dir+"/abcdefghijk1")
require.NoFileExists(t, dir+"/abcdefghijk5")
require.Equal(t, int64(7992), s.Size())
require.Equal(t, int64(2248), s.Remaining())
}
func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) {
dir, s := newTestFileStore(t)
for i := 0; i < 10; i++ {
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
require.Nil(t, err)
require.Equal(t, int64(1024), size)
}
_, err := s.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkX")
}
func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) {
dir, s := newTestFileStore(t)
_, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkl")
}
func TestFileStore_Read_NotFound(t *testing.T) {
_, s := newTestFileStore(t)
_, _, err := s.Read("abcdefghijkl")
require.Error(t, err)
}
func newTestFileStore(t *testing.T) (dir string, store Store) {
dir = t.TempDir()
store, err := NewFileStore(dir, 10*1024)
require.Nil(t, err)
return dir, store
}
func readFile(t *testing.T, f string) string {
b, err := os.ReadFile(f)
require.Nil(t, err)
return string(b)
}

150
attachment/store_s3.go Normal file
View File

@@ -0,0 +1,150 @@
package attachment
import (
"context"
"fmt"
"io"
"sync"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/s3"
"heckel.io/ntfy/v2/util"
)
const (
tagS3Store = "s3_store"
)
type s3Store struct {
client *s3.Client
totalSizeCurrent int64
totalSizeLimit int64
mu sync.Mutex
}
// NewS3Store creates a new S3-backed attachment store. The s3URL must be in the format:
//
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) {
cfg, err := s3.ParseURL(s3URL)
if err != nil {
return nil, err
}
store := &s3Store{
client: s3.New(cfg),
totalSizeLimit: totalSizeLimit,
}
if totalSizeLimit > 0 {
size, err := store.computeSize()
if err != nil {
return nil, fmt.Errorf("s3 store: failed to compute initial size: %w", err)
}
store.totalSizeCurrent = size
}
return store, nil
}
func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3")
// Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked
// uploads, so no temp file or Content-Length is needed.
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
pr, pw := io.Pipe()
lw := util.NewLimitWriter(pw, limiters...)
var size int64
var copyErr error
done := make(chan struct{})
go func() {
defer close(done)
size, copyErr = io.Copy(lw, in)
if copyErr != nil {
pw.CloseWithError(copyErr)
} else {
pw.Close()
}
}()
putErr := c.client.PutObject(context.Background(), id, pr)
pr.Close()
<-done
if copyErr != nil {
return 0, copyErr
}
if putErr != nil {
return 0, putErr
}
c.mu.Lock()
c.totalSizeCurrent += size
c.mu.Unlock()
return size, nil
}
func (c *s3Store) Read(id string) (io.ReadCloser, int64, error) {
if !fileIDRegex.MatchString(id) {
return nil, 0, errInvalidFileID
}
return c.client.GetObject(context.Background(), id)
}
func (c *s3Store) Remove(ids ...string) error {
for _, id := range ids {
if !fileIDRegex.MatchString(id) {
return errInvalidFileID
}
}
// S3 DeleteObjects supports up to 1000 keys per call
for i := 0; i < len(ids); i += 1000 {
end := i + 1000
if end > len(ids) {
end = len(ids)
}
batch := ids[i:end]
for _, id := range batch {
log.Tag(tagS3Store).Field("message_id", id).Debug("Deleting attachment from S3")
}
if err := c.client.DeleteObjects(context.Background(), batch); err != nil {
return err
}
}
// Recalculate totalSizeCurrent via ListObjectsV2 (matches fileStore's dirSize rescan pattern)
size, err := c.computeSize()
if err != nil {
return fmt.Errorf("s3 store: failed to compute size after remove: %w", err)
}
c.mu.Lock()
c.totalSizeCurrent = size
c.mu.Unlock()
return nil
}
func (c *s3Store) Size() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.totalSizeCurrent
}
func (c *s3Store) Remaining() int64 {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.totalSizeLimit - c.totalSizeCurrent
if remaining < 0 {
return 0
}
return remaining
}
// computeSize uses ListAllObjects to sum up the total size of all objects with our prefix.
func (c *s3Store) computeSize() (int64, error) {
objects, err := c.client.ListAllObjects(context.Background())
if err != nil {
return 0, err
}
var totalSize int64
for _, obj := range objects {
totalSize += obj.Size
}
return totalSize, nil
}

367
attachment/store_s3_test.go Normal file
View File

@@ -0,0 +1,367 @@
package attachment
import (
"bytes"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/s3"
"heckel.io/ntfy/v2/util"
)
// --- Integration tests using a mock S3 server ---
func TestS3Store_WriteReadRemove(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
// Write
size, err := store.Write("abcdefghijkl", strings.NewReader("hello world"))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, int64(11), store.Size())
// Read back
reader, readSize, err := store.Read("abcdefghijkl")
require.Nil(t, err)
require.Equal(t, int64(11), readSize)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, "hello world", string(data))
// Remove
require.Nil(t, store.Remove("abcdefghijkl"))
require.Equal(t, int64(0), store.Size())
// Read after remove should fail
_, _, err = store.Read("abcdefghijkl")
require.Error(t, err)
}
func TestS3Store_WriteNoPrefix(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "", 10*1024)
size, err := store.Write("abcdefghijkl", strings.NewReader("test"))
require.Nil(t, err)
require.Equal(t, int64(4), size)
reader, _, err := store.Read("abcdefghijkl")
require.Nil(t, err)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, "test", string(data))
}
func TestS3Store_WriteTotalSizeLimit(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 100)
// First write fits
_, err := store.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)))
require.Nil(t, err)
require.Equal(t, int64(80), store.Size())
require.Equal(t, int64(20), store.Remaining())
// Second write exceeds total limit
_, err = store.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)))
require.Equal(t, util.ErrLimitReached, err)
}
func TestS3Store_WriteFileSizeLimit(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, err := store.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100))
require.Equal(t, util.ErrLimitReached, err)
}
func TestS3Store_WriteRemoveMultiple(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
for i := 0; i < 5; i++ {
_, err := store.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)))
require.Nil(t, err)
}
require.Equal(t, int64(500), store.Size())
require.Nil(t, store.Remove("abcdefghijk1", "abcdefghijk3"))
require.Equal(t, int64(300), store.Size())
}
func TestS3Store_ReadNotFound(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, _, err := store.Read("abcdefghijkl")
require.Error(t, err)
}
func TestS3Store_InvalidID(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, err := store.Write("bad", strings.NewReader("x"))
require.Equal(t, errInvalidFileID, err)
_, _, err = store.Read("bad")
require.Equal(t, errInvalidFileID, err)
err = store.Remove("bad")
require.Equal(t, errInvalidFileID, err)
}
// --- Helpers ---
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) Store {
t.Helper()
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
host := strings.TrimPrefix(server.URL, "https://")
s := &s3Store{
client: &s3.Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: host,
Bucket: bucket,
Prefix: prefix,
PathStyle: true,
HTTPClient: server.Client(),
},
totalSizeLimit: totalSizeLimit,
}
// Compute initial size (should be 0 for fresh mock)
size, err := s.computeSize()
require.Nil(t, err)
s.totalSizeCurrent = size
return s
}
// --- Mock S3 server ---
//
// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
type mockS3Server struct {
objects map[string][]byte // full key (bucket/key) -> body
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
nextID int // counter for generating upload IDs
mu sync.RWMutex
}
func newMockS3Server() *httptest.Server {
m := &mockS3Server{
objects: make(map[string][]byte),
uploads: make(map[string]map[int][]byte),
}
return httptest.NewTLSServer(m)
}
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Path is /{bucket}[/{key...}]
path := strings.TrimPrefix(r.URL.Path, "/")
q := r.URL.Query()
switch {
case r.Method == http.MethodPut && q.Has("partNumber"):
m.handleUploadPart(w, r, path)
case r.Method == http.MethodPut:
m.handlePut(w, r, path)
case r.Method == http.MethodPost && q.Has("uploads"):
m.handleInitiateMultipart(w, r, path)
case r.Method == http.MethodPost && q.Has("uploadId"):
m.handleCompleteMultipart(w, r, path)
case r.Method == http.MethodDelete && q.Has("uploadId"):
m.handleAbortMultipart(w, r, path)
case r.Method == http.MethodGet && q.Get("list-type") == "2":
m.handleList(w, r, path)
case r.Method == http.MethodGet:
m.handleGet(w, r, path)
case r.Method == http.MethodPost && q.Has("delete"):
m.handleDelete(w, r, path)
default:
http.Error(w, "not implemented", http.StatusNotImplemented)
}
}
func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
m.objects[path] = body
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
m.mu.Lock()
m.nextID++
uploadID := fmt.Sprintf("upload-%d", m.nextID)
m.uploads[uploadID] = make(map[int][]byte)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
}
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
var partNumber int
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
parts[partNumber] = body
m.mu.Unlock()
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
w.Header().Set("ETag", etag)
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
// Assemble parts in order
var assembled []byte
for i := 1; i <= len(parts); i++ {
assembled = append(assembled, parts[i]...)
}
m.objects[path] = assembled
delete(m.uploads, uploadID)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
}
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
delete(m.uploads, uploadID)
m.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
}
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
m.mu.RLock()
body, ok := m.objects[path]
m.mu.RUnlock()
if !ok {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`))
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
w.WriteHeader(http.StatusOK)
w.Write(body)
}
func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
// bucketPath is just the bucket name
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var req struct {
Objects []struct {
Key string `xml:"Key"`
} `xml:"Object"`
}
if err := xml.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
m.mu.Lock()
for _, obj := range req.Objects {
delete(m.objects, bucketPath+"/"+obj.Key)
}
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></DeleteResult>`))
}
func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
prefix := r.URL.Query().Get("prefix")
m.mu.RLock()
var contents []s3ListObject
for key, body := range m.objects {
// key is "bucket/objectkey", strip bucket prefix
objKey := strings.TrimPrefix(key, bucketPath+"/")
if objKey == key {
continue // different bucket
}
if prefix == "" || strings.HasPrefix(objKey, prefix) {
contents = append(contents, s3ListObject{Key: objKey, Size: int64(len(body))})
}
}
m.mu.RUnlock()
resp := s3ListResponse{
Contents: contents,
IsTruncated: false,
}
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
xml.NewEncoder(w).Encode(resp)
}
type s3ListResponse struct {
XMLName xml.Name `xml:"ListBucketResult"`
Contents []s3ListObject `xml:"Contents"`
IsTruncated bool `xml:"IsTruncated"`
}
type s3ListObject struct {
Key string `xml:"Key"`
Size int64 `xml:"Size"`
}

View File

@@ -2,9 +2,6 @@ package cmd
import (
"fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/test"
"heckel.io/ntfy/v2/util"
"net/http"
"net/http/httptest"
"os"
@@ -14,9 +11,14 @@ import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/test"
"heckel.io/ntfy/v2/util"
)
func TestCLI_Publish_Subscribe_Poll_Real_Server(t *testing.T) {
t.Skip("temporarily disabled") // FIXME
testMessage := util.RandomString(10)
app, _, _, _ := newTestApp()
require.Nil(t, app.Run([]string{"ntfy", "publish", "ntfytest", "ntfy unit test " + testMessage}))

View File

@@ -39,6 +39,8 @@ var flagsServe = append(
altsrc.NewStringFlag(&cli.StringFlag{Name: "key-file", Aliases: []string{"key_file", "K"}, EnvVars: []string{"NTFY_KEY_FILE"}, Usage: "private key file, if listen-https is set"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "cert-file", Aliases: []string{"cert_file", "E"}, EnvVars: []string{"NTFY_CERT_FILE"}, Usage: "certificate file, if listen-https is set"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"firebase_key_file", "F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores (e.g. postgres://user:pass@host:5432/ntfy)"}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "database-replica-urls", Aliases: []string{"database_replica_urls"}, EnvVars: []string{"NTFY_DATABASE_REPLICA_URLS"}, Usage: "PostgreSQL read replica connection strings for offloading read queries"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file", "C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-duration", Aliases: []string{"cache_duration", "b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: util.FormatDuration(server.DefaultCacheDuration), Usage: "buffer messages for this time to allow `since` requests"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "cache-batch-size", Aliases: []string{"cache_batch_size"}, EnvVars: []string{"NTFY_BATCH_SIZE"}, Usage: "max size of messages to batch together when writing to message cache (if zero, writes are synchronous)"}),
@@ -51,6 +53,7 @@ var flagsServe = append(
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-access", Aliases: []string{"auth_access"}, EnvVars: []string{"NTFY_AUTH_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-tokens", Aliases: []string{"auth_tokens"}, EnvVars: []string{"NTFY_AUTH_TOKENS"}, Usage: "pre-provisioned declarative access tokens"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-s3-url", Aliases: []string{"attachment_s3_url"}, EnvVars: []string{"NTFY_ATTACHMENT_S3_URL"}, Usage: "S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-expiry-duration", Aliases: []string{"attachment_expiry_duration", "X"}, EnvVars: []string{"NTFY_ATTACHMENT_EXPIRY_DURATION"}, Value: util.FormatDuration(server.DefaultAttachmentExpiryDuration), Usage: "duration after which uploaded attachments will be deleted (e.g. 3h, 20h)"}),
@@ -143,6 +146,8 @@ func execServe(c *cli.Context) error {
keyFile := c.String("key-file")
certFile := c.String("cert-file")
firebaseKeyFile := c.String("firebase-key-file")
databaseURL := c.String("database-url")
databaseReplicaURLs := c.StringSlice("database-replica-urls")
webPushPrivateKey := c.String("web-push-private-key")
webPushPublicKey := c.String("web-push-public-key")
webPushFile := c.String("web-push-file")
@@ -162,6 +167,7 @@ func execServe(c *cli.Context) error {
authAccessRaw := c.StringSlice("auth-access")
authTokensRaw := c.StringSlice("auth-tokens")
attachmentCacheDir := c.String("attachment-cache-dir")
attachmentS3URL := c.String("attachment-s3-url")
attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit")
attachmentFileSizeLimitStr := c.String("attachment-file-size-limit")
attachmentExpiryDurationStr := c.String("attachment-expiry-duration")
@@ -280,12 +286,18 @@ func execServe(c *cli.Context) error {
}
// Check values
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
if databaseURL != "" && !strings.HasPrefix(databaseURL, "postgres://") && !strings.HasPrefix(databaseURL, "postgresql://") {
return errors.New("if database-url is set, it must start with postgres:// or postgresql://")
} else if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set")
} else if len(databaseReplicaURLs) > 0 && databaseURL == "" {
return errors.New("database-replica-urls can only be used if database-url is also set")
} else if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
return errors.New("if set, FCM key file must exist")
} else if firebaseKeyFile != "" && !server.FirebaseAvailable {
return errors.New("cannot set firebase-key-file, support for Firebase is not available (nofirebase)")
} else if webPushPublicKey != "" && (webPushPrivateKey == "" || webPushFile == "" || webPushEmailAddress == "" || baseURL == "") {
return errors.New("if web push is enabled, web-push-private-key, web-push-public-key, web-push-file, web-push-email-address, and base-url should be set. run 'ntfy webpush keys' to generate keys")
} else if webPushPublicKey != "" && (webPushPrivateKey == "" || (webPushFile == "" && databaseURL == "") || webPushEmailAddress == "" || baseURL == "") {
return errors.New("if web push is enabled, web-push-private-key, web-push-public-key, web-push-file (or database-url), web-push-email-address, and base-url should be set. run 'ntfy webpush keys' to generate keys")
} else if keepaliveInterval < 5*time.Second {
return errors.New("keepalive interval cannot be lower than five seconds")
} else if managerInterval < 5*time.Second {
@@ -304,6 +316,10 @@ func execServe(c *cli.Context) error {
return errors.New("if smtp-server-listen is set, smtp-server-domain must also be set")
} else if attachmentCacheDir != "" && baseURL == "" {
return errors.New("if attachment-cache-dir is set, base-url must also be set")
} else if attachmentS3URL != "" && baseURL == "" {
return errors.New("if attachment-s3-url is set, base-url must also be set")
} else if attachmentS3URL != "" && attachmentCacheDir != "" {
return errors.New("attachment-cache-dir and attachment-s3-url are mutually exclusive")
} else if baseURL != "" {
u, err := url.Parse(baseURL)
if err != nil {
@@ -321,8 +337,8 @@ func execServe(c *cli.Context) error {
return errors.New("if upstream-base-url is set, base-url must also be set")
} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
} else if authFile == "" && (enableSignup || enableLogin || requireLogin || enableReservations || stripeSecretKey != "") {
return errors.New("cannot set enable-signup, enable-login, require-login, enable-reserve-topics, or stripe-secret-key if auth-file is not set")
} else if authFile == "" && databaseURL == "" && (enableSignup || enableLogin || requireLogin || enableReservations || stripeSecretKey != "") {
return errors.New("cannot set enable-signup, enable-login, require-login, enable-reserve-topics, or stripe-secret-key if auth-file or database-url is not set")
} else if enableSignup && !enableLogin {
return errors.New("cannot set enable-signup without also setting enable-login")
} else if requireLogin && !enableLogin {
@@ -331,8 +347,8 @@ func execServe(c *cli.Context) error {
return errors.New("cannot set stripe-secret-key or stripe-webhook-key, support for payments is not available in this build (nopayments)")
} else if stripeSecretKey != "" && (stripeWebhookKey == "" || baseURL == "") {
return errors.New("if stripe-secret-key is set, stripe-webhook-key and base-url must also be set")
} else if twilioAccount != "" && (twilioAuthToken == "" || twilioPhoneNumber == "" || twilioVerifyService == "" || baseURL == "" || authFile == "") {
return errors.New("if twilio-account is set, twilio-auth-token, twilio-phone-number, twilio-verify-service, base-url, and auth-file must also be set")
} else if twilioAccount != "" && (twilioAuthToken == "" || twilioPhoneNumber == "" || twilioVerifyService == "" || baseURL == "" || (authFile == "" && databaseURL == "")) {
return errors.New("if twilio-account is set, twilio-auth-token, twilio-phone-number, twilio-verify-service, base-url, and auth-file (or database-url) must also be set")
} else if messageSizeLimit > server.DefaultMessageSizeLimit {
log.Warn("message-size-limit is greater than 4K, this is not recommended and largely untested, and may lead to issues with some clients")
if messageSizeLimit > 5*1024*1024 {
@@ -412,6 +428,15 @@ func execServe(c *cli.Context) error {
payments.Setup(stripeSecretKey)
}
// Parse Twilio template
var twilioCallFormatTemplate *template.Template
if twilioCallFormat != "" {
twilioCallFormatTemplate, err = template.New("").Parse(twilioCallFormat)
if err != nil {
return fmt.Errorf("failed to parse twilio-call-format template: %w", err)
}
}
// Add default forbidden topics
disallowedTopics = append(disallowedTopics, server.DefaultDisallowedTopics...)
@@ -438,6 +463,7 @@ func execServe(c *cli.Context) error {
conf.AuthAccess = authAccess
conf.AuthTokens = authTokens
conf.AttachmentCacheDir = attachmentCacheDir
conf.AttachmentS3URL = attachmentS3URL
conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit
conf.AttachmentFileSizeLimit = attachmentFileSizeLimit
conf.AttachmentExpiryDuration = attachmentExpiryDuration
@@ -459,13 +485,7 @@ func execServe(c *cli.Context) error {
conf.TwilioAuthToken = twilioAuthToken
conf.TwilioPhoneNumber = twilioPhoneNumber
conf.TwilioVerifyService = twilioVerifyService
if twilioCallFormat != "" {
tmpl, err := template.New("twiml").Parse(twilioCallFormat)
if err != nil {
return fmt.Errorf("failed to parse twilio-call-format template: %w", err)
}
conf.TwilioCallFormat = tmpl
}
conf.TwilioCallFormat = twilioCallFormatTemplate
conf.MessageSizeLimit = int(messageSizeLimit)
conf.MessageDelayMax = messageDelayLimit
conf.TotalTopicLimit = totalTopicLimit
@@ -494,6 +514,8 @@ func execServe(c *cli.Context) error {
conf.EnableMetrics = enableMetrics
conf.MetricsListenHTTP = metricsListenHTTP
conf.ProfileListenHTTP = profileListenHTTP
conf.DatabaseURL = databaseURL
conf.DatabaseReplicaURLs = databaseReplicaURLs
conf.WebPushPrivateKey = webPushPrivateKey
conf.WebPushPublicKey = webPushPublicKey
conf.WebPushFile = webPushFile

View File

@@ -6,13 +6,15 @@ import (
"crypto/subtle"
"errors"
"fmt"
"heckel.io/ntfy/v2/server"
"heckel.io/ntfy/v2/user"
"os"
"strings"
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/server"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
@@ -29,6 +31,7 @@ var flagsUser = append(
&cli.StringFlag{Name: "config", Aliases: []string{"c"}, EnvVars: []string{"NTFY_CONFIG_FILE"}, Value: server.DefaultConfigFile, DefaultText: server.DefaultConfigFile, Usage: "config file"},
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file", "H"}, EnvVars: []string{"NTFY_AUTH_FILE"}, Usage: "auth database file used for access control"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores"}),
)
var cmdUser = &cli.Command{
@@ -365,24 +368,30 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
authFile := c.String("auth-file")
authStartupQueries := c.String("auth-startup-queries")
authDefaultAccess := c.String("auth-default-access")
if authFile == "" {
return nil, errors.New("option auth-file not set; auth is unconfigured for this server")
} else if !util.FileExists(authFile) {
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
}
databaseURL := c.String("database-url")
authDefault, err := user.ParsePermission(authDefaultAccess)
if err != nil {
return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'")
}
authConfig := &user.Config{
Filename: authFile,
StartupQueries: authStartupQueries,
DefaultAccess: authDefault,
ProvisionEnabled: false, // Hack: Do not re-provision users on manager initialization
BcryptCost: user.DefaultUserPasswordBcryptCost,
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
}
return user.NewManager(authConfig)
if databaseURL != "" {
host, dbErr := pg.Open(databaseURL)
if dbErr != nil {
return nil, dbErr
}
return user.NewPostgresManager(db.New(host, nil), authConfig)
} else if authFile != "" {
if !util.FileExists(authFile) {
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
}
return user.NewSQLiteManager(authFile, authStartupQueries, authConfig)
}
return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server")
}
func readPasswordAndConfirm(c *cli.Context) (string, error) {

137
db/db.go Normal file
View File

@@ -0,0 +1,137 @@
package db
import (
"context"
"database/sql"
"sync/atomic"
"time"
"heckel.io/ntfy/v2/log"
)
const (
tag = "db"
replicaHealthCheckInitialDelay = 5 * time.Second
replicaHealthCheckInterval = 30 * time.Second
replicaHealthCheckTimeout = 10 * time.Second
)
// DB wraps a primary *sql.DB and optional read replicas. All standard query/exec methods
// delegate to the primary. The ReadOnly() method returns a *sql.DB from a healthy replica
// (round-robin), falling back to the primary if no replicas are configured or all are unhealthy.
type DB struct {
primary *Host
replicas []*Host
counter atomic.Uint64
cancel context.CancelFunc
}
// New creates a new DB that wraps the given primary and optional replica connections.
// If replicas is nil or empty, ReadOnly() simply returns the primary.
// Replicas start unhealthy and are checked immediately by a background goroutine.
func New(primary *Host, replicas []*Host) *DB {
ctx, cancel := context.WithCancel(context.Background())
d := &DB{
primary: primary,
replicas: replicas,
cancel: cancel,
}
if len(d.replicas) > 0 {
go d.healthCheckLoop(ctx)
}
return d
}
// Query delegates to the primary database.
func (d *DB) Query(query string, args ...any) (*sql.Rows, error) {
return d.primary.DB.Query(query, args...)
}
// QueryRow delegates to the primary database.
func (d *DB) QueryRow(query string, args ...any) *sql.Row {
return d.primary.DB.QueryRow(query, args...)
}
// Exec delegates to the primary database.
func (d *DB) Exec(query string, args ...any) (sql.Result, error) {
return d.primary.DB.Exec(query, args...)
}
// Begin delegates to the primary database.
func (d *DB) Begin() (*sql.Tx, error) {
return d.primary.DB.Begin()
}
// Ping delegates to the primary database.
func (d *DB) Ping() error {
return d.primary.DB.Ping()
}
// Primary returns the underlying primary *sql.DB. This is only intended for
// one-time schema setup during store initialization, not for regular queries.
func (d *DB) Primary() *sql.DB {
return d.primary.DB
}
// ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy
// replicas. If all replicas are unhealthy or none are configured, the primary is returned.
func (d *DB) ReadOnly() *sql.DB {
if len(d.replicas) == 0 {
return d.primary.DB
}
n := len(d.replicas)
start := int(d.counter.Add(1) - 1)
for i := 0; i < n; i++ {
r := d.replicas[(start+i)%n]
if r.healthy.Load() {
return r.DB
}
}
return d.primary.DB
}
// Close closes the primary database and all replicas, and stops the health-check goroutine.
func (d *DB) Close() error {
d.cancel()
for _, r := range d.replicas {
r.DB.Close()
}
return d.primary.DB.Close()
}
// healthCheckLoop checks replicas immediately, then periodically on a ticker.
func (d *DB) healthCheckLoop(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-time.After(replicaHealthCheckInitialDelay):
d.checkReplicas(ctx)
}
for {
select {
case <-ctx.Done():
return
case <-time.After(replicaHealthCheckInterval):
d.checkReplicas(ctx)
}
}
}
// checkReplicas pings each replica with a timeout and updates its health status.
func (d *DB) checkReplicas(ctx context.Context) {
for _, r := range d.replicas {
wasHealthy := r.healthy.Load()
pingCtx, cancel := context.WithTimeout(ctx, replicaHealthCheckTimeout)
err := r.DB.PingContext(pingCtx)
cancel()
if err != nil {
r.healthy.Store(false)
log.Tag(tag).Error("Database replica %s is unhealthy: %s", r.Addr, err)
} else {
r.healthy.Store(true)
if !wasHealthy {
log.Tag(tag).Info("Database replica %s is healthy", r.Addr)
}
}
}
}

120
db/pg/pg.go Normal file
View File

@@ -0,0 +1,120 @@
package pg
import (
"database/sql"
"fmt"
"net/url"
"strconv"
"strings"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
"heckel.io/ntfy/v2/db"
)
// Open opens a PostgreSQL connection pool for a primary database. It pings the database
// to verify connectivity before returning.
func Open(dsn string) (*db.Host, error) {
d, err := open(dsn)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
if err := d.DB.Ping(); err != nil {
return nil, fmt.Errorf("database ping failed on %v: %w", d.Addr, err)
}
return d, nil
}
// OpenReplica opens a PostgreSQL connection pool for a read replica. Unlike Open, it does
// not ping the database, since replicas are health-checked in the background by db.DB.
func OpenReplica(dsn string) (*db.Host, error) {
return open(dsn)
}
// open opens a PostgreSQL database connection pool from a DSN string. It supports custom
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
// the DSN before passing it to the driver.
func open(dsn string) (*db.Host, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid database URL: %w", err)
}
switch u.Scheme {
case "postgres", "postgresql":
// OK
default:
return nil, fmt.Errorf("invalid database URL scheme %q, must be \"postgres\" or \"postgresql\" (URL: %s)", u.Scheme, censorPassword(u))
}
q := u.Query()
maxOpenConns, err := extractIntParam(q, "pool_max_conns", 10)
if err != nil {
return nil, err
}
maxIdleConns, err := extractIntParam(q, "pool_max_idle_conns", 0)
if err != nil {
return nil, err
}
connMaxLifetime, err := extractDurationParam(q, "pool_conn_max_lifetime", 0)
if err != nil {
return nil, err
}
connMaxIdleTime, err := extractDurationParam(q, "pool_conn_max_idle_time", 0)
if err != nil {
return nil, err
}
u.RawQuery = q.Encode()
d, err := sql.Open("pgx", u.String())
if err != nil {
return nil, err
}
d.SetMaxOpenConns(maxOpenConns)
if maxIdleConns > 0 {
d.SetMaxIdleConns(maxIdleConns)
}
if connMaxLifetime > 0 {
d.SetConnMaxLifetime(connMaxLifetime)
}
if connMaxIdleTime > 0 {
d.SetConnMaxIdleTime(connMaxIdleTime)
}
return &db.Host{
Addr: u.Host,
DB: d,
}, nil
}
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
v, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return v, nil
}
// censorPassword returns a string representation of the URL with the password replaced by "*****".
func censorPassword(u *url.URL) string {
if password, hasPassword := u.User.Password(); hasPassword {
return strings.Replace(u.String(), ":"+password+"@", ":*****@", 1)
}
return u.String()
}
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
d, err := time.ParseDuration(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return d, nil
}

53
db/pg/pg_test.go Normal file
View File

@@ -0,0 +1,53 @@
package pg
import (
"net/url"
"testing"
"github.com/stretchr/testify/require"
)
func TestOpen_InvalidScheme(t *testing.T) {
_, err := Open("postgresql+psycopg2://user:pass@localhost/db")
require.Error(t, err)
require.Contains(t, err.Error(), `invalid database URL scheme "postgresql+psycopg2"`)
require.Contains(t, err.Error(), "*****")
require.NotContains(t, err.Error(), "pass")
}
func TestOpen_InvalidURL(t *testing.T) {
_, err := Open("not a valid url\x00")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid database URL")
}
func TestCensorPassword(t *testing.T) {
tests := []struct {
name string
url string
expected string
}{
{
name: "with password",
url: "postgres://user:secret@localhost/db",
expected: "postgres://user:*****@localhost/db",
},
{
name: "without password",
url: "postgres://localhost/db",
expected: "postgres://localhost/db",
},
{
name: "user only",
url: "postgres://user@localhost/db",
expected: "postgres://user@localhost/db",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u, err := url.Parse(tt.url)
require.NoError(t, err)
require.Equal(t, tt.expected, censorPassword(u))
})
}
}

64
db/test/test.go Normal file
View File

@@ -0,0 +1,64 @@
package dbtest
import (
"fmt"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/util"
)
const testPoolMaxConns = "2"
// CreateTestPostgresSchema creates a temporary PostgreSQL schema and returns the DSN pointing to it.
// It registers a cleanup function to drop the schema when the test finishes.
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
func CreateTestPostgresSchema(t *testing.T) string {
t.Helper()
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
if dsn == "" {
t.Skip("NTFY_TEST_DATABASE_URL not set")
}
schema := fmt.Sprintf("test_%s", util.RandomString(10))
u, err := url.Parse(dsn)
require.Nil(t, err)
q := u.Query()
q.Set("pool_max_conns", testPoolMaxConns)
u.RawQuery = q.Encode()
dsn = u.String()
setupHost, err := pg.Open(dsn)
require.Nil(t, err)
_, err = setupHost.DB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err)
require.Nil(t, setupHost.DB.Close())
q.Set("search_path", schema)
u.RawQuery = q.Encode()
schemaDSN := u.String()
t.Cleanup(func() {
cleanHost, err := pg.Open(dsn)
if err == nil {
cleanHost.DB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanHost.DB.Close()
}
})
return schemaDSN
}
// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *db.DB connection to it.
// It registers cleanup functions to close the DB and drop the schema when the test finishes.
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
func CreateTestPostgres(t *testing.T) *db.DB {
t.Helper()
schemaDSN := CreateTestPostgresSchema(t)
testHost, err := pg.Open(schemaDSN)
require.Nil(t, err)
d := db.New(testHost, nil)
t.Cleanup(func() {
d.Close()
})
return d
}

25
db/types.go Normal file
View File

@@ -0,0 +1,25 @@
package db
import (
"database/sql"
"sync/atomic"
)
// Beginner is an interface for types that can begin a database transaction.
// Both *sql.DB and *DB implement this.
type Beginner interface {
Begin() (*sql.Tx, error)
}
// Querier is an interface for types that can execute SQL queries.
// *sql.DB, *sql.Tx, and *DB all implement this.
type Querier interface {
Query(query string, args ...any) (*sql.Rows, error)
}
// Host pairs a *sql.DB with the host:port it was opened against.
type Host struct {
Addr string // "host:port"
DB *sql.DB
healthy atomic.Bool
}

36
db/util.go Normal file
View File

@@ -0,0 +1,36 @@
package db
import "database/sql"
// ExecTx executes a function within a database transaction. If the function returns an error,
// the transaction is rolled back. Otherwise, the transaction is committed.
func ExecTx(db Beginner, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := f(tx); err != nil {
return err
}
return tx.Commit()
}
// QueryTx executes a function within a database transaction and returns the result. If the function
// returns an error, the transaction is rolled back. Otherwise, the transaction is committed.
func QueryTx[T any](db Beginner, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
var zero T
return zero, err
}
defer tx.Rollback()
t, err := f(tx)
if err != nil {
return t, err
}
if err := tx.Commit(); err != nil {
return t, err
}
return t, nil
}

View File

@@ -53,6 +53,16 @@ Here are a few working sample configs using a `/etc/ntfy/server.yml` file:
behind-proxy: true
```
=== "server.yml (PostgreSQL, behind proxy)"
``` yaml
base-url: "https://ntfy.example.com"
listen-http: ":2586"
database-url: "postgres://ntfy:mypassword@db.example.com:5432/ntfy?sslmode=require"
attachment-cache-dir: "/var/cache/ntfy/attachments"
behind-proxy: true
auth-default-access: "deny-all"
```
=== "server.yml (ntfy.sh config)"
``` yaml
# All the things: Behind a proxy, Firebase, cache, attachments,
@@ -125,16 +135,349 @@ using Docker Compose (i.e. `docker-compose.yml`):
command: serve
```
## Config generator
This generator helps you configure your self-hosted ntfy instance. It's not fully featured, but it is a good starting point. Please refer to the relevant sections in the doc for more details.
<div style="text-align: center;">
<button type="button" id="cg-open-btn" class="cg-open-btn">Open config generator</button>
</div>
<figure markdown style="padding-left: 50px; padding-right: 50px; cursor: pointer;" onclick="document.getElementById('cg-open-btn').click();">
<img src="../../static/img/config-generator.png"/>
<figcaption>The config generator helps you create a custom config for your self-hosted ntfy instance. Click to open.</figcaption>
</figure>
<div id="cg-modal" class="cg-modal" style="display:none">
<div class="cg-modal-backdrop"></div>
<div class="cg-modal-dialog">
<div class="cg-modal-header">
<div class="cg-modal-header-left">
<span class="cg-modal-title">Config generator</span><span class="cg-badge-beta">BETA</span>
<span class="cg-modal-desc">This generator helps you configure your self-hosted ntfy instance. It's not fully featured, but it is a good starting point.</span>
</div>
<div class="cg-modal-header-actions">
<button type="button" id="cg-reset-btn" class="cg-modal-reset" title="Reset all values">Reset</button>
<button type="button" id="cg-close-btn" class="cg-modal-close" title="Close">&times;</button>
</div>
</div>
<div class="cg-modal-body">
<div class="cg-mobile-toggle">
<button class="cg-mobile-toggle-btn active" data-show="left">Edit</button>
<button class="cg-mobile-toggle-btn" data-show="right">Preview</button>
</div>
<div id="cg-left">
<div class="cg-nav">
<div class="cg-nav-tab active" data-panel="cg-panel-general">General</div>
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-database" id="cg-nav-database">Database</div>
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-auth" id="cg-nav-auth">Users</div>
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-cache" id="cg-nav-cache">Message Cache</div>
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-attach" id="cg-nav-attach">Attachments</div>
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-webpush" id="cg-nav-webpush">Web Push</div>
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-email" id="cg-nav-email">Email</div>
</div>
<div class="cg-panels">
<div class="cg-panel active" id="cg-panel-general">
<div class="cg-field cg-inline-field">
<label>What URL will ntfy be reachable on?</label>
<input type="text" data-key="base-url" placeholder="https://ntfy.example.com">
</div>
<div class="cg-field cg-inline-field">
<label>Will ntfy run behind a proxy (e.g. nginx, Caddy)? <a href="/config/#behind-a-proxy-tls-etc" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
<div class="cg-btn-group">
<label><input type="radio" name="cg-proxy" value="no" checked><span>No</span></label>
<label><input type="radio" name="cg-proxy" value="yes"><span>Yes</span></label>
</div>
</div>
<div class="cg-field cg-inline-field">
<label>Will this ntfy server be open or private? <a href="/config/#access-control" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
<div class="cg-btn-group">
<label><input type="radio" name="cg-server-type" value="open" checked><span>Open</span></label>
<label><input type="radio" name="cg-server-type" value="private"><span>Private</span></label>
<label><input type="radio" name="cg-server-type" value="custom"><span>Custom</span></label>
</div>
</div>
<div class="cg-field cg-inline-field">
<label>Will iOS/iPhone users use this server? <a href="/config/#ios-instant-notifications" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
<div class="cg-btn-group">
<label><input type="radio" name="cg-ios" value="no" checked><span>No</span></label>
<label><input type="radio" name="cg-ios" value="yes"><span>Yes</span></label>
</div>
</div>
<div class="cg-field cg-inline-field">
<label>Do you want to use ntfy as a UnifiedPush distributor? <a href="/config/#example-unifiedpush" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
<div class="cg-btn-group">
<label><input type="radio" name="cg-unifiedpush" value="no" checked><span>No</span></label>
<label><input type="radio" name="cg-unifiedpush" value="yes"><span>Yes</span></label>
</div>
</div>
<div class="cg-field">
<label>Which features do you want to enable?</label>
<div class="cg-feature-grid">
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-auth"> User management and access control</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-auth">Configure</button></div>
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-cache"> Persistent message cache</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-cache">Configure</button></div>
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-attach"> Attachments</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-attach">Configure</button></div>
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-webpush"> Web push</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-webpush">Configure</button></div>
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-smtp-out"> Email notifications</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-email">Configure</button></div>
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-smtp-in"> Email publishing</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-email">Configure</button></div>
</div>
</div>
<div class="cg-field cg-inline-field cg-hidden" id="cg-wizard-db">
<label>Which database backend would you like to use? <a href="/config/#database-options" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
<div class="cg-btn-group">
<label><input type="radio" name="cg-db-type" value="sqlite" checked><span>SQLite</span></label>
<label><input type="radio" name="cg-db-type" value="postgres"><span>PostgreSQL</span></label>
</div>
</div>
</div>
<div class="cg-panel" id="cg-panel-auth">
<div class="cg-panel-desc">Configure user management, access control, and provisioned users/ACLs. See <a href="/config/#access-control" target="_blank">access control</a> for details.</div>
<div class="cg-field cg-inline-field">
<label>Where should the user database be stored?</label>
<input type="text" data-key="auth-file" placeholder="/var/lib/ntfy/auth.db">
</div>
<div class="cg-field cg-inline-field">
<label>What should the default access policy be? <a href="/config/#access-control" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
<select id="cg-default-access-select">
<option value="read-write" selected>Read &amp; Write</option>
<option value="read-only">Read Only</option>
<option value="write-only">Write Only</option>
<option value="deny-all">Deny All</option>
</select>
</div>
<div class="cg-field cg-inline-field">
<label>Should login to the web app be enabled?</label>
<div class="cg-btn-group">
<label><input type="radio" name="cg-login-mode" value="disabled" checked><span>Disabled</span></label>
<label><input type="radio" name="cg-login-mode" value="enabled"><span>Enabled</span></label>
<label><input type="radio" name="cg-login-mode" value="required"><span>Required</span></label>
</div>
</div>
<div class="cg-field cg-inline-field">
<label>Should it be possible to sign up via the web app?</label>
<div class="cg-btn-group">
<label><input type="radio" name="cg-enable-signup" value="no" checked><span>No</span></label>
<label><input type="radio" name="cg-enable-signup" value="yes"><span>Yes</span></label>
</div>
</div>
<input type="hidden" data-key="auth-default-access">
<input type="checkbox" data-key="enable-login" id="cg-enable-login-hidden" style="display:none">
<input type="checkbox" data-key="require-login" id="cg-require-login-hidden" style="display:none">
<input type="checkbox" data-key="enable-signup" id="cg-enable-signup-hidden" style="display:none">
<div class="cg-field">
<label>Provisioned users <a href="/config/#users-and-roles" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
<div class="cg-repeatable-container" id="cg-auth-users-container"></div>
<button type="button" class="cg-btn-add" data-add-type="user">+ Add user</button>
</div>
<div class="cg-field">
<label>Provisioned ACLs <a href="/config/#access-control-list-acl" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
<div class="cg-repeatable-container" id="cg-auth-acls-container"></div>
<button type="button" class="cg-btn-add" data-add-type="acl">+ Add ACL</button>
</div>
<div class="cg-field">
<label>Provisioned tokens <a href="/config/#access-tokens" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
<div class="cg-repeatable-container" id="cg-auth-tokens-container"></div>
<button type="button" class="cg-btn-add" data-add-type="token">+ Add token</button>
</div>
</div>
<div class="cg-panel" id="cg-panel-cache">
<div class="cg-panel-desc">Configure the message cache to allow devices to retrieve missed notifications. See <a href="/config/#message-cache" target="_blank">message cache</a> for details.</div>
<div class="cg-field cg-inline-field" id="cg-cache-file-field">
<label>Where should the cache be stored?</label>
<input type="text" data-key="cache-file" placeholder="/var/cache/ntfy/cache.db">
</div>
<div class="cg-field cg-inline-field">
<label>How long should messages be cached?</label>
<input type="text" data-key="cache-duration" placeholder="12h">
</div>
</div>
<div class="cg-panel" id="cg-panel-attach">
<div class="cg-panel-desc">Allow users to upload and attach files to notifications. See <a href="/config/#attachments" target="_blank">attachments</a> for details.</div>
<div class="cg-field cg-inline-field">
<label>Where should attachments be stored?</label>
<input type="text" data-key="attachment-cache-dir" placeholder="/var/cache/ntfy/attachments">
</div>
<div class="cg-field cg-inline-field">
<label>Max file size per attachment?</label>
<input type="text" data-key="attachment-file-size-limit" placeholder="15M">
</div>
<div class="cg-field cg-inline-field">
<label>Total attachment storage limit?</label>
<input type="text" data-key="attachment-total-size-limit" placeholder="5G">
</div>
<div class="cg-field cg-inline-field">
<label>How long before attachments expire?</label>
<input type="text" data-key="attachment-expiry-duration" placeholder="3h">
</div>
</div>
<div class="cg-panel" id="cg-panel-webpush">
<div class="cg-panel-desc">Enable browser push notifications via the Web Push API. VAPID keys are generated automatically. See <a href="/config/#web-push" target="_blank">web push</a> for details.</div>
<div class="cg-field cg-inline-field">
<label>Where should web push data be stored?</label>
<input type="text" data-key="web-push-file" placeholder="/var/lib/ntfy/webpush.db">
</div>
<div class="cg-field cg-inline-field">
<label>Contact email address</label>
<input type="text" data-key="web-push-email-address" placeholder="admin@example.com">
</div>
<div class="cg-field cg-inline-field">
<label>Private key</label>
<input type="text" data-key="web-push-private-key" placeholder="Auto-generated" readonly>
</div>
<div class="cg-field cg-inline-field">
<label>Public key</label>
<input type="text" data-key="web-push-public-key" placeholder="Auto-generated" readonly>
</div>
<div class="cg-field cg-inline-field">
<label></label>
<button type="button" id="cg-regen-keys" class="cg-btn-add">Regenerate keys</button>
</div>
</div>
<div class="cg-panel" id="cg-panel-email">
<div class="cg-panel-desc">Configure outgoing email notifications and/or incoming email publishing. See <a href="/config/#e-mail-notifications" target="_blank">email notifications</a> and <a href="/config/#e-mail-publishing" target="_blank">email publishing</a> for details.</div>
<div id="cg-email-out-section" class="cg-hidden">
<div class="cg-field"><label><strong>Outgoing (notifications)</strong></label></div>
<div class="cg-field cg-inline-field">
<label>SMTP server address</label>
<input type="text" data-key="smtp-sender-addr" placeholder="smtp.example.com:587">
</div>
<div class="cg-field cg-inline-field">
<label>Sender email</label>
<input type="text" data-key="smtp-sender-from" placeholder="ntfy@example.com">
</div>
<div class="cg-field cg-inline-field">
<label>SMTP username</label>
<input type="text" data-key="smtp-sender-user" placeholder="Username">
</div>
<div class="cg-field cg-inline-field">
<label>SMTP password</label>
<input type="password" data-key="smtp-sender-pass" placeholder="Password">
</div>
</div>
<div id="cg-email-in-section" class="cg-hidden">
<div class="cg-field"><label><strong>Incoming (publishing)</strong></label></div>
<div class="cg-field cg-inline-field">
<label>Listen address</label>
<input type="text" data-key="smtp-server-listen" placeholder=":25">
</div>
<div class="cg-field cg-inline-field">
<label>Domain</label>
<input type="text" data-key="smtp-server-domain" placeholder="ntfy.example.com">
</div>
<div class="cg-field cg-inline-field">
<label>Address prefix</label>
<input type="text" data-key="smtp-server-addr-prefix" placeholder="ntfy-">
</div>
</div>
</div>
<div class="cg-panel" id="cg-panel-database">
<div class="cg-panel-desc">Configure the PostgreSQL connection. See <a href="/config/#postgresql-experimental" target="_blank">PostgreSQL</a> for details.</div>
<div class="cg-field">
<label>Database URL</label>
<input type="text" data-key="database-url" placeholder="postgres://user:pass@host:5432/ntfy">
</div>
</div>
<input type="hidden" data-key="upstream-base-url">
<input type="checkbox" data-key="behind-proxy" id="cg-behind-proxy" style="display:none">
</div>
</div>
<div id="cg-right">
<div class="cg-output-tabs">
<div class="cg-output-tab active" data-format="server-yml">server.yml</div>
<div class="cg-output-tab" data-format="docker-compose">docker-compose.yml</div>
<div class="cg-output-tab" data-format="env-vars">Env variables</div>
<button type="button" id="cg-copy-btn" class="cg-btn-copy" title="Copy to clipboard"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg></button>
</div>
<div class="cg-output-wrap">
<pre><code id="cg-code"><span class="cg-empty-msg">Configure options on the left to generate your config...</span></code></pre>
<div id="cg-warnings" class="cg-hidden"></div>
</div>
</div>
</div>
</div>
</div>
## Database options
ntfy uses a database for storing messages ([message cache](#message-cache)), users and [access control](#access-control), and [web push](#web-push) subscriptions.
You can choose between **SQLite** and **PostgreSQL** as the database backend.
### SQLite
By default, ntfy uses SQLite with separate database files for each store. This is the simplest setup and requires
no external dependencies:
* `cache-file`: Database file for the [message cache](#message-cache).
* `auth-file`: Database file for authentication and [access control](#access-control). If set, enables auth.
* `web-push-file`: Database file for [web push](#web-push) subscriptions.
### PostgreSQL (EXPERIMENTAL)
As an alternative, you can configure ntfy to use PostgreSQL for **all** database-backed stores by setting the
`database-url` option to a PostgreSQL connection string.
When `database-url` is set, ntfy will use PostgreSQL for the [message cache](#message-cache),
[access control](#access-control), and [web push](#web-push) subscriptions instead of SQLite. The `cache-file`,
`auth-file`, and `web-push-file` options **must not** be set in this case.
Note that setting `database-url` implicitly enables authentication and access control (equivalent to setting
`auth-file` with SQLite). The default access is `read-write`, so anonymous users can still read and write to all
topics. To restrict access, set `auth-default-access` to `deny-all` (see [access control](#access-control)).
You can also set this via the environment variable `NTFY_DATABASE_URL` or the command line flag `--database-url`.
To offload read-heavy queries from the primary database, you can optionally configure one or more read replicas
using the `database-replica-urls` option. When configured, non-critical read-only queries (e.g. fetching messages, checking access permissions, etc)
are distributed across the replicas using round-robin, while all writes and correctness-critical reads continue to go
to the primary. If a replica becomes unhealthy, ntfy automatically falls back to the primary until the replica recovers.
You can also set this via the environment variable `NTFY_DATABASE_REPLICA_URLS` (comma-separated) or the command line
flag `--database-replica-urls`.
Examples:
=== "Simple"
```yaml
database-url: "postgres://user:pass@host:5432/ntfy"
```
=== "With SSL and pool tuning"
```yaml
database-url: "postgres://user:pass@host:5432/ntfy?sslmode=require&pool_max_conns=50&pool_conn_max_idle_time=5m"
```
=== "With CA certificate"
```yaml
database-url: "postgres://user:pass@host:25060/ntfy?sslmode=require&sslrootcert=/etc/ntfy/db-ca-cert.pem&pool_max_conns=30"
```
=== "With read replicas"
```yaml
database-url: "postgres://user:pass@primary:5432/ntfy?sslmode=require&sslrootcert=/etc/ntfy/db-ca-cert.pem&pool_max_conns=30"
database-replica-urls:
- "postgres://user:pass@replica1:5432/ntfy?sslmode=require&sslrootcert=/etc/ntfy/db-ca-cert.pem&pool_max_conns=30"
- "postgres://user:pass@replica2:5432/ntfy?sslmode=require&sslrootcert=/etc/ntfy/db-ca-cert.pem&pool_max_conns=30"
```
The database URL supports the standard [PostgreSQL connection parameters](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS)
as query parameters, such as `sslmode`, `connect_timeout`, `sslcert`, `sslkey`, `sslrootcert`, and `application_name`.
See the [pgx driver documentation](https://pkg.go.dev/github.com/jackc/pgx/v5) for the full list of supported parameters.
In addition, ntfy supports the following custom query parameters to tune the connection pool (these apply to both
the primary and replica URLs):
| Parameter | Default | Description |
|---------------------------|---------|----------------------------------------------------------------------------------|
| `pool_max_conns` | 10 | Maximum number of open connections to the database |
| `pool_max_idle_conns` | - | Maximum number of idle connections in the pool |
| `pool_conn_max_lifetime` | - | Maximum amount of time a connection may be reused (Go duration, e.g. `5m`, `1h`) |
| `pool_conn_max_idle_time` | - | Maximum amount of time a connection may be idle (Go duration, e.g. `30s`, `5m`) |
## Message cache
If desired, ntfy can temporarily keep notifications in an in-memory or an on-disk cache. Caching messages for a short period
of time is important to allow [phones](subscribe/phone.md) and other devices with brittle Internet connections to be able to retrieve
notifications that they may have missed.
By default, ntfy keeps messages **in-memory for 12 hours**, which means that **cached messages do not survive an application
restart**. You can override this behavior using the following config settings:
restart**. You can override this behavior by setting `cache-file` (SQLite) or `database-url` (PostgreSQL).
* `cache-file`: if set, ntfy will store messages in a SQLite based cache (default is empty, which means in-memory cache).
**This is required if you'd like messages to be retained across restarts**.
* `cache-duration`: defines the duration for which messages are stored in the cache (default is `12h`).
You can also entirely disable the cache by setting `cache-duration` to `0`. When the cache is disabled, messages are only
@@ -146,20 +489,23 @@ Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscri
## Attachments
If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable
this feature, you have to simply configure an attachment cache directory and a base URL (`attachment-cache-dir`, `base-url`).
Once these options are set and the directory is writable by the server user, you can upload attachments via PUT.
this feature, you have to configure an attachment storage backend and a base URL (`base-url`). Attachments can be stored
either on the local filesystem (`attachment-cache-dir`) or in an S3-compatible object store (`attachment-s3-url`).
Once configured, you can upload attachments via PUT.
By default, attachments are stored in the disk-cache **for only 3 hours**. The main reason for this is to avoid legal issues
and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download
By default, attachments are stored **for only 3 hours**. The main reason for this is to avoid legal issues
and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download
feature) to download the file. The following config options are relevant to attachments:
* `base-url` is the root URL for the ntfy server; this is needed for the generated attachment URLs
* `attachment-cache-dir` is the cache directory for attached files
* `attachment-total-size-limit` is the size limit of the on-disk attachment cache (default: 5G)
* `attachment-cache-dir` is the cache directory for attached files (mutually exclusive with `attachment-s3-url`)
* `attachment-s3-url` is the S3 URL for attachment storage (mutually exclusive with `attachment-cache-dir`)
* `attachment-total-size-limit` is the size limit of the attachment storage (default: 5G)
* `attachment-file-size-limit` is the per-file attachment size limit (e.g. 300k, 2M, 100M, default: 15M)
* `attachment-expiry-duration` is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h, default: 3h)
Here's an example config using mostly the defaults (except for the cache directory, which is empty by default):
### Filesystem storage
Here's an example config using the local filesystem for attachment storage:
=== "/etc/ntfy/server.yml (minimal)"
``` yaml
@@ -178,6 +524,30 @@ Here's an example config using mostly the defaults (except for the cache directo
visitor-attachment-daily-bandwidth-limit: "500M"
```
### S3 storage
As an alternative to the local filesystem, you can store attachments in an S3-compatible object store (e.g. AWS S3,
MinIO, DigitalOcean Spaces). This is useful for HA/cloud deployments where you don't want to rely on local disk storage.
The `attachment-s3-url` option uses the following format:
```
s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
```
When `endpoint` is specified, path-style addressing is enabled automatically (useful for MinIO and other S3-compatible stores).
=== "/etc/ntfy/server.yml (AWS S3)"
``` yaml
base-url: "https://ntfy.sh"
attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1"
```
=== "/etc/ntfy/server.yml (MinIO/custom endpoint)"
``` yaml
base-url: "https://ntfy.sh"
attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com"
```
Please also refer to the [rate limiting](#rate-limiting) settings below, specifically `visitor-attachment-total-size-limit`
and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is necessary to avoid abuse.
@@ -185,14 +555,15 @@ and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is
By default, the ntfy server is open for everyone, meaning **everyone can read and write to any topic** (this is how
ntfy.sh is configured). To restrict access to your own server, you can optionally configure authentication and authorization.
ntfy's auth is implemented with a simple [SQLite](https://www.sqlite.org/)-based backend. It implements two roles
(`user` and `admin`) and per-topic `read` and `write` permissions using an [access control list (ACL)](https://en.wikipedia.org/wiki/Access-control_list).
Access control entries can be applied to users as well as the special everyone user (`*`), which represents anonymous API access.
ntfy's auth implements two roles (`user` and `admin`) and per-topic `read` and `write` permissions using an
[access control list (ACL)](https://en.wikipedia.org/wiki/Access-control_list). Access control entries can be applied
to users as well as the special everyone user (`*`), which represents anonymous API access.
To set up auth, **configure the following options**:
* `auth-file` is the user/access database; it is created automatically if it doesn't already exist; suggested
location `/var/lib/ntfy/user.db` (easiest if deb/rpm package is used)
* `auth-file` is the user/access database (SQLite); it is created automatically if it doesn't already exist; suggested
location `/var/lib/ntfy/user.db` (easiest if deb/rpm package is used). Alternatively, if `database-url` is set,
auth is automatically enabled using PostgreSQL (see [database options](#database-options)).
* `auth-default-access` defines the default/fallback access if no access control entry is found; it can be
set to `read-write` (default), `read-only`, `write-only` or `deny-all`. **If you are setting up a private instance,
you'll want to set this to `deny-all`** (see [private instance example](#example-private-instance)).
@@ -454,7 +825,7 @@ Here's an example:
```
# Comma-separated list
NTFY_AUTH_FILE='/var/lib/ntfy/user.db'
NTFY_AUTH_USERS='phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:admin,ben:$2a$10$NKbrNb7HPMjtQXWJ0f1pouw03LDLT/WzlO9VAv44x84bRCkh19h6m:user'
NTFY_AUTH_USERS='phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:admin,backup-service:$2a$10$NKbrNb7HPMjtQXWJ0f1pouw03LDLT/WzlO9VAv44x84bRCkh19h6m:user'
NTFY_AUTH_TOKENS='phil:tk_3gd7d2yftt4b8ixyfe9mnmro88o76,backup-service:tk_f099we8uzj7xi5qshzajwp6jffvkz:Backup script'
```
@@ -470,7 +841,8 @@ and access tokens in the `auth-tokens` section (see [access tokens via the confi
Here's an example that defines a single admin user `phil` with the password `mypass`, and a regular user `backup-script`
with the password `backup-script`. The admin user has full access to all topics, while regular user can only
access the `backups` topic with read-write permissions. The `auth-default-access` is set to `deny-all`, which means
access the `backups` topic with read-write permissions. `phil` has a token `tk_3gd7d2yftt4b8ixyfe9mnmro88o76`
with the label "My personal token". The `auth-default-access` is set to `deny-all`, which means
that all other users and anonymous access are denied by default.
=== "Config via /etc/ntfy/server.yml"
@@ -481,7 +853,7 @@ that all other users and anonymous access are denied by default.
- "phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:admin"
- "backup-script:$2a$10$/ehiQt.w7lhTmHXq.RNsOOkIwiPPeWFIzWYO3DRxNixnWKLX8.uj.:user"
auth-access:
- "backup-service:backups:rw"
- "backup-script:backups:rw"
auth-tokens:
- "phil:tk_3gd7d2yftt4b8ixyfe9mnmro88o76:My personal token"
```
@@ -491,7 +863,7 @@ that all other users and anonymous access are denied by default.
NTFY_AUTH_FILE='/var/lib/ntfy/user.db'
NTFY_AUTH_DEFAULT_ACCESS='deny-all'
NTFY_AUTH_USERS='phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:admin,backup-script:$2a$10$/ehiQt.w7lhTmHXq.RNsOOkIwiPPeWFIzWYO3DRxNixnWKLX8.uj.:user'
NTFY_AUTH_ACCESS='backup-service:backups:rw'
NTFY_AUTH_ACCESS='backup-script:backups:rw'
NTFY_AUTH_TOKENS='phil:tk_3gd7d2yftt4b8ixyfe9mnmro88o76:My personal token'
```
@@ -1141,12 +1513,15 @@ a database to keep track of the browser's subscriptions, and an admin email addr
- `web-push-public-key` is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
- `web-push-private-key` is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
- `web-push-file` is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db`
- `web-push-file` is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db` (not required if `database-url` is set)
- `web-push-email-address` is the admin email address send to the push provider, e.g. `sysadmin@example.com`
- `web-push-startup-queries` is an optional list of queries to run on startup`
- `web-push-expiry-warning-duration` defines the duration after which unused subscriptions are sent a warning (default is `55d`)
- `web-push-expiry-duration` defines the duration after which unused subscriptions will expire (default is `60d`)
Alternatively, you can use PostgreSQL instead of SQLite by setting `database-url`
(see [PostgreSQL database](#postgresql-experimental)).
Limitations:
- Like foreground browser notifications, background push notifications require the web app to be served over HTTPS. A _valid_
@@ -1172,9 +1547,10 @@ web-push-file: /var/cache/ntfy/webpush.db
web-push-email-address: sysadmin@example.com
```
The `web-push-file` is used to store the push subscriptions. Unused subscriptions will send out a warning after 55 days,
and will automatically expire after 60 days (default). If the gateway returns an error (e.g. 410 Gone when a user has unsubscribed),
subscriptions are also removed automatically.
The `web-push-file` is used to store the push subscriptions in a local SQLite database. Alternatively, if `database-url`
is set, subscriptions are stored in PostgreSQL and `web-push-file` is not required. Unused subscriptions will send out
a warning after 55 days, and will automatically expire after 60 days (default). If the gateway returns an error
(e.g. 410 Gone when a user has unsubscribed), subscriptions are also removed automatically.
The web app refreshes subscriptions on start and regularly on an interval, but this file should be persisted across restarts. If the subscription
file is deleted or lost, any web apps that aren't open will not receive new web push notifications until you open then.
@@ -1755,17 +2131,20 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`).
| `key-file` | `NTFY_KEY_FILE` | *filename* | - | HTTPS/TLS private key file, only used if `listen-https` is set. |
| `cert-file` | `NTFY_CERT_FILE` | *filename* | - | HTTPS/TLS certificate file, only used if `listen-https` is set. |
| `firebase-key-file` | `NTFY_FIREBASE_KEY_FILE` | *filename* | - | If set, also publish messages to a Firebase Cloud Messaging (FCM) topic for your app. This is optional and only required to save battery when using the Android app. See [Firebase (FCM)](#firebase-fcm). |
| `database-url` | `NTFY_DATABASE_URL` | *string (connection URL)* | - | PostgreSQL connection string (e.g. `postgres://user:pass@host:5432/ntfy`). If set, uses PostgreSQL for all database-backed stores (message cache, user manager, web push) instead of SQLite. See [database options](#database-options). |
| `database-replica-urls` | `NTFY_DATABASE_REPLICA_URLS` | *list of strings (connection URLs)* | - | PostgreSQL read replica connection strings. Non-critical read-only queries are distributed across replicas (round-robin) with automatic fallback to primary. Requires `database-url`. See [read replicas](#read-replicas). |
| `cache-file` | `NTFY_CACHE_FILE` | *filename* | - | If set, messages are cached in a local SQLite database instead of only in-memory. This allows for service restarts without losing messages in support of the since= parameter. See [message cache](#message-cache). |
| `cache-duration` | `NTFY_CACHE_DURATION` | *duration* | 12h | Duration for which messages will be buffered before they are deleted. This is required to support the `since=...` and `poll=1` parameter. Set this to `0` to disable the cache entirely. |
| `cache-startup-queries` | `NTFY_CACHE_STARTUP_QUERIES` | *string (SQL queries)* | - | SQL queries to run during database startup; this is useful for tuning and [enabling WAL mode](#message-cache) |
| `cache-batch-size` | `NTFY_CACHE_BATCH_SIZE` | *int* | 0 | Max size of messages to batch together when writing to message cache (if zero, writes are synchronous) |
| `cache-batch-timeout` | `NTFY_CACHE_BATCH_TIMEOUT` | *duration* | 0s | Timeout for batched async writes to the message cache (if zero, writes are synchronous) |
| `auth-file` | `NTFY_AUTH_FILE` | *filename* | - | Auth database file used for access control. If set, enables authentication and access control. See [access control](#access-control). |
| `auth-file` | `NTFY_AUTH_FILE` | *filename* | - | Auth database file used for access control (SQLite). If set, enables authentication and access control. Not required if `database-url` is set. See [access control](#access-control). |
| `auth-default-access` | `NTFY_AUTH_DEFAULT_ACCESS` | `read-write`, `read-only`, `write-only`, `deny-all` | `read-write` | Default permissions if no matching entries in the auth database are found. Default is `read-write`. |
| `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) |
| `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) |
| `proxy-trusted-hosts` | `NTFY_PROXY_TRUSTED_HOSTS` | *comma-separated host/IP/CIDR list* | - | Comma-separated list of trusted IP addresses, hosts, or CIDRs to remove from forwarded header |
| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. To enable attachments, this has to be set. |
| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. Mutually exclusive with `attachment-s3-url`. |
| `attachment-s3-url` | `NTFY_ATTACHMENT_S3_URL` | *URL* | - | S3 URL for attachment storage (format: `s3://KEY:SECRET@BUCKET[/PREFIX]?region=REGION`). Mutually exclusive with `attachment-cache-dir`. |
| `attachment-total-size-limit` | `NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT` | *size* | 5G | Limit of the on-disk attachment cache directory. If the limits is exceeded, new attachments will be rejected. |
| `attachment-file-size-limit` | `NTFY_ATTACHMENT_FILE_SIZE_LIMIT` | *size* | 15M | Per-file attachment size limit (e.g. 300k, 2M, 100M). Larger attachment will be rejected. |
| `attachment-expiry-duration` | `NTFY_ATTACHMENT_EXPIRY_DURATION` | *duration* | 3h | Duration after which uploaded attachments will be deleted (e.g. 3h, 20h). Strongly affects `visitor-attachment-total-size-limit`. |
@@ -1868,6 +2247,7 @@ OPTIONS:
--auth-startup-queries value, --auth_startup_queries value queries run when the auth database is initialized [$NTFY_AUTH_STARTUP_QUERIES]
--auth-default-access value, --auth_default_access value, -p value default permissions if no matching entries in the auth database are found (default: "read-write") [$NTFY_AUTH_DEFAULT_ACCESS]
--attachment-cache-dir value, --attachment_cache_dir value cache directory for attached files [$NTFY_ATTACHMENT_CACHE_DIR]
--attachment-s3-url value, --attachment_s3_url value S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION) [$NTFY_ATTACHMENT_S3_URL]
--attachment-total-size-limit value, --attachment_total_size_limit value, -A value limit of the on-disk attachment cache (default: "5G") [$NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT]
--attachment-file-size-limit value, --attachment_file_size_limit value, -Y value per-file attachment size limit (e.g. 300k, 2M, 100M) (default: "15M") [$NTFY_ATTACHMENT_FILE_SIZE_LIMIT]
--attachment-expiry-duration value, --attachment_expiry_duration value, -X value duration after which uploaded attachments will be deleted (e.g. 3h, 20h) (default: "3h") [$NTFY_ATTACHMENT_EXPIRY_DURATION]

View File

@@ -340,10 +340,6 @@ Then either follow the steps for building with or without Firebase.
Without Firebase, you may want to still change the default `app_base_url` in [values.xml](https://github.com/binwiederhier/ntfy-android/blob/main/app/src/main/res/values/values.xml)
if you're self-hosting the server. Then run:
```
# Remove Google dependencies (FCM)
sed -i -e '/google-services/d' build.gradle
sed -i -e '/google-services/d' app/build.gradle
# To build an unsigned .apk (app/build/outputs/apk/fdroid/*.apk)
./gradlew assembleFdroidRelease
@@ -351,6 +347,8 @@ sed -i -e '/google-services/d' app/build.gradle
./gradlew bundleFdroidRelease
```
The F-Droid flavor automatically excludes Google Services dependencies.
### Build Play flavor (FCM)
!!! info
I do build the ntfy Android app using IntelliJ IDEA (Android Studio), so I don't know if these Gradle commands will

View File

@@ -661,6 +661,8 @@ Add the following function and alias to your `.bashrc` or `.bash_profile`:
local token=$(< ~/.ntfy_token) # Securely read the token
local status_icon="$([ $exit_status -eq 0 ] && echo magic_wand || echo warning)"
local last_command=$(history | tail -n1 | sed -e 's/^[[:space:]]*[0-9]\{1,\}[[:space:]]*//' -e 's/[;&|][[:space:]]*alert$//')
# for zsh users, use the same sed pattern but get the history differently.
# local last_command=$(history "$HISTCMD" | sed -e 's/^[[:space:]]*[0-9]\{1,\}[[:space:]]*//' -e 's/[;&|][[:space:]]*alert$//')
curl -s -X POST "https://n.example.dev/alerts" \
-H "Authorization: Bearer $token" \
@@ -692,4 +694,4 @@ To test failure notifications:
false; alert # Always fails (exit 1)
ls --invalid; alert # Invalid option
cat nonexistent_file; alert # File not found
```
```

View File

@@ -30,37 +30,37 @@ deb/rpm packages.
=== "x86_64/amd64"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_amd64.tar.gz
tar zxvf ntfy_2.16.0_linux_amd64.tar.gz
sudo cp -a ntfy_2.16.0_linux_amd64/ntfy /usr/local/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_amd64/{client,server}/*.yml /etc/ntfy
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.tar.gz
tar zxvf ntfy_2.19.2_linux_amd64.tar.gz
sudo cp -a ntfy_2.19.2_linux_amd64/ntfy /usr/local/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_amd64/{client,server}/*.yml /etc/ntfy
sudo ntfy serve
```
=== "armv6"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv6.tar.gz
tar zxvf ntfy_2.16.0_linux_armv6.tar.gz
sudo cp -a ntfy_2.16.0_linux_armv6/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_armv6/{client,server}/*.yml /etc/ntfy
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.tar.gz
tar zxvf ntfy_2.19.2_linux_armv6.tar.gz
sudo cp -a ntfy_2.19.2_linux_armv6/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_armv6/{client,server}/*.yml /etc/ntfy
sudo ntfy serve
```
=== "armv7/armhf"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv7.tar.gz
tar zxvf ntfy_2.16.0_linux_armv7.tar.gz
sudo cp -a ntfy_2.16.0_linux_armv7/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_armv7/{client,server}/*.yml /etc/ntfy
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.tar.gz
tar zxvf ntfy_2.19.2_linux_armv7.tar.gz
sudo cp -a ntfy_2.19.2_linux_armv7/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_armv7/{client,server}/*.yml /etc/ntfy
sudo ntfy serve
```
=== "arm64"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_arm64.tar.gz
tar zxvf ntfy_2.16.0_linux_arm64.tar.gz
sudo cp -a ntfy_2.16.0_linux_arm64/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_arm64/{client,server}/*.yml /etc/ntfy
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.tar.gz
tar zxvf ntfy_2.19.2_linux_arm64.tar.gz
sudo cp -a ntfy_2.19.2_linux_arm64/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_arm64/{client,server}/*.yml /etc/ntfy
sudo ntfy serve
```
@@ -116,7 +116,7 @@ Manually installing the .deb file:
=== "x86_64/amd64"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_amd64.deb
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.deb
sudo dpkg -i ntfy_*.deb
sudo systemctl enable ntfy
sudo systemctl start ntfy
@@ -124,7 +124,7 @@ Manually installing the .deb file:
=== "armv6"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv6.deb
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.deb
sudo dpkg -i ntfy_*.deb
sudo systemctl enable ntfy
sudo systemctl start ntfy
@@ -132,7 +132,7 @@ Manually installing the .deb file:
=== "armv7/armhf"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv7.deb
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.deb
sudo dpkg -i ntfy_*.deb
sudo systemctl enable ntfy
sudo systemctl start ntfy
@@ -140,7 +140,7 @@ Manually installing the .deb file:
=== "arm64"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_arm64.deb
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.deb
sudo dpkg -i ntfy_*.deb
sudo systemctl enable ntfy
sudo systemctl start ntfy
@@ -150,28 +150,28 @@ Manually installing the .deb file:
=== "x86_64/amd64"
```bash
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_amd64.rpm
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.rpm
sudo systemctl enable ntfy
sudo systemctl start ntfy
```
=== "armv6"
```bash
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv6.rpm
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.rpm
sudo systemctl enable ntfy
sudo systemctl start ntfy
```
=== "armv7/armhf"
```bash
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv7.rpm
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.rpm
sudo systemctl enable ntfy
sudo systemctl start ntfy
```
=== "arm64"
```bash
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_arm64.rpm
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.rpm
sudo systemctl enable ntfy
sudo systemctl start ntfy
```
@@ -213,18 +213,18 @@ pkg install go-ntfy
## macOS
The [ntfy CLI](subscribe/cli.md) (`ntfy publish` and `ntfy subscribe` only) is supported on macOS as well.
To install, please [download the tarball](https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_darwin_all.tar.gz),
To install, please [download the tarball](https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_darwin_all.tar.gz),
extract it and place it somewhere in your `PATH` (e.g. `/usr/local/bin/ntfy`).
If run as `root`, ntfy will look for its config at `/etc/ntfy/client.yml`. For all other users, it'll look for it at
`~/Library/Application Support/ntfy/client.yml` (sample included in the tarball).
```bash
curl -L https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_darwin_all.tar.gz > ntfy_2.16.0_darwin_all.tar.gz
tar zxvf ntfy_2.16.0_darwin_all.tar.gz
sudo cp -a ntfy_2.16.0_darwin_all/ntfy /usr/local/bin/ntfy
curl -L https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_darwin_all.tar.gz > ntfy_2.19.2_darwin_all.tar.gz
tar zxvf ntfy_2.19.2_darwin_all.tar.gz
sudo cp -a ntfy_2.19.2_darwin_all/ntfy /usr/local/bin/ntfy
mkdir ~/Library/Application\ Support/ntfy
cp ntfy_2.16.0_darwin_all/client/client.yml ~/Library/Application\ Support/ntfy/client.yml
cp ntfy_2.19.2_darwin_all/client/client.yml ~/Library/Application\ Support/ntfy/client.yml
ntfy --help
```
@@ -245,7 +245,7 @@ brew install ntfy
The ntfy server and CLI are fully supported on Windows. You can run the ntfy server directly or as a Windows service.
To install, you can either
* [Download the latest ZIP](https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_windows_amd64.zip),
* [Download the latest ZIP](https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_windows_amd64.zip),
extract it and place the `ntfy.exe` binary somewhere in your `%Path%`.
* Or install ntfy from the [Scoop](https://scoop.sh) main repository via `scoop install ntfy`

View File

@@ -184,6 +184,8 @@ I've added a ⭐ to projects or posts that have a significant following, or had
- [BRun](https://github.com/cbrake/brun) - Native Linux automation platform connecting triggers to actions without containers (Go)
- [Uptime Monitor](https://uptime-monitor.org) - Self-hosted, enterprise-grade uptime monitoring and alerting system (TS)
- [send_to_ntfy_extension](https://github.com/TheDuffman85/send_to_ntfy_extension/) ⭐ - A browser extension to send the notifications to ntfy (JS)
- [SIA-Server](https://github.com/ZebMcKayhan/SIA-Server) - A light weight, self-hosted notification Server for Honywell Galaxy Flex alarm systems (Python)
- [zabbix-ntfy](https://github.com/torgrimt/zabbix-ntfy) - Zabbix server Mediatype to add support for ntfy.sh services
## Blog + forum posts
@@ -304,7 +306,7 @@ ntfy community. Thanks to everyone running a public server. **You guys rock!**
| URL | Country |
|---------------------------------------------------|--------------------|
| [ntfy.sh](https://ntfy.sh/) (*Official*) | 🇺🇸 United States |
| [ntfy.tedomum.net](https://ntfy.tedomum.net/) | 🇫🇷 France |
| [ntfy.tedomum.fr](https://ntfy.tedomum.fr/) | 🇫🇷 France |
| [ntfy.jae.fi](https://ntfy.jae.fi/) | 🇫🇮 Finland |
| [ntfy.adminforge.de](https://ntfy.adminforge.de/) | 🇩🇪 Germany |
| [ntfy.envs.net](https://ntfy.envs.net) | 🇩🇪 Germany |

View File

@@ -2304,6 +2304,11 @@ _Supported on:_ :material-android: :material-firefox:
The `copy` action **copies a given value to the clipboard when the action button is tapped**. This is useful for
one-time passcodes, tokens, or any other value you want to quickly copy without opening the full notification.
!!! info
The copy action button is only shown in the web app and Android app notification list, **not** in browser desktop
notifications. This is because browsers do not allow clipboard access from notification actions without direct
user interaction with the page.
Here's an example using the [`X-Actions` header](#using-a-header):
=== "Command line (curl)"

View File

@@ -6,12 +6,145 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
| Component | Version | Release date |
|------------------|---------|--------------|
| ntfy server | v2.16.0 | Jan 19, 2026 |
| ntfy Android app | v1.22.2 | Jan 25, 2026 |
| ntfy server | v2.19.2 | Mar 16, 2026 |
| ntfy Android app | v1.24.0 | Mar 5, 2026 |
| ntfy iOS app | v1.3 | Nov 26, 2023 |
Please check out the release notes for [upcoming releases](#not-released-yet) below.
### ntfy server v2.19.2
Released March 16, 2026
This is another small bugfix release for PostgreSQL, avoiding races between primary and read replica, as well as to
further reduce primary load.
**Bug fixes + maintenance:**
* Fix race condition in web push subscription causing FK constraint violation when concurrent requests hit the same endpoint
* Route authorization query to read-only database replica to reduce primary database load
## ntfy server v2.19.1
Released March 15, 2026
This is a bugfix release to avoid PostgreSQL insert failures due to invalid UTF-8 messages. It also fixes `database-url`
validation incorrectly rejecting `postgresql://` connection strings.
**Bug fixes + maintenance:**
* Fix invalid UTF-8 in HTTP headers (e.g. Latin-1 encoded text) causing PostgreSQL insert failures and dropping entire message batches
* Fix `database-url` validation rejecting `postgresql://` connection strings ([#1657](https://github.com/binwiederhier/ntfy/issues/1657)/[#1658](https://github.com/binwiederhier/ntfy/pull/1658))
## ntfy server v2.19.0
Released March 15, 2026
This is a fast-follow release that enables Postgres read replica support.
To offload read-heavy queries from the primary database, you can optionally configure one or more read replicas
using the `database-replica-urls` option. When configured, non-critical read-only queries (e.g. fetching messages,
checking access permissions, etc) are distributed across the replicas using round-robin, while all writes and
correctness-critical reads continue to go to the primary. If a replica becomes unhealthy, ntfy automatically falls back
to the primary until the replica recovers.
**Features:**
* Support [PostgreSQL read replicas](config.md#postgresql-experimental) for offloading non-critical read queries via `database-replica-urls` config option ([#1648](https://github.com/binwiederhier/ntfy/pull/1648))
* Add interactive [config generator](config.md#config-generator) to the documentation to help create server configuration files ([#1654](https://github.com/binwiederhier/ntfy/pull/1654))
**Bug fixes + maintenance:**
* Web: Throttle notification sound in web app to play at most once every 2 seconds (similar to [#1550](https://github.com/binwiederhier/ntfy/issues/1550), thanks to [@jlaffaye](https://github.com/jlaffaye) for reporting)
* Web: Add hover tooltips to icon buttons in web app account and preferences pages ([#1565](https://github.com/binwiederhier/ntfy/issues/1565), thanks to [@jermanuts](https://github.com/jermanuts) for reporting)
## ntfy server v2.18.0
Released March 7, 2026
This is the biggest release I've ever done on the server. It's 14,997 added lines of code, and 10,202 lines removed, all from
one [pull request](https://github.com/binwiederhier/ntfy/pull/1619) that adds [PostgreSQL support](config.md#postgresql-experimental).
The code was written by Cursor and Claude, but reviewed and heavily tested over 2-3 weeks by me. I created comparison documents,
went through all queries multiple times and reviewed the logic over and over again. I also did load tests and manual regression tests,
which took lots of evenings.
I'll not instantly switch ntfy.sh over. Instead, I'm kindly asking the community to test the Postgres support and report back to me
if things are working (or not working). There is a [one-off migration tool](https://github.com/binwiederhier/ntfy/tree/main/tools/pgimport) (entirely written by AI) that you can use to migrate.
**Features:**
* Add experimental [PostgreSQL support](config.md#postgresql-experimental) as an alternative database backend (message cache, user manager, web push subscriptions) via `database-url` config option ([#1114](https://github.com/binwiederhier/ntfy/issues/1114)/[#1619](https://github.com/binwiederhier/ntfy/pull/1619), thanks to [@brettinternet](https://github.com/brettinternet) for reporting)
**Bug fixes + maintenance:**
* Preserve `<br>` line breaks in HTML-only emails received via SMTP ([#690](https://github.com/binwiederhier/ntfy/issues/690), [#1620](https://github.com/binwiederhier/ntfy/pull/1620), thanks to [@uzkikh](https://github.com/uzkikh) for the fix and to [@teastrainer](https://github.com/teastrainer) for reporting)
## ntfy Android v1.24.0
Released March 5, 2026
This is a tiny release that will revert the "reconnecting ..." behavior of the foreground notification. Lots of people
have complained about it, so I'm replacing it with a notification that shows up when the server connection has failed
for >15 minutes, hoping that people will be less annoyed by that.
**Features:**
* Show notification when connection to server has been lost for 15+ minutes, with dismiss, snooze and never-show-again actions
**Bug fixes + maintenance:**
* Fix crash in settings when fragment is detached during backup/restore or log operations
## ntfy Android v1.23.0
Released February 22, 2026
This release adds support for search within a topic, and adds [copy action](publish.md#copy-to-clipboard) support
to the Android app.
**Features:**
* Search within a topic ([#141](https://github.com/binwiederhier/ntfy/issues/141), [ntfy-android#153](https://github.com/binwiederhier/ntfy-android/pull/153), thanks to [@Copephobia](https://github.com/Copephobia) and [@StoyanYonkov](https://github.com/StoyanYonkov) for reporting and sponsoring)
* Add "reconnecting to N topics ..." to foreground notification ([#1101](https://github.com/binwiederhier/ntfy/issues/1101), thanks to [@milosivanovic](https://github.com/milosivanovic) for reporting)
* Improved default server dialog with full-screen UI and stricter URL validation ([#1582](https://github.com/binwiederhier/ntfy/issues/1582))
* Show last notification time for UnifiedPush subscriptions ([#1230](https://github.com/binwiederhier/ntfy/issues/1230), [#1454](https://github.com/binwiederhier/ntfy/issues/1454), thanks to [@Tealk](https://github.com/Tealk) and [@user4andre](https://github.com/user4andre) for reporting)
* Support "copy" action button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
**Bug fixes + maintenance:**
* Fix `clear=true` on action buttons not marking notification as read ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
* Fix crash when default server URL is missing scheme by auto-prepending `https://` ([#1582](https://github.com/binwiederhier/ntfy/issues/1582), thanks to [@hard-zero1](https://github.com/hard-zero1))
* Fix notification timestamp to use original send time instead of receive time ([#1112](https://github.com/binwiederhier/ntfy/issues/1112), thanks to [@voruti](https://github.com/voruti) for reporting)
* Fix notifications being missed after service restart by using persisted lastNotificationId ([#1591](https://github.com/binwiederhier/ntfy/issues/1591), thanks to @Epifeny for reporting)
## ntfy server v2.17.0
Released February 8, 2026
This release adds support for templating in the priority field, a new "copy" action button to copy values to the clipboard,
a red notification dot on the favicon for unread messages, and an admin-only version endpoint. It also includes several
crash fixes, web app improvements, and documentation updates.
❤️ If you like ntfy, **please consider sponsoring me** via [GitHub Sponsors](https://github.com/sponsors/binwiederhier), [Liberapay](https://en.liberapay.com/ntfy/), Bitcoin (`1626wjrw3uWk9adyjCfYwafw4sQWujyjn8`),
or by buying a [paid plan via the web app](https://ntfy.sh/app). ntfy will always remain open source.
**Features:**
* Server: Support templating in the priority field ([#1426](https://github.com/binwiederhier/ntfy/issues/1426), thanks to [@seantomburke](https://github.com/seantomburke) for reporting)
* Server: Add admin-only `GET /v1/version` endpoint returning server version, build commit, and date ([#1599](https://github.com/binwiederhier/ntfy/issues/1599), thanks to [@crivchri](https://github.com/crivchri) for reporting)
* Server/Web: [Support "copy" action](publish.md#copy-to-clipboard) button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
* Web: Show red notification dot on favicon when there are unread messages ([#1017](https://github.com/binwiederhier/ntfy/issues/1017), thanks to [@ad-si](https://github.com/ad-si) for reporting)
**Bug fixes + maintenance:**
* Server: Fix crash when commit string is shorter than 7 characters in non-GitHub-Action builds ([#1493](https://github.com/binwiederhier/ntfy/issues/1493), thanks to [@cyrinux](https://github.com/cyrinux) for reporting)
* Server: Fix server crash (nil pointer panic) when subscriber disconnects during publish ([#1598](https://github.com/binwiederhier/ntfy/pull/1598))
* Server: Fix log spam from `http: response.WriteHeader on hijacked connection` for WebSocket errors ([#1362](https://github.com/binwiederhier/ntfy/issues/1362), thanks to [@bonfiresh](https://github.com/bonfiresh) for reporting)
* Server: Use `slices.Contains` from stdlib to simplify code ([#1406](https://github.com/binwiederhier/ntfy/pull/1406), thanks to [@tanhuaan](https://github.com/tanhuaan))
* Web: Fix `clear=true` on action buttons not clearing the notification ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
* Web: Fix Markdown message line height to match plain text (1.5 instead of 1.2) ([#1139](https://github.com/binwiederhier/ntfy/issues/1139), thanks to [@etfz](https://github.com/etfz) for reporting)
* Web: Fix long lines (e.g. JSON) being truncated by adding horizontal scroll ([#1363](https://github.com/binwiederhier/ntfy/issues/1363), thanks to [@v3DJG6GL](https://github.com/v3DJG6GL) for reporting)
* Web: Fix Windows notification icon being cut off ([#884](https://github.com/binwiederhier/ntfy/issues/884), thanks to [@ZhangTianrong](https://github.com/ZhangTianrong) for reporting)
* Web: Use full URL in curl example on empty topic pages ([#1435](https://github.com/binwiederhier/ntfy/issues/1435), [#1535](https://github.com/binwiederhier/ntfy/pull/1535), thanks to [@elmatadoor](https://github.com/elmatadoor) for reporting and [@jjasghar](https://github.com/jjasghar) for the PR)
* Web: Add validation feedback for service URL when adding user ([#1566](https://github.com/binwiederhier/ntfy/issues/1566), thanks to [@jermanuts](https://github.com/jermanuts))
* Docs: Remove obsolete `version` field from docker-compose examples ([#1333](https://github.com/binwiederhier/ntfy/issues/1333), thanks to [@seals187](https://github.com/seals187) for reporting and [@cyb3rko](https://github.com/cyb3rko) for fixing)
* Docs: Fix Kustomize config in installation docs ([#1367](https://github.com/binwiederhier/ntfy/issues/1367), thanks to [@toby-griffiths](https://github.com/toby-griffiths))
* Docs: Use SVG F-Droid badge and add app store badges to README ([#1170](https://github.com/binwiederhier/ntfy/issues/1170), thanks to [@PanderMusubi](https://github.com/PanderMusubi) for reporting)
## ntfy Android app v1.22.2
Released January 20, 2026
@@ -1665,43 +1798,12 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
## Not released yet
### ntfy Android v1.23.x (UNRELEASED)
### ntfy server v2.20.x (UNRELEASED)
**Features:**
* Search within a topic ([#141](https://github.com/binwiederhier/ntfy/issues/141), [ntfy-android#153](https://github.com/binwiederhier/ntfy-android/pull/153), thanks to [@Copephobia](https://github.com/Copephobia) and [@StoyanYonkov](https://github.com/StoyanYonkov) for reporting and sponsoring)
* Add "reconnecting to N topics ..." to foreground notification ([#1101](https://github.com/binwiederhier/ntfy/issues/1101), thanks to [@milosivanovic](https://github.com/milosivanovic) for reporting)
* Improved default server dialog with full-screen UI and stricter URL validation ([#1582](https://github.com/binwiederhier/ntfy/issues/1582))
* Show last notification time for UnifiedPush subscriptions ([#1230](https://github.com/binwiederhier/ntfy/issues/1230), [#1454](https://github.com/binwiederhier/ntfy/issues/1454), thanks to [@Tealk](https://github.com/Tealk) and [@user4andre](https://github.com/user4andre) for reporting)
* Support "copy" action button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
* Add S3-compatible object storage as an alternative attachment backend via `attachment-s3-url` config option
**Bug fixes + maintenance:**
* Fix `clear=true` on action buttons not marking notification as read ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
* Fix crash when default server URL is missing scheme by auto-prepending `https://` ([#1582](https://github.com/binwiederhier/ntfy/issues/1582), thanks to [@hard-zero1](https://github.com/hard-zero1))
* Fix notification timestamp to use original send time instead of receive time ([#1112](https://github.com/binwiederhier/ntfy/issues/1112), thanks to [@voruti](https://github.com/voruti) for reporting)
* Fix notifications being missed after service restart by using persisted lastNotificationId ([#1591](https://github.com/binwiederhier/ntfy/issues/1591), thanks to @Epifeny for reporting)
### ntfy server v2.17.x (UNRELEASED)
**Features:**
* Server: Support templating in the priority field ([#1426](https://github.com/binwiederhier/ntfy/issues/1426), thanks to [@seantomburke](https://github.com/seantomburke) for reporting)
* Server/Web: Support "copy" action button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
* Web: Show red notification dot on favicon when there are unread messages ([#1017](https://github.com/binwiederhier/ntfy/issues/1017), thanks to [@ad-si](https://github.com/ad-si) for reporting)
**Bug fixes + maintenance:**
* Server: Fix crash when commit string is shorter than 7 characters in non-GitHub-Action builds ([#1493](https://github.com/binwiederhier/ntfy/issues/1493), thanks to [@cyrinux](https://github.com/cyrinux) for reporting)
* Server: Fix server crash (nil pointer panic) when subscriber disconnects during publish ([#1598](https://github.com/binwiederhier/ntfy/pull/1598))
* Server: Fix log spam from `http: response.WriteHeader on hijacked connection` for WebSocket errors ([#1362](https://github.com/binwiederhier/ntfy/issues/1362), thanks to [@bonfiresh](https://github.com/bonfiresh) for reporting)
* Server: Use `slices.Contains` from stdlib to simplify code ([#1406](https://github.com/binwiederhier/ntfy/pull/1406), thanks to [@tanhuaan](https://github.com/tanhuaan))
* Web: Fix `clear=true` on action buttons not clearing the notification ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
* Web: Fix Markdown message line height to match plain text (1.5 instead of 1.2) ([#1139](https://github.com/binwiederhier/ntfy/issues/1139), thanks to [@etfz](https://github.com/etfz) for reporting)
* Web: Fix long lines (e.g. JSON) being truncated by adding horizontal scroll ([#1363](https://github.com/binwiederhier/ntfy/issues/1363), thanks to [@v3DJG6GL](https://github.com/v3DJG6GL) for reporting)
* Web: Fix Windows notification icon being cut off ([#884](https://github.com/binwiederhier/ntfy/issues/884), thanks to [@ZhangTianrong](https://github.com/ZhangTianrong) for reporting)
* Web: Use full URL in curl example on empty topic pages ([#1435](https://github.com/binwiederhier/ntfy/issues/1435), [#1535](https://github.com/binwiederhier/ntfy/pull/1535), thanks to [@elmatadoor](https://github.com/elmatadoor) for reporting and [@jjasghar](https://github.com/jjasghar) for the PR)
* Web: Add validation feedback for service URL when adding user ([#1566](https://github.com/binwiederhier/ntfy/issues/1566), thanks to [@jermanuts](https://github.com/jermanuts))
* Docs: Remove obsolete `version` field from docker-compose examples ([#1333](https://github.com/binwiederhier/ntfy/issues/1333), thanks to [@seals187](https://github.com/seals187) for reporting and [@cyb3rko](https://github.com/cyb3rko) for fixing)
* Docs: Fix Kustomize config in installation docs ([#1367](https://github.com/binwiederhier/ntfy/issues/1367), thanks to [@toby-griffiths](https://github.com/toby-griffiths))
* Docs: Use SVG F-Droid badge and add app store badges to README ([#1170](https://github.com/binwiederhier/ntfy/issues/1170), thanks to [@PanderMusubi](https://github.com/PanderMusubi) for reporting)
* Reject invalid e-mail addresses (e.g. multiple comma-separated recipients) with HTTP 400

853
docs/static/css/config-generator.css vendored Normal file
View File

@@ -0,0 +1,853 @@
/* Config Generator */
/* Hidden utility */
.cg-hidden {
display: none !important;
}
/* Open button */
.cg-open-btn {
display: inline-block;
padding: 8px 20px;
background: var(--md-primary-fg-color);
color: #fff;
border: none;
border-radius: 4px;
font-size: 0.85rem;
font-weight: 500;
cursor: pointer;
font-family: inherit;
transition: opacity 0.15s;
}
.cg-open-btn:hover {
opacity: 0.85;
}
/* Modal overlay */
.cg-modal {
position: fixed;
inset: 0;
z-index: 1000;
}
.cg-modal-backdrop {
position: absolute;
inset: 0;
background: rgba(0, 0, 0, 0.5);
}
.cg-modal-dialog {
position: absolute;
inset: 24px;
display: flex;
flex-direction: column;
background: #fff;
border-radius: 10px;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.25);
overflow: hidden;
font-size: 0.78rem;
}
.cg-modal-header {
display: flex;
align-items: center;
justify-content: space-between;
padding: 12px 20px;
border-bottom: 1px solid #ddd;
flex-shrink: 0;
}
.cg-modal-header-left {
display: flex;
align-items: baseline;
gap: 12px;
min-width: 0;
}
.cg-modal-title {
font-weight: 600;
font-size: 0.95rem;
white-space: nowrap;
}
.cg-badge-beta {
display: inline-block;
padding: 1px 8px;
margin-left: 8px;
background: var(--md-primary-fg-color);
color: #fff;
font-size: 0.6rem;
font-weight: 600;
border-radius: 10px;
letter-spacing: 0.5px;
vertical-align: middle;
}
.cg-modal-desc {
font-size: 0.75rem;
color: #888;
}
.cg-modal-header-actions {
display: flex;
align-items: center;
gap: 8px;
}
.cg-modal-reset {
background: none;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 0.75rem;
color: #777;
cursor: pointer;
padding: 4px 12px;
font-family: inherit;
transition: color 0.15s, border-color 0.15s;
}
.cg-modal-reset:hover {
color: #333;
border-color: #999;
}
.cg-modal-close {
background: none;
border: none;
font-size: 1.4rem;
color: #999;
cursor: pointer;
padding: 0 4px;
line-height: 1;
}
.cg-modal-close:hover {
color: #333;
}
/* Modal body: left + right */
.cg-modal-body {
display: flex;
flex: 1;
min-height: 0;
overflow: hidden;
}
/* Left panel */
#cg-left {
flex: 1;
display: flex;
flex-direction: column;
border-right: 1px solid #ddd;
min-width: 0;
}
.cg-nav {
display: flex;
flex-wrap: wrap;
gap: 0;
border-bottom: 1px solid #ddd;
flex-shrink: 0;
padding: 0 16px;
}
.cg-nav-tab {
padding: 9px 14px;
cursor: pointer;
font-size: 0.78rem;
font-weight: 500;
color: #777;
border-bottom: 2px solid transparent;
margin-bottom: -1px;
user-select: none;
transition: color 0.15s, border-color 0.15s;
white-space: nowrap;
}
.cg-nav-tab:hover {
color: #444;
}
.cg-nav-tab.active {
color: var(--md-primary-fg-color);
border-bottom-color: var(--md-primary-fg-color);
}
.cg-panels {
flex: 1;
overflow-y: auto;
padding: 16px 20px;
}
.cg-panel {
display: none;
}
.cg-panel.active {
display: block;
}
.cg-panel-desc {
font-size: 0.75rem;
color: #888;
margin-bottom: 12px;
line-height: 1.5;
}
.cg-panel-desc a {
color: var(--md-primary-fg-color);
}
.cg-help {
color: var(--md-primary-fg-color);
text-decoration: none;
margin-left: 4px;
vertical-align: middle;
flex-shrink: 0;
transition: color 0.15s;
}
.cg-help:hover {
color: var(--md-primary-fg-color--dark, #2a6e5f);
}
/* Right panel */
#cg-right {
flex: 1;
display: flex;
flex-direction: column;
min-width: 0;
}
.cg-output-tabs {
display: flex;
border-bottom: 1px solid #ddd;
flex-shrink: 0;
padding: 0 16px;
}
.cg-output-tab {
padding: 9px 14px;
cursor: pointer;
font-size: 0.78rem;
font-weight: 500;
border-bottom: 2px solid transparent;
margin-bottom: -1px;
color: #777;
transition: color 0.15s, border-color 0.15s;
user-select: none;
white-space: nowrap;
}
.cg-output-tab:hover {
color: #444;
}
.cg-output-tab.active {
color: var(--md-primary-fg-color);
border-bottom-color: var(--md-primary-fg-color);
}
.cg-btn-copy {
margin-left: auto;
background: none;
color: #777;
border: none;
border-bottom: 2px solid transparent;
margin-bottom: -1px;
padding: 9px 10px;
cursor: pointer;
line-height: 1;
display: flex;
align-items: center;
justify-content: center;
transition: color 0.15s;
}
.cg-btn-copy:hover {
color: #333;
}
.cg-output-wrap {
flex: 1;
overflow: auto;
padding: 16px 20px;
display: flex;
flex-direction: column;
}
.cg-output-wrap pre {
margin: 0;
padding: 8px 10px;
background: #f5f5f5;
color: #333;
border: 1px solid #ddd;
border-radius: 6px;
overflow-x: auto;
font-size: 0.76rem;
line-height: 1.5;
flex: 1;
white-space: pre;
}
.cg-empty-msg {
color: #888;
font-style: italic;
}
.cg-warning {
padding: 6px 10px;
margin-top: 8px;
background: #fff3cd;
color: #856404;
border: 1px solid #ffc107;
border-radius: 4px;
font-size: 0.76rem;
}
/* Form fields */
.cg-field {
margin-bottom: 0;
padding: 8px 12px;
}
.cg-field:nth-child(odd) {
background: #f8f8f8;
}
.cg-field:nth-child(even) {
background: #fff;
}
.cg-field > label {
display: block;
font-weight: 500;
margin-bottom: 4px;
font-size: 0.78rem;
color: #555;
}
.cg-field input[type="text"],
.cg-field input[type="password"],
.cg-field select {
width: 100%;
padding: 6px 8px;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 0.78rem;
font-family: inherit;
box-sizing: border-box;
background: #fff;
}
.cg-field input[type="text"]:focus,
.cg-field input[type="password"]:focus,
.cg-field select:focus {
border-color: var(--md-primary-fg-color);
outline: none;
box-shadow: 0 0 0 2px rgba(51, 133, 116, 0.15);
}
.cg-checkbox {
display: flex;
align-items: center;
gap: 6px;
margin-bottom: 10px;
}
.cg-checkbox input[type="checkbox"] {
accent-color: var(--md-primary-fg-color);
}
.cg-checkbox label {
font-weight: 500;
font-size: 0.78rem;
margin: 0;
cursor: pointer;
}
.cg-radio-group {
display: flex;
flex-direction: column;
gap: 4px;
}
.cg-radio-group label {
display: flex;
align-items: center;
gap: 4px;
font-weight: 400;
font-size: 0.78rem;
cursor: pointer;
}
.cg-radio-group input[type="radio"] {
accent-color: var(--md-primary-fg-color);
}
/* Inline field: label + control side by side */
.cg-inline-field {
display: flex;
align-items: center;
gap: 12px;
}
.cg-inline-field > label {
margin-bottom: 0;
width: 60%;
flex-shrink: 0;
}
.cg-panel:not(#cg-panel-general) .cg-inline-field > label {
width: 50%;
}
.cg-inline-field > input[type="text"],
.cg-inline-field > select {
padding: 4px 10px;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 0.75rem;
font-family: inherit;
box-sizing: border-box;
line-height: 1.4;
background: #fff;
}
.cg-inline-field > input[type="text"]:focus,
.cg-inline-field > select:focus {
border-color: var(--md-primary-fg-color);
outline: none;
box-shadow: 0 0 0 2px rgba(51, 133, 116, 0.15);
}
#cg-email-in-section {
margin-top: 20px;
}
.cg-pg-label {
font-size: 0.75rem;
color: #888;
font-style: italic;
}
/* Button group toggle */
.cg-btn-group {
display: flex;
flex-shrink: 0;
}
.cg-btn-group label {
cursor: pointer;
margin: 0;
}
.cg-btn-group input[type="radio"] {
display: none;
}
.cg-btn-group span {
display: block;
padding: 4px 14px;
font-size: 0.75rem;
font-weight: 500;
line-height: 1.4;
border: 1px solid #ccc;
color: #555;
background: #fff;
transition: background 0.15s, color 0.15s, border-color 0.15s;
user-select: none;
}
.cg-btn-group label:first-child span {
border-radius: 4px 0 0 4px;
}
.cg-btn-group label:last-child span {
border-radius: 0 4px 4px 0;
}
.cg-btn-group label + label span {
margin-left: -1px;
}
.cg-btn-group input[type="radio"]:checked + span {
background: var(--md-primary-fg-color);
color: #fff;
border-color: var(--md-primary-fg-color);
position: relative;
z-index: 1;
}
.cg-feature-grid {
display: flex;
flex-direction: column;
gap: 5px;
}
.cg-feature-row {
display: flex;
align-items: center;
gap: 8px;
}
.cg-feature-grid label {
display: flex;
align-items: center;
gap: 6px;
font-size: 0.78rem;
cursor: pointer;
}
.cg-feature-grid input[type="checkbox"] {
accent-color: var(--md-primary-fg-color);
}
.cg-btn-configure {
background: none;
border: 1px solid var(--md-primary-fg-color);
border-radius: 10px;
color: var(--md-primary-fg-color);
font-size: 0.68rem;
font-family: inherit;
padding: 1px 10px;
cursor: pointer;
white-space: nowrap;
transition: background 0.15s, color 0.15s;
}
.cg-btn-configure:hover {
background: var(--md-primary-fg-color);
color: #fff;
}
/* Repeatable rows */
.cg-repeatable-row {
display: flex;
gap: 6px;
align-items: center;
margin-bottom: 6px;
flex-wrap: wrap;
}
.cg-repeatable-row input,
.cg-repeatable-row select {
flex: 1;
min-width: 80px;
padding: 5px 6px;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 0.75rem;
font-family: inherit;
box-sizing: border-box;
}
.cg-repeatable-row input:focus,
.cg-repeatable-row select:focus {
border-color: var(--md-primary-fg-color);
outline: none;
}
.cg-repeatable-row input:disabled,
.cg-repeatable-row select:disabled {
background: #eee;
color: #999;
cursor: not-allowed;
}
.cg-btn-remove {
background: none;
border: 1px solid #ccc;
border-radius: 4px;
cursor: pointer;
font-size: 0.9rem;
padding: 5px 8px;
color: #999;
line-height: 1;
}
.cg-btn-remove:hover {
background: #fee;
border-color: #c66;
color: #c33;
}
.cg-btn-add {
background: none;
border: 1px dashed #bbb;
border-radius: 4px;
cursor: pointer;
padding: 5px 10px;
font-size: 0.75rem;
color: #777;
margin-top: 2px;
}
.cg-btn-add:hover {
border-color: var(--md-primary-fg-color);
color: var(--md-primary-fg-color);
}
/* Dark mode */
body[data-md-color-scheme="slate"] .cg-modal-dialog {
background: #1e1e2e;
}
body[data-md-color-scheme="slate"] .cg-modal-header {
border-bottom-color: #444;
}
body[data-md-color-scheme="slate"] .cg-modal-title {
color: #ddd;
}
body[data-md-color-scheme="slate"] .cg-modal-desc {
color: #777;
}
body[data-md-color-scheme="slate"] .cg-modal-reset {
border-color: #555;
color: #888;
}
body[data-md-color-scheme="slate"] .cg-modal-reset:hover {
border-color: #888;
color: #ddd;
}
body[data-md-color-scheme="slate"] .cg-modal-close {
color: #777;
}
body[data-md-color-scheme="slate"] .cg-modal-close:hover {
color: #ddd;
}
body[data-md-color-scheme="slate"] #cg-left {
border-right-color: #444;
}
body[data-md-color-scheme="slate"] .cg-nav {
border-bottom-color: #444;
}
body[data-md-color-scheme="slate"] .cg-nav-tab {
color: #888;
}
body[data-md-color-scheme="slate"] .cg-nav-tab:hover {
color: #bbb;
}
body[data-md-color-scheme="slate"] .cg-output-tabs {
border-bottom-color: #444;
}
body[data-md-color-scheme="slate"] .cg-output-tab {
color: #888;
}
body[data-md-color-scheme="slate"] .cg-output-tab:hover {
color: #bbb;
}
body[data-md-color-scheme="slate"] .cg-btn-copy {
color: #777;
}
body[data-md-color-scheme="slate"] .cg-btn-copy:hover {
color: #bbb;
}
body[data-md-color-scheme="slate"] .cg-output-wrap pre {
background: #161620;
color: #ddd;
border-color: #444;
}
body[data-md-color-scheme="slate"] .cg-field:nth-child(odd) {
background: #232334;
}
body[data-md-color-scheme="slate"] .cg-field:nth-child(even) {
background: #1e1e2e;
}
body[data-md-color-scheme="slate"] .cg-panel-desc {
color: #777;
}
body[data-md-color-scheme="slate"] .cg-field > label {
color: #aaa;
}
body[data-md-color-scheme="slate"] .cg-field input[type="text"],
body[data-md-color-scheme="slate"] .cg-field input[type="password"],
body[data-md-color-scheme="slate"] .cg-field select {
background: #2a2a3a;
border-color: #555;
color: #ddd;
}
body[data-md-color-scheme="slate"] .cg-btn-group span {
background: #2a2a3a;
border-color: #555;
color: #aaa;
}
body[data-md-color-scheme="slate"] .cg-btn-group input[type="radio"]:checked + span {
background: var(--md-primary-fg-color);
color: #fff;
border-color: var(--md-primary-fg-color);
}
body[data-md-color-scheme="slate"] .cg-checkbox label {
color: #ccc;
}
body[data-md-color-scheme="slate"] .cg-radio-group label {
color: #ccc;
}
body[data-md-color-scheme="slate"] .cg-feature-grid label {
color: #ccc;
}
body[data-md-color-scheme="slate"] .cg-repeatable-row input,
body[data-md-color-scheme="slate"] .cg-repeatable-row select {
background: #2a2a3a;
border-color: #555;
color: #ddd;
}
body[data-md-color-scheme="slate"] .cg-repeatable-row input:disabled,
body[data-md-color-scheme="slate"] .cg-repeatable-row select:disabled {
background: #1a1a28;
color: #666;
}
body[data-md-color-scheme="slate"] .cg-btn-remove {
border-color: #555;
color: #888;
}
body[data-md-color-scheme="slate"] .cg-btn-add {
border-color: #555;
color: #888;
}
body[data-md-color-scheme="slate"] .cg-warning {
background: #3a2e00;
color: #ffc107;
border-color: #665200;
}
/* Mobile toggle bar (hidden on desktop) */
.cg-mobile-toggle {
display: none;
}
/* Responsive */
@media (max-width: 900px) {
.cg-modal-dialog {
inset: 0;
border-radius: 0;
}
.cg-modal-header {
padding: 8px 16px;
}
.cg-modal-title {
font-size: 0.85rem;
}
.cg-modal-desc {
display: none;
}
.cg-modal-body {
flex-direction: column;
}
.cg-mobile-toggle {
display: flex;
flex-shrink: 0;
border-bottom: 1px solid #ddd;
}
.cg-mobile-toggle-btn {
flex: 1;
padding: 8px 0;
border: none;
background: #f5f5f5;
font-size: 0.78rem;
font-weight: 500;
font-family: inherit;
color: #777;
cursor: pointer;
transition: background 0.15s, color 0.15s;
}
.cg-mobile-toggle-btn.active {
background: #fff;
color: var(--md-primary-fg-color);
box-shadow: inset 0 -2px 0 var(--md-primary-fg-color);
}
#cg-left {
border-right: none;
flex: 1;
min-height: 0;
}
#cg-right {
flex: 1;
display: none;
min-height: 0;
}
#cg-right.cg-mobile-active {
display: flex;
}
#cg-left.cg-mobile-hidden {
display: none;
}
.cg-nav {
overflow-x: auto;
flex-wrap: nowrap;
-webkit-overflow-scrolling: touch;
}
.cg-inline-field {
flex-direction: column;
align-items: stretch;
gap: 4px;
}
.cg-inline-field > label {
width: 100%;
}
.cg-panel:not(#cg-panel-general) .cg-inline-field > label {
width: 100%;
}
}
/* Dark mode mobile toggle */
body[data-md-color-scheme="slate"] .cg-mobile-toggle {
border-bottom-color: #444;
}
body[data-md-color-scheme="slate"] .cg-mobile-toggle-btn {
background: #2a2a3a;
color: #888;
}
body[data-md-color-scheme="slate"] .cg-mobile-toggle-btn.active {
background: #1e1e2e;
color: var(--md-primary-fg-color);
}

BIN
docs/static/img/config-generator.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 318 KiB

1220
docs/static/js/bcrypt.js vendored Normal file

File diff suppressed because it is too large Load Diff

1357
docs/static/js/config-generator.js vendored Normal file

File diff suppressed because it is too large Load Diff

66
go.mod
View File

@@ -1,25 +1,25 @@
module heckel.io/ntfy/v2
go 1.24.6
go 1.25.0
require (
cloud.google.com/go/firestore v1.21.0 // indirect
cloud.google.com/go/storage v1.59.2 // indirect
cloud.google.com/go/storage v1.61.3 // indirect
github.com/BurntSushi/toml v1.6.0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect
github.com/emersion/go-smtp v0.18.0
github.com/gabriel-vasile/mimetype v1.4.13
github.com/gorilla/websocket v1.5.3
github.com/mattn/go-sqlite3 v1.14.33
github.com/mattn/go-sqlite3 v1.14.34
github.com/olebedev/when v1.1.0
github.com/stretchr/testify v1.11.1
github.com/urfave/cli/v2 v2.27.7
golang.org/x/crypto v0.47.0
golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/sync v0.19.0
golang.org/x/term v0.39.0
golang.org/x/time v0.14.0
google.golang.org/api v0.265.0
golang.org/x/crypto v0.49.0
golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/sync v0.20.0
golang.org/x/term v0.41.0
golang.org/x/time v0.15.0
google.golang.org/api v0.271.0
gopkg.in/yaml.v2 v2.4.0
)
@@ -30,17 +30,18 @@ require github.com/pkg/errors v0.9.1 // indirect
require (
firebase.google.com/go/v4 v4.19.0
github.com/SherClockHolmes/webpush-go v1.4.0
github.com/jackc/pgx/v5 v5.8.0
github.com/microcosm-cc/bluemonday v1.0.27
github.com/prometheus/client_golang v1.23.2
github.com/stripe/stripe-go/v74 v74.30.0
golang.org/x/sys v0.40.0
golang.org/x/text v0.33.0
golang.org/x/sys v0.42.0
golang.org/x/text v0.35.0
)
require (
cel.dev/expr v0.25.1 // indirect
cloud.google.com/go v0.123.0 // indirect
cloud.google.com/go/auth v0.18.1 // indirect
cloud.google.com/go/auth v0.18.3-0.20260310051336-87cdcc9f7568 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
cloud.google.com/go/iam v1.5.3 // indirect
@@ -57,8 +58,8 @@ require (
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 // indirect
github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect
github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect
github.com/envoyproxy/go-control-plane/envoy v1.37.0 // indirect
github.com/envoyproxy/protoc-gen-validate v1.3.3 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/go-logr/logr v1.4.3 // indirect
@@ -68,35 +69,38 @@ require (
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect
github.com/googleapis/gax-go/v2 v2.17.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.18.0 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/prometheus/procfs v0.20.1 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect
go.opentelemetry.io/otel v1.40.0 // indirect
go.opentelemetry.io/otel/metric v1.40.0 // indirect
go.opentelemetry.io/otel/sdk v1.40.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.40.0 // indirect
go.opentelemetry.io/otel/trace v1.40.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/net v0.49.0 // indirect
go.opentelemetry.io/contrib/detectors/gcp v1.42.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
go.opentelemetry.io/otel v1.42.0 // indirect
go.opentelemetry.io/otel/metric v1.42.0 // indirect
go.opentelemetry.io/otel/sdk v1.42.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.42.0 // indirect
go.opentelemetry.io/otel/trace v1.42.0 // indirect
go.yaml.in/yaml/v2 v2.4.4 // indirect
golang.org/x/net v0.52.0 // indirect
google.golang.org/appengine/v2 v2.0.6 // indirect
google.golang.org/genproto v0.0.0-20260203192932-546029d2fa20 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 // indirect
google.golang.org/grpc v1.78.0 // indirect
google.golang.org/genproto v0.0.0-20260311181403-84a4fc48630c // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260311181403-84a4fc48630c // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c // indirect
google.golang.org/grpc v1.79.2 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

141
go.sum
View File

@@ -2,8 +2,8 @@ cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4=
cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4=
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs=
cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA=
cloud.google.com/go/auth v0.18.3-0.20260310051336-87cdcc9f7568 h1:PJt3KrySfZkKdcEV2wlyNkfAPbMZGjtnv5oLrT4tWPg=
cloud.google.com/go/auth v0.18.3-0.20260310051336-87cdcc9f7568/go.mod h1:/Tt0rLCp4FHXEBtdyYqvIZPcJzbpJ/fmqtgIaXseDK4=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
@@ -12,14 +12,14 @@ cloud.google.com/go/firestore v1.21.0 h1:BhopUsx7kh6NFx77ccRsHhrtkbJUmDAxNY3uapW
cloud.google.com/go/firestore v1.21.0/go.mod h1:1xH6HNcnkf/gGyR8udd6pFO4Z7GWJSwLKQMx/u6UrP4=
cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc=
cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU=
cloud.google.com/go/logging v1.13.1 h1:O7LvmO0kGLaHY/gq8cV7T0dyp6zJhYAOtZPX4TF3QtY=
cloud.google.com/go/logging v1.13.1/go.mod h1:XAQkfkMBxQRjQek96WLPNze7vsOmay9H5PqfsNYDqvw=
cloud.google.com/go/logging v1.13.2 h1:qqlHCBvieJT9Cdq4QqYx1KPadCQ2noD4FK02eNqHAjA=
cloud.google.com/go/logging v1.13.2/go.mod h1:zaybliM3yun1J8mU2dVQ1/qDzjbOqEijZCn6hSBtKak=
cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8=
cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk=
cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE=
cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI=
cloud.google.com/go/storage v1.59.2 h1:gmOAuG1opU8YvycMNpP+DvHfT9BfzzK5Cy+arP+Nocw=
cloud.google.com/go/storage v1.59.2/go.mod h1:cMWbtM+anpC74gn6qjLh+exqYcfmB9Hqe5z6adx+CLI=
cloud.google.com/go/storage v1.61.3 h1:VS//ZfBuPGDvakfD9xyPW1RGF1Vy3BWUoVZXgW1KMOg=
cloud.google.com/go/storage v1.61.3/go.mod h1:JtqK8BBB7TWv0HVGHubtUdzYYrakOQIsMLffZ2Z/HWk=
cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U=
cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s=
firebase.google.com/go/v4 v4.19.0 h1:f5NMlC2YHFsncz00c2+ecBr+ZYlRMhKIhj1z8Iz0lD8=
@@ -58,14 +58,14 @@ github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 h1:oP4q0fw+fOSWn3
github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
github.com/emersion/go-smtp v0.17.0 h1:tq90evlrcyqRfE6DSXaWVH54oX6OuZOQECEmhWBMEtI=
github.com/emersion/go-smtp v0.17.0/go.mod h1:qm27SGYgoIPRot6ubfQ/GpiPy/g3PaZAVRxiO/sDUgQ=
github.com/envoyproxy/go-control-plane v0.13.5-0.20251024222203-75eaa193e329 h1:K+fnvUM0VZ7ZFJf0n4L/BRlnsb9pL/GuDG6FqaH+PwM=
github.com/envoyproxy/go-control-plane v0.13.5-0.20251024222203-75eaa193e329/go.mod h1:Alz8LEClvR7xKsrq3qzoc4N0guvVNSS8KmSChGYr9hs=
github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g=
github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98=
github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA=
github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU=
github.com/envoyproxy/go-control-plane/envoy v1.37.0 h1:u3riX6BoYRfF4Dr7dwSOroNfdSbEPe9Yyl09/B6wBrQ=
github.com/envoyproxy/go-control-plane/envoy v1.37.0/go.mod h1:DReE9MMrmecPy+YvQOAOHNYMALuowAnbjjEMkkWOi6A=
github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI=
github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4=
github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4=
github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA=
github.com/envoyproxy/protoc-gen-validate v1.3.3 h1:MVQghNeW+LZcmXe7SY1V36Z+WFMDjpqGAGacLe2T0ds=
github.com/envoyproxy/protoc-gen-validate v1.3.3/go.mod h1:TsndJ/ngyIdQRhMcVVGDDHINPLWB7C82oDArY51KfB0=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM=
@@ -96,14 +96,22 @@ github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao=
github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8=
github.com/googleapis/gax-go/v2 v2.17.0 h1:RksgfBpxqff0EZkDWYuz9q/uWsTVz+kf43LsZ1J6SMc=
github.com/googleapis/gax-go/v2 v2.17.0/go.mod h1:mzaqghpQp4JDh3HvADwrat+6M3MOIDp5YKHhb9PAgDY=
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.18.0 h1:jxP5Uuo3bxm3M6gGtV94P4lliVetoCB4Wk2x8QA86LI=
github.com/googleapis/gax-go/v2 v2.18.0/go.mod h1:uSzZN4a356eRG985CzJ3WfbFSpqkLTjsnhWGJR6EwrE=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -112,8 +120,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
@@ -133,8 +141,8 @@ github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNw
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc=
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
@@ -144,6 +152,7 @@ github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xI
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
@@ -156,36 +165,36 @@ github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBi
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 h1:Awaf8gmW99tZTOWqkLCOl6aw1/rxAWVlHsHIZ3fT2sA=
go.opentelemetry.io/contrib/detectors/gcp v1.40.0/go.mod h1:99OY9ZCqyLkzJLTh5XhECpLRSxcZl+ZDKBEO+jMBFR4=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 h1:XmiuHzgJt067+a6kwyAzkhXooYVv3/TOw9cM2VfJgUM=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0/go.mod h1:KDgtbWKTQs4bM+VPUr6WlL9m/WXcmkCcBlIzqxPGzmI=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 h1:7iP2uCb7sGddAr30RRS6xjKy7AZ2JtTOPA3oolgVSw8=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0/go.mod h1:c7hN3ddxs/z6q9xwvfLPk+UHlWRQyaeR1LdgfL/66l0=
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g=
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.38.0 h1:wm/Q0GAAykXv83wzcKzGGqAnnfLFyFe7RslekZuv+VI=
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.38.0/go.mod h1:ra3Pa40+oKjvYh+ZD3EdxFZZB0xdMfuileHAm4nNN7w=
go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g=
go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc=
go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8=
go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE=
go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw=
go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg=
go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw=
go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA=
go.opentelemetry.io/contrib/detectors/gcp v1.42.0 h1:kpt2PEJuOuqYkPcktfJqWWDjTEd/FNgrxcniL7kQrXQ=
go.opentelemetry.io/contrib/detectors/gcp v1.42.0/go.mod h1:W9zQ439utxymRrXsUOzZbFX4JhLxXU4+ZnCt8GG7yA8=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0 h1:ZrPRak/kS4xI3AVXy8F7pipuDXmDsrO8Lg+yQjBLjw0=
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0/go.mod h1:3y6kQCWztq6hyW8Z9YxQDDm0Je9AJoFar2G0yDcmhRk=
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
go.yaml.in/yaml/v2 v2.4.4 h1:tuyd0P+2Ont/d6e2rl3be67goVK4R6deVxCUX5vyPaQ=
go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
@@ -200,10 +209,10 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -211,8 +220,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -225,8 +234,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -236,8 +245,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -249,10 +258,10 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
@@ -263,18 +272,18 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/api v0.265.0 h1:FZvfUdI8nfmuNrE34aOWFPmLC+qRBEiNm3JdivTvAAU=
google.golang.org/api v0.265.0/go.mod h1:uAvfEl3SLUj/7n6k+lJutcswVojHPp2Sp08jWCu8hLY=
google.golang.org/api v0.271.0 h1:cIPN4qcUc61jlh7oXu6pwOQqbJW2GqYh5PS6rB2C/JY=
google.golang.org/api v0.271.0/go.mod h1:CGT29bhwkbF+i11qkRUJb2KMKqcJ1hdFceEIRd9u64Q=
google.golang.org/appengine/v2 v2.0.6 h1:LvPZLGuchSBslPBp+LAhihBeGSiRh1myRoYK4NtuBIw=
google.golang.org/appengine/v2 v2.0.6/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI=
google.golang.org/genproto v0.0.0-20260203192932-546029d2fa20 h1:/CU1zrxTpGylJJbe3Ru94yy6sZRbzALq2/oxl3pGB3U=
google.golang.org/genproto v0.0.0-20260203192932-546029d2fa20/go.mod h1:Tt+08/KdKEt3l8x3Pby3HLQxMB3uk/MzaQ4ZIv0ORTs=
google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20 h1:7ei4lp52gK1uSejlA8AZl5AJjeLUOHBQscRQZUgAcu0=
google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20/go.mod h1:ZdbssH/1SOVnjnDlXzxDHK2MCidiqXtbYccJNzNYPEE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 h1:Jr5R2J6F6qWyzINc+4AM8t5pfUz6beZpHp678GNrMbE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/genproto v0.0.0-20260311181403-84a4fc48630c h1:ZhFDeBMmFc/4g8/GwxnJ4rzB3O4GwQVNr+8Mh7Y5z4g=
google.golang.org/genproto v0.0.0-20260311181403-84a4fc48630c/go.mod h1:hf4r/rBuzaTkLUWRO03771Xvcs6P5hwdQK3UUEJjqo0=
google.golang.org/genproto/googleapis/api v0.0.0-20260311181403-84a4fc48630c h1:OyQPd6I3pN/9gDxz6L13kYGJgqkpdrAohJRBeXyxlgI=
google.golang.org/genproto/googleapis/api v0.0.0-20260311181403-84a4fc48630c/go.mod h1:X2gu9Qwng7Nn009s/r3RUxqkzQNqOrAy79bluY7ojIg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c h1:xgCzyF2LFIO/0X2UAoVRiXKU5Xg6VjToG4i2/ecSswk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU=
google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=

592
message/cache.go Normal file
View File

@@ -0,0 +1,592 @@
package message
import (
"database/sql"
"encoding/json"
"errors"
"net/netip"
"strings"
"sync"
"time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
const (
tagMessageCache = "message_cache"
)
var errNoRows = errors.New("no rows found")
// queries holds the database-specific SQL queries
type queries struct {
insertMessage string
deleteMessage string
selectScheduledMessageIDsBySeqID string
deleteScheduledBySequenceID string
updateMessagesForTopicExpiry string
selectMessagesByID string
selectMessagesSinceTime string
selectMessagesSinceTimeScheduled string
selectMessagesSinceID string
selectMessagesSinceIDScheduled string
selectMessagesLatest string
selectMessagesDue string
selectMessagesExpired string
updateMessagePublished string
selectMessagesCount string
selectTopics string
updateAttachmentDeleted string
selectAttachmentsExpired string
selectAttachmentsSizeBySender string
selectAttachmentsSizeByUserID string
selectStats string
updateStats string
updateMessageTime string
}
// Cache stores published messages
type Cache struct {
db *db.DB
queue *util.BatchingQueue[*model.Message]
nop bool
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
queries queries
}
func newCache(db *db.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
var queue *util.BatchingQueue[*model.Message]
if batchSize > 0 || batchTimeout > 0 {
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
}
c := &Cache{
db: db,
queue: queue,
nop: nop,
mu: mu,
queries: queries,
}
go c.processMessageBatches()
return c
}
func (c *Cache) maybeLock() {
if c.mu != nil {
c.mu.Lock()
}
}
func (c *Cache) maybeUnlock() {
if c.mu != nil {
c.mu.Unlock()
}
}
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
func (c *Cache) AddMessage(m *model.Message) error {
if c.queue != nil {
c.queue.Enqueue(m)
return nil
}
return c.addMessages([]*model.Message{m})
}
// AddMessages synchronously stores a batch of messages to the message cache
func (c *Cache) AddMessages(ms []*model.Message) error {
return c.addMessages(ms)
}
func (c *Cache) addMessages(ms []*model.Message) error {
c.maybeLock()
defer c.maybeUnlock()
if c.nop {
return nil
}
if len(ms) == 0 {
return nil
}
start := time.Now()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.Prepare(c.queries.insertMessage)
if err != nil {
return err
}
defer stmt.Close()
for _, m := range ms {
if m.Event != model.MessageEvent && m.Event != model.MessageDeleteEvent && m.Event != model.MessageClearEvent {
return model.ErrUnexpectedMessageType
}
published := m.Time <= time.Now().Unix()
tags := util.SanitizeUTF8(strings.Join(m.Tags, ","))
var attachmentName, attachmentType, attachmentURL string
var attachmentSize, attachmentExpires int64
var attachmentDeleted bool
if m.Attachment != nil {
attachmentName = util.SanitizeUTF8(m.Attachment.Name)
attachmentType = util.SanitizeUTF8(m.Attachment.Type)
attachmentSize = m.Attachment.Size
attachmentExpires = m.Attachment.Expires
attachmentURL = util.SanitizeUTF8(m.Attachment.URL)
}
var actionsStr string
if len(m.Actions) > 0 {
actionsBytes, err := json.Marshal(m.Actions)
if err != nil {
return err
}
actionsStr = string(actionsBytes)
}
var sender string
if m.Sender.IsValid() {
sender = m.Sender.String()
}
_, err := stmt.Exec(
m.ID,
m.SequenceID,
m.Time,
m.Event,
m.Expires,
util.SanitizeUTF8(m.Topic),
util.SanitizeUTF8(m.Message),
util.SanitizeUTF8(m.Title),
m.Priority,
tags,
util.SanitizeUTF8(m.Click),
util.SanitizeUTF8(m.Icon),
actionsStr,
attachmentName,
attachmentType,
attachmentSize,
attachmentExpires,
attachmentURL,
attachmentDeleted, // Always zero
sender,
m.User,
util.SanitizeUTF8(m.ContentType),
m.Encoding,
published,
)
if err != nil {
return err
}
}
if err := tx.Commit(); err != nil {
log.Tag(tagMessageCache).Err(err).Error("Writing %d message(s) failed (took %v)", len(ms), time.Since(start))
return err
}
log.Tag(tagMessageCache).Debug("Wrote %d message(s) in %v", len(ms), time.Since(start))
return nil
}
// Messages returns messages for a topic since the given marker, optionally including scheduled messages
func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
if since.IsNone() {
return make([]*model.Message, 0), nil
} else if since.IsLatest() {
return c.messagesLatest(topic)
} else if since.IsID() {
return c.messagesSinceID(topic, since, scheduled)
}
return c.messagesSinceTime(topic, since, scheduled)
}
func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
var rows *sql.Rows
var err error
rdb := c.db.ReadOnly()
if scheduled {
rows, err = rdb.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
} else {
rows, err = rdb.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
}
if err != nil {
return nil, err
}
return readMessages(rows)
}
func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
var rows *sql.Rows
var err error
rdb := c.db.ReadOnly()
if scheduled {
rows, err = rdb.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
} else {
rows, err = rdb.Query(c.queries.selectMessagesSinceID, topic, since.ID())
}
if err != nil {
return nil, err
}
return readMessages(rows)
}
func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) {
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesLatest, topic)
if err != nil {
return nil, err
}
return readMessages(rows)
}
// MessagesDue returns all messages that are due for publishing
func (c *Cache) MessagesDue() ([]*model.Message, error) {
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
if err != nil {
return nil, err
}
return readMessages(rows)
}
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
func (c *Cache) MessagesExpired() ([]string, error) {
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ids, nil
}
// Message returns the message with the given ID, or ErrMessageNotFound if not found
func (c *Cache) Message(id string) (*model.Message, error) {
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesByID, id)
if err != nil {
return nil, err
}
if !rows.Next() {
return nil, model.ErrMessageNotFound
}
defer rows.Close()
return readMessage(rows)
}
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
func (c *Cache) UpdateMessageTime(messageID string, timestamp int64) error {
c.maybeLock()
defer c.maybeUnlock()
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
return err
}
// MarkPublished marks a message as published
func (c *Cache) MarkPublished(m *model.Message) error {
c.maybeLock()
defer c.maybeUnlock()
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
return err
}
// MessagesCount returns the total number of messages in the cache
func (c *Cache) MessagesCount() (int, error) {
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesCount)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
var count int
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// Topics returns a list of all topics with messages in the cache
func (c *Cache) Topics() ([]string, error) {
rows, err := c.db.ReadOnly().Query(c.queries.selectTopics)
if err != nil {
return nil, err
}
defer rows.Close()
topics := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
topics = append(topics, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return topics, nil
}
// DeleteMessages deletes the messages with the given IDs
func (c *Cache) DeleteMessages(ids ...string) error {
c.maybeLock()
defer c.maybeUnlock()
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, id := range ids {
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
return err
}
}
return nil
})
}
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
c.maybeLock()
defer c.maybeUnlock()
return db.QueryTx(c.db, func(tx *sql.Tx) ([]string, error) {
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
rows.Close() // Close rows before executing delete in same transaction
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
return nil, err
}
return ids, nil
})
}
// ExpireMessages marks messages in the given topics as expired
func (c *Cache) ExpireMessages(topics ...string) error {
c.maybeLock()
defer c.maybeUnlock()
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, t := range topics {
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
return err
}
}
return nil
})
}
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
func (c *Cache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ids, nil
}
// MarkAttachmentsDeleted marks the attachments for the given message IDs as deleted
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
c.maybeLock()
defer c.maybeUnlock()
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, id := range ids {
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
return err
}
}
return nil
})
}
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
if err != nil {
return 0, err
}
return c.readAttachmentBytesUsed(rows)
}
// AttachmentBytesUsedByUser returns the total size of active attachments for the given user
func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) {
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
if err != nil {
return 0, err
}
return c.readAttachmentBytesUsed(rows)
}
func (c *Cache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
defer rows.Close()
var size int64
if !rows.Next() {
return 0, errors.New("no rows found")
}
if err := rows.Scan(&size); err != nil {
return 0, err
} else if err := rows.Err(); err != nil {
return 0, err
}
return size, nil
}
// UpdateStats updates the total message count statistic
func (c *Cache) UpdateStats(messages int64) error {
c.maybeLock()
defer c.maybeUnlock()
_, err := c.db.Exec(c.queries.updateStats, messages)
return err
}
// Stats returns the total message count statistic
func (c *Cache) Stats() (messages int64, err error) {
rows, err := c.db.ReadOnly().Query(c.queries.selectStats)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
if err := rows.Scan(&messages); err != nil {
return 0, err
}
return messages, nil
}
// Close closes the underlying database connection
func (c *Cache) Close() error {
return c.db.Close()
}
func (c *Cache) processMessageBatches() {
if c.queue == nil {
return
}
for messages := range c.queue.Dequeue() {
if err := c.addMessages(messages); err != nil {
log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch")
}
}
}
func readMessages(rows *sql.Rows) ([]*model.Message, error) {
defer rows.Close()
messages := make([]*model.Message, 0)
for rows.Next() {
m, err := readMessage(rows)
if err != nil {
return nil, err
}
messages = append(messages, m)
}
if err := rows.Err(); err != nil {
return nil, err
}
return messages, nil
}
func readMessage(rows *sql.Rows) (*model.Message, error) {
var timestamp, expires, attachmentSize, attachmentExpires int64
var priority int
var id, sequenceID, event, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string
err := rows.Scan(
&id,
&sequenceID,
&timestamp,
&event,
&expires,
&topic,
&msg,
&title,
&priority,
&tagsStr,
&click,
&icon,
&actionsStr,
&attachmentName,
&attachmentType,
&attachmentSize,
&attachmentExpires,
&attachmentURL,
&sender,
&user,
&contentType,
&encoding,
)
if err != nil {
return nil, err
}
var tags []string
if tagsStr != "" {
tags = strings.Split(tagsStr, ",")
}
var actions []*model.Action
if actionsStr != "" {
if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil {
return nil, err
}
}
senderIP, err := netip.ParseAddr(sender)
if err != nil {
senderIP = netip.Addr{} // if no IP stored in database, return invalid address
}
var att *model.Attachment
if attachmentName != "" && attachmentURL != "" {
att = &model.Attachment{
Name: attachmentName,
Type: attachmentType,
Size: attachmentSize,
Expires: attachmentExpires,
URL: attachmentURL,
}
}
return &model.Message{
ID: id,
SequenceID: sequenceID,
Time: timestamp,
Expires: expires,
Event: event,
Topic: topic,
Message: msg,
Title: title,
Priority: priority,
Tags: tags,
Click: click,
Icon: icon,
Actions: actions,
Attachment: att,
Sender: senderIP,
User: user,
ContentType: contentType,
Encoding: encoding,
}, nil
}

111
message/cache_postgres.go Normal file
View File

@@ -0,0 +1,111 @@
package message
import (
"time"
"heckel.io/ntfy/v2/db"
)
// PostgreSQL runtime query constants
const (
postgresInsertMessageQuery = `
INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)
`
postgresDeleteMessageQuery = `DELETE FROM message WHERE mid = $1`
postgresSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
postgresDeleteScheduledBySequenceIDQuery = `DELETE FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
postgresUpdateMessagesForTopicExpiryQuery = `UPDATE message SET expires = $1 WHERE topic = $2`
postgresSelectMessagesByIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE mid = $1
`
postgresSelectMessagesSinceTimeQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND time >= $2 AND published = TRUE
ORDER BY time, id
`
postgresSelectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND time >= $2
ORDER BY time, id
`
postgresSelectMessagesSinceIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1
AND id > COALESCE((SELECT id FROM message WHERE mid = $2), 0)
AND published = TRUE
ORDER BY time, id
`
postgresSelectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1
AND (id > COALESCE((SELECT id FROM message WHERE mid = $2), 0) OR published = FALSE)
ORDER BY time, id
`
postgresSelectMessagesLatestQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND published = TRUE
ORDER BY time DESC, id DESC
LIMIT 1
`
postgresSelectMessagesDueQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE time <= $1 AND published = FALSE
ORDER BY time, id
`
postgresSelectMessagesExpiredQuery = `SELECT mid FROM message WHERE expires <= $1 AND published = TRUE`
postgresUpdateMessagePublishedQuery = `UPDATE message SET published = TRUE WHERE mid = $1`
postgresSelectMessagesCountQuery = `SELECT COUNT(*) FROM message`
postgresSelectTopicsQuery = `SELECT topic FROM message GROUP BY topic`
postgresUpdateAttachmentDeletedQuery = `UPDATE message SET attachment_deleted = TRUE WHERE mid = $1`
postgresSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`
postgresSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2`
postgresSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2`
postgresSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
postgresUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
postgresUpdateMessageTimeQuery = `UPDATE message SET time = $1 WHERE mid = $2`
)
var postgresQueries = queries{
insertMessage: postgresInsertMessageQuery,
deleteMessage: postgresDeleteMessageQuery,
selectScheduledMessageIDsBySeqID: postgresSelectScheduledMessageIDsBySeqIDQuery,
deleteScheduledBySequenceID: postgresDeleteScheduledBySequenceIDQuery,
updateMessagesForTopicExpiry: postgresUpdateMessagesForTopicExpiryQuery,
selectMessagesByID: postgresSelectMessagesByIDQuery,
selectMessagesSinceTime: postgresSelectMessagesSinceTimeQuery,
selectMessagesSinceTimeScheduled: postgresSelectMessagesSinceTimeIncludeScheduledQuery,
selectMessagesSinceID: postgresSelectMessagesSinceIDQuery,
selectMessagesSinceIDScheduled: postgresSelectMessagesSinceIDIncludeScheduledQuery,
selectMessagesLatest: postgresSelectMessagesLatestQuery,
selectMessagesDue: postgresSelectMessagesDueQuery,
selectMessagesExpired: postgresSelectMessagesExpiredQuery,
updateMessagePublished: postgresUpdateMessagePublishedQuery,
selectMessagesCount: postgresSelectMessagesCountQuery,
selectTopics: postgresSelectTopicsQuery,
updateAttachmentDeleted: postgresUpdateAttachmentDeletedQuery,
selectAttachmentsExpired: postgresSelectAttachmentsExpiredQuery,
selectAttachmentsSizeBySender: postgresSelectAttachmentsSizeBySenderQuery,
selectAttachmentsSizeByUserID: postgresSelectAttachmentsSizeByUserIDQuery,
selectStats: postgresSelectStatsQuery,
updateStats: postgresUpdateStatsQuery,
updateMessageTime: postgresUpdateMessageTimeQuery,
}
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
func NewPostgresStore(d *db.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) {
if err := setupPostgres(d.Primary()); err != nil {
return nil, err
}
return newCache(d, postgresQueries, nil, batchSize, batchTimeout, false), nil
}

View File

@@ -0,0 +1,85 @@
package message
import (
"database/sql"
"fmt"
"heckel.io/ntfy/v2/db"
)
// Initial PostgreSQL schema
const (
postgresCreateTablesQuery = `
CREATE TABLE IF NOT EXISTS message (
id BIGSERIAL PRIMARY KEY,
mid TEXT NOT NULL,
sequence_id TEXT NOT NULL,
time BIGINT NOT NULL,
event TEXT NOT NULL,
expires BIGINT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size BIGINT NOT NULL,
attachment_expires BIGINT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_deleted BOOLEAN NOT NULL DEFAULT FALSE,
sender TEXT NOT NULL,
user_id TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid);
CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id);
CREATE INDEX IF NOT EXISTS idx_message_topic_published_time ON message (topic, published, time, id);
CREATE INDEX IF NOT EXISTS idx_message_published_expires ON message (published, expires);
CREATE INDEX IF NOT EXISTS idx_message_sender_attachment_expires ON message (sender, attachment_expires) WHERE user_id = '';
CREATE INDEX IF NOT EXISTS idx_message_user_id_attachment_expires ON message (user_id, attachment_expires);
CREATE TABLE IF NOT EXISTS message_stats (
key TEXT PRIMARY KEY,
value BIGINT
);
INSERT INTO message_stats (key, value) VALUES ('messages', 0);
CREATE TABLE IF NOT EXISTS schema_version (
store TEXT PRIMARY KEY,
version INT NOT NULL
);
`
)
// PostgreSQL schema management queries
const (
postgresCurrentSchemaVersion = 14
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
)
func setupPostgres(db *sql.DB) error {
var schemaVersion int
if err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
return setupNewPostgresDB(db)
} else if schemaVersion > postgresCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
}
return nil
}
func setupNewPostgresDB(sqlDB *sql.DB) error {
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
return err
}
return nil
})
}

143
message/cache_sqlite.go Normal file
View File

@@ -0,0 +1,143 @@
package message
import (
"database/sql"
"fmt"
"path/filepath"
"sync"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/util"
)
// SQLite runtime query constants
const (
sqliteInsertMessageQuery = `
INSERT INTO messages (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
sqliteDeleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
sqliteSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
sqliteDeleteScheduledBySequenceIDQuery = `DELETE FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
sqliteUpdateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
sqliteSelectMessagesByIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE mid = ?
`
sqliteSelectMessagesSinceTimeQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time, id
`
sqliteSelectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ?
ORDER BY time, id
`
sqliteSelectMessagesSinceIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND id > COALESCE((SELECT id FROM messages WHERE mid = ?), 0) AND published = 1
ORDER BY time, id
`
sqliteSelectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND (id > COALESCE((SELECT id FROM messages WHERE mid = ?), 0) OR published = 0)
ORDER BY time, id
`
sqliteSelectMessagesLatestQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND published = 1
ORDER BY time DESC, id DESC
LIMIT 1
`
sqliteSelectMessagesDueQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE time <= ? AND published = 0
ORDER BY time, id
`
sqliteSelectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
sqliteUpdateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
sqliteSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
sqliteSelectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
sqliteUpdateAttachmentDeletedQuery = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
sqliteSelectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
sqliteSelectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
sqliteSelectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
)
var sqliteQueries = queries{
insertMessage: sqliteInsertMessageQuery,
deleteMessage: sqliteDeleteMessageQuery,
selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery,
deleteScheduledBySequenceID: sqliteDeleteScheduledBySequenceIDQuery,
updateMessagesForTopicExpiry: sqliteUpdateMessagesForTopicExpiryQuery,
selectMessagesByID: sqliteSelectMessagesByIDQuery,
selectMessagesSinceTime: sqliteSelectMessagesSinceTimeQuery,
selectMessagesSinceTimeScheduled: sqliteSelectMessagesSinceTimeIncludeScheduledQuery,
selectMessagesSinceID: sqliteSelectMessagesSinceIDQuery,
selectMessagesSinceIDScheduled: sqliteSelectMessagesSinceIDIncludeScheduledQuery,
selectMessagesLatest: sqliteSelectMessagesLatestQuery,
selectMessagesDue: sqliteSelectMessagesDueQuery,
selectMessagesExpired: sqliteSelectMessagesExpiredQuery,
updateMessagePublished: sqliteUpdateMessagePublishedQuery,
selectMessagesCount: sqliteSelectMessagesCountQuery,
selectTopics: sqliteSelectTopicsQuery,
updateAttachmentDeleted: sqliteUpdateAttachmentDeletedQuery,
selectAttachmentsExpired: sqliteSelectAttachmentsExpiredQuery,
selectAttachmentsSizeBySender: sqliteSelectAttachmentsSizeBySenderQuery,
selectAttachmentsSizeByUserID: sqliteSelectAttachmentsSizeByUserIDQuery,
selectStats: sqliteSelectStatsQuery,
updateStats: sqliteUpdateStatsQuery,
updateMessageTime: sqliteUpdateMessageTimeQuery,
}
// NewSQLiteStore creates a SQLite file-backed cache
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*Cache, error) {
parentDir := filepath.Dir(filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
}
d, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupSQLite(d, startupQueries, cacheDuration); err != nil {
return nil, err
}
return newCache(db.New(&db.Host{DB: d}, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
}
// NewMemStore creates an in-memory cache
func NewMemStore() (*Cache, error) {
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false)
}
// NewNopStore creates an in-memory cache that discards all messages;
// it is always empty and can be used if caching is entirely disabled
func NewNopStore() (*Cache, error) {
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true)
}
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory
// sql database, so if the stdlib's sql engine happens to open another connection and
// you've only specified ":memory:", that connection will see a brand new database.
// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared").
// Every connection to this string will point to the same in-memory database."
func createMemoryFilename() string {
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
}

View File

@@ -0,0 +1,453 @@
package message
import (
"database/sql"
"fmt"
"time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
)
// Initial SQLite schema
const (
sqliteCreateTablesQuery = `
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
sequence_id TEXT NOT NULL,
time INT NOT NULL,
event TEXT NOT NULL,
expires INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_deleted INT NOT NULL,
sender TEXT NOT NULL,
user TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
`
)
// Schema version management for SQLite
const (
sqliteCurrentSchemaVersion = 14
sqliteCreateSchemaVersionTableQuery = `
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
`
sqliteInsertSchemaVersionQuery = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteUpdateSchemaVersionQuery = `UPDATE schemaVersion SET version = ? WHERE id = 1`
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
// Schema migrations for SQLite
const (
// 0 -> 1
sqliteMigrate0To1AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
`
// 1 -> 2
sqliteMigrate1To2AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
`
// 2 -> 3
sqliteMigrate2To3AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
`
// 3 -> 4
sqliteMigrate3To4AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
`
// 4 -> 5
sqliteMigrate4To5AlterMessagesTableQuery = `
CREATE TABLE IF NOT EXISTS messages_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_owner TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
INSERT
INTO messages_new (
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
SELECT
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
FROM messages;
DROP TABLE messages;
ALTER TABLE messages_new RENAME TO messages;
`
// 5 -> 6
sqliteMigrate5To6AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
`
// 6 -> 7
sqliteMigrate6To7AlterMessagesTableQuery = `
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
`
// 7 -> 8
sqliteMigrate7To8AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
`
// 8 -> 9
sqliteMigrate8To9AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
`
// 9 -> 10
sqliteMigrate9To10AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
`
sqliteMigrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
// 10 -> 11
sqliteMigrate10To11AlterMessagesTableQuery = `
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
`
// 11 -> 12
sqliteMigrate11To12AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
`
// 12 -> 13
sqliteMigrate12To13AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
`
// 13 -> 14
sqliteMigrate13To14AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN sequence_id TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN event TEXT NOT NULL DEFAULT('message');
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
`
)
var (
sqliteMigrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
0: sqliteMigrateFrom0,
1: sqliteMigrateFrom1,
2: sqliteMigrateFrom2,
3: sqliteMigrateFrom3,
4: sqliteMigrateFrom4,
5: sqliteMigrateFrom5,
6: sqliteMigrateFrom6,
7: sqliteMigrateFrom7,
8: sqliteMigrateFrom8,
9: sqliteMigrateFrom9,
10: sqliteMigrateFrom10,
11: sqliteMigrateFrom11,
12: sqliteMigrateFrom12,
13: sqliteMigrateFrom13,
}
)
func setupSQLite(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
return err
}
// If 'messages' table does not exist, this must be a new database
var messagesCount int
if err := db.QueryRow(sqliteSelectMessagesCountQuery).Scan(&messagesCount); err != nil {
return setupNewSQLite(db)
}
// If 'messages' table exists (schema >= 0), check 'schemaVersion' table
var schemaVersion int
db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion) // Error means schema version is zero!
// Do migrations
if schemaVersion == sqliteCurrentSchemaVersion {
return nil
} else if schemaVersion > sqliteCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
}
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
fn, ok := sqliteMigrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db, cacheDuration); err != nil {
return err
}
}
return nil
}
func setupNewSQLite(sqlDB *sql.DB) error {
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
return err
}
return nil
})
}
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
return nil
}
func sqliteMigrateFrom0(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate0To1AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, 1); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom1(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate1To2AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 2); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom2(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate2To3AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 3); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom3(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate3To4AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 4); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom4(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate4To5AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 5); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom5(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate5To6AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 6); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom6(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate6To7AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 7); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom7(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate7To8AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 8); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom8(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate8To9AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 9); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom9(sqlDB *sql.DB, cacheDuration time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom10(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom11(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom12(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom13(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
return err
}
return nil
})
}

View File

@@ -0,0 +1,292 @@
package message_test
import (
"database/sql"
"fmt"
"path/filepath"
"testing"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
)
func TestSqliteStore_Migration_From0(t *testing.T) {
filename := newSqliteTestStoreFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 0" schema
_, err = db.Exec(`
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id VARCHAR(20) PRIMARY KEY,
time INT NOT NULL,
topic VARCHAR(64) NOT NULL,
message VARCHAR(1024) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
COMMIT;
`)
require.Nil(t, err)
// Insert a bunch of messages
for i := 0; i < 10; i++ {
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`,
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i))
require.Nil(t, err)
}
require.Nil(t, db.Close())
// Create store to trigger migration
s := newSqliteTestStoreFromFile(t, filename, "")
checkSqliteSchemaVersion(t, filename)
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
require.Equal(t, "some message 5", messages[5].Message)
require.Equal(t, "", messages[5].Title)
require.Nil(t, messages[5].Tags)
require.Equal(t, 0, messages[5].Priority)
}
func TestSqliteStore_Migration_From1(t *testing.T) {
filename := newSqliteTestStoreFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 1" schema
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS messages (
id VARCHAR(20) PRIMARY KEY,
time INT NOT NULL,
topic VARCHAR(64) NOT NULL,
message VARCHAR(512) NOT NULL,
title VARCHAR(256) NOT NULL,
priority INT NOT NULL,
tags VARCHAR(256) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
`)
require.Nil(t, err)
// Insert a bunch of messages
for i := 0; i < 10; i++ {
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`,
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "")
require.Nil(t, err)
}
require.Nil(t, db.Close())
// Create store to trigger migration
s := newSqliteTestStoreFromFile(t, filename, "")
checkSqliteSchemaVersion(t, filename)
// Add delayed message
delayedMessage := model.NewDefaultMessage("mytopic", "some delayed message")
delayedMessage.Time = time.Now().Add(time.Minute).Unix()
require.Nil(t, s.AddMessage(delayedMessage))
// 10, not 11!
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
// 11!
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 11, len(messages))
// Check that index "idx_topic" exists
verifyDB, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
defer verifyDB.Close()
rows, err := verifyDB.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
require.Nil(t, err)
require.True(t, rows.Next())
var indexName string
require.Nil(t, rows.Scan(&indexName))
require.Equal(t, "idx_topic", indexName)
require.Nil(t, rows.Close())
}
func TestSqliteStore_Migration_From9(t *testing.T) {
// This primarily tests the awkward migration that introduces the "expires" column.
// The migration logic has to update the column, using the existing "cache-duration" value.
filename := newSqliteTestStoreFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 9" schema
_, err = db.Exec(`
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
sender TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schemaVersion (id, version) VALUES (1, 9);
COMMIT;
`)
require.Nil(t, err)
// Insert a bunch of messages
insertQuery := `
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
for i := 0; i < 10; i++ {
_, err = db.Exec(
insertQuery,
fmt.Sprintf("abcd%d", i),
time.Now().Unix(),
"mytopic",
fmt.Sprintf("some message %d", i),
"", // title
0, // priority
"", // tags
"", // click
"", // icon
"", // actions
"", // attachment_name
"", // attachment_type
0, // attachment_size
0, // attachment_expires
"", // attachment_url
"9.9.9.9", // sender
"", // encoding
1, // published
)
require.Nil(t, err)
}
require.Nil(t, db.Close())
// Create store to trigger migration
cacheDuration := 17 * time.Hour
s, err := message.NewSQLiteStore(filename, "", cacheDuration, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
checkSqliteSchemaVersion(t, filename)
// Check version
verifyDB, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
defer verifyDB.Close()
rows, err := verifyDB.Query(`SELECT version FROM schemaVersion WHERE id = 1`)
require.Nil(t, err)
require.True(t, rows.Next())
var version int
require.Nil(t, rows.Scan(&version))
require.Equal(t, 14, version)
require.Nil(t, rows.Close())
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
for _, m := range messages {
require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix())
require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix())
}
}
func TestSqliteStore_StartupQueries_WAL(t *testing.T) {
filename := newSqliteTestStoreFile(t)
startupQueries := `pragma journal_mode = WAL;
pragma synchronous = normal;
pragma temp_store = memory;`
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
require.FileExists(t, filename)
require.FileExists(t, filename+"-wal")
require.FileExists(t, filename+"-shm")
}
func TestSqliteStore_StartupQueries_None(t *testing.T) {
filename := newSqliteTestStoreFile(t)
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
require.FileExists(t, filename)
require.NoFileExists(t, filename+"-wal")
require.NoFileExists(t, filename+"-shm")
}
func TestSqliteStore_StartupQueries_Fail(t *testing.T) {
filename := newSqliteTestStoreFile(t)
_, err := message.NewSQLiteStore(filename, `xx error`, time.Hour, 0, 0, false)
require.Error(t, err)
}
func TestNopStore(t *testing.T) {
s, err := message.NewNopStore()
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "my message")))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Empty(t, messages)
topics, err := s.Topics()
require.Nil(t, err)
require.Empty(t, topics)
}
func newSqliteTestStoreFile(t *testing.T) string {
return filepath.Join(t.TempDir(), "cache.db")
}
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) *message.Cache {
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func checkSqliteSchemaVersion(t *testing.T, filename string) {
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
defer db.Close()
rows, err := db.Query(`SELECT version FROM schemaVersion`)
require.Nil(t, err)
require.True(t, rows.Next())
var schemaVersion int
require.Nil(t, rows.Scan(&schemaVersion))
require.Equal(t, 14, schemaVersion)
require.Nil(t, rows.Close())
}

967
message/cache_test.go Normal file
View File

@@ -0,0 +1,967 @@
package message_test
import (
"net/netip"
"path/filepath"
"sort"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
)
func newSqliteTestStore(t *testing.T) *message.Cache {
filename := filepath.Join(t.TempDir(), "cache.db")
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func newMemTestStore(t *testing.T) *message.Cache {
s, err := message.NewMemStore()
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func newTestPostgresStore(t *testing.T) *message.Cache {
testDB := dbtest.CreateTestPostgres(t)
store, err := message.NewPostgresStore(testDB, 0, 0)
require.Nil(t, err)
return store
}
func forEachBackend(t *testing.T, f func(t *testing.T, s *message.Cache)) {
t.Run("sqlite", func(t *testing.T) {
f(t, newSqliteTestStore(t))
})
t.Run("mem", func(t *testing.T) {
f(t, newMemTestStore(t))
})
t.Run("postgres", func(t *testing.T) {
f(t, newTestPostgresStore(t))
})
}
func TestStore_Messages(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m1 := model.NewDefaultMessage("mytopic", "my message")
m1.Time = 1
m2 := model.NewDefaultMessage("mytopic", "my other message")
m2.Time = 2
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message")))
require.Nil(t, s.AddMessage(m2))
// Adding invalid
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added!
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added!
// count
count, err := s.MessagesCount()
require.Nil(t, err)
require.Equal(t, 3, count)
// mytopic: since all
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
require.Equal(t, 2, len(messages))
require.Equal(t, "my message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
require.Equal(t, model.MessageEvent, messages[0].Event)
require.Equal(t, "", messages[0].Title)
require.Equal(t, 0, messages[0].Priority)
require.Nil(t, messages[0].Tags)
require.Equal(t, "my other message", messages[1].Message)
// mytopic: since none
messages, _ = s.Messages("mytopic", model.SinceNoMessages, false)
require.Empty(t, messages)
// mytopic: since m1 (by ID)
messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false)
require.Equal(t, 1, len(messages))
require.Equal(t, m2.ID, messages[0].ID)
require.Equal(t, "my other message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
// mytopic: since 2
messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// mytopic: latest
messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// example: since all
messages, _ = s.Messages("example", model.SinceAllMessages, false)
require.Equal(t, "my example message", messages[0].Message)
// non-existing: since all
messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false)
require.Empty(t, messages)
})
}
func TestStore_MessagesLock(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
var wg sync.WaitGroup
for i := 0; i < 5000; i++ {
wg.Add(1)
go func() {
assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message")))
wg.Done()
}()
}
wg.Wait()
})
}
func TestStore_MessagesScheduled(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m1 := model.NewDefaultMessage("mytopic", "message 1")
m2 := model.NewDefaultMessage("mytopic", "message 2")
m2.Time = time.Now().Add(time.Hour).Unix()
m3 := model.NewDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
m4 := model.NewDefaultMessage("mytopic2", "message 4")
m4.Time = time.Now().Add(time.Minute).Unix()
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled
require.Equal(t, 1, len(messages))
require.Equal(t, "message 1", messages[0].Message)
messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled
require.Equal(t, 3, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message) // Order!
require.Equal(t, "message 2", messages[2].Message)
messages, _ = s.MessagesDue()
require.Empty(t, messages)
})
}
func TestStore_Topics(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3")))
topics, err := s.Topics()
if err != nil {
t.Fatal(err)
}
require.Equal(t, 2, len(topics))
require.Contains(t, topics, "topic1")
require.Contains(t, topics, "topic2")
})
}
func TestStore_MessagesTagsPrioAndTitle(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m := model.NewDefaultMessage("mytopic", "some message")
m.Tags = []string{"tag1", "tag2"}
m.Priority = 5
m.Title = "some title"
require.Nil(t, s.AddMessage(m))
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
require.Equal(t, 5, messages[0].Priority)
require.Equal(t, "some title", messages[0].Title)
})
}
func TestStore_MessagesSinceID(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m1 := model.NewDefaultMessage("mytopic", "message 1")
m1.Time = 100
m2 := model.NewDefaultMessage("mytopic", "message 2")
m2.Time = 200
m3 := model.NewDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
m4 := model.NewDefaultMessage("mytopic", "message 4")
m4.Time = 400
m5 := model.NewDefaultMessage("mytopic", "message 5")
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
m6 := model.NewDefaultMessage("mytopic", "message 6")
m6.Time = 600
m7 := model.NewDefaultMessage("mytopic", "message 7")
m7.Time = 700
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
require.Nil(t, s.AddMessage(m4))
require.Nil(t, s.AddMessage(m5))
require.Nil(t, s.AddMessage(m6))
require.Nil(t, s.AddMessage(m7))
// Case 1: Since ID exists, exclude scheduled
messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false)
require.Equal(t, 3, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
require.Equal(t, "message 7", messages[2].Message)
// Case 2: Since ID exists, include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true)
require.Equal(t, 5, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message)
require.Equal(t, "message 7", messages[2].Message)
require.Equal(t, "message 5", messages[3].Message) // Order!
require.Equal(t, "message 3", messages[4].Message) // Order!
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true)
require.Equal(t, 7, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 2", messages[1].Message)
require.Equal(t, "message 4", messages[2].Message)
require.Equal(t, "message 6", messages[3].Message)
require.Equal(t, "message 7", messages[4].Message)
require.Equal(t, "message 5", messages[5].Message) // Order!
require.Equal(t, "message 3", messages[6].Message) // Order!
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false)
require.Equal(t, 0, len(messages))
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true)
require.Equal(t, 2, len(messages))
require.Equal(t, "message 5", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message)
})
}
func TestStore_Prune(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
now := time.Now().Unix()
m1 := model.NewDefaultMessage("mytopic", "my message")
m1.Time = now - 10
m1.Expires = now - 5
m2 := model.NewDefaultMessage("mytopic", "my other message")
m2.Time = now - 5
m2.Expires = now + 5 // In the future
m3 := model.NewDefaultMessage("another_topic", "and another one")
m3.Time = now - 12
m3.Expires = now - 2
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
count, err := s.MessagesCount()
require.Nil(t, err)
require.Equal(t, 3, count)
expiredMessageIDs, err := s.MessagesExpired()
require.Nil(t, err)
require.Nil(t, s.DeleteMessages(expiredMessageIDs...))
count, err = s.MessagesCount()
require.Nil(t, err)
require.Equal(t, 1, count)
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
})
}
func TestStore_Attachments(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
m := model.NewDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &model.Attachment{
Name: "flower.jpg",
Type: "image/jpeg",
Size: 5000,
Expires: expires1,
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
}
require.Nil(t, s.AddMessage(m))
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = model.NewDefaultMessage("mytopic", "sending you a car")
m.ID = "m2"
m.SequenceID = "m2"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: expires2,
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = model.NewDefaultMessage("another-topic", "sending you another car")
m.ID = "m3"
m.SequenceID = "m3"
m.User = "u_BAsbaAa"
m.Sender = netip.MustParseAddr("5.6.7.8")
m.Attachment = &model.Attachment{
Name: "another-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: expires3,
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
}
require.Nil(t, s.AddMessage(m))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, "flower for you", messages[0].Message)
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
require.Equal(t, int64(5000), messages[0].Attachment.Size)
require.Equal(t, expires1, messages[0].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
require.Equal(t, "sending you a car", messages[1].Message)
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
require.Equal(t, int64(10000), messages[1].Attachment.Size)
require.Equal(t, expires2, messages[1].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
size, err := s.AttachmentBytesUsedBySender("1.2.3.4")
require.Nil(t, err)
require.Equal(t, int64(10000), size)
size, err = s.AttachmentBytesUsedBySender("5.6.7.8")
require.Nil(t, err)
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa")
require.Nil(t, err)
require.Equal(t, int64(20000), size)
})
}
func TestStore_AttachmentsExpired(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m := model.NewDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic", "message with attachment")
m.ID = "m2"
m.SequenceID = "m2"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: time.Now().Add(2 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic", "message with external attachment")
m.ID = "m3"
m.SequenceID = "m3"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Expires: 0, // Unknown!
URL: "https://somedomain.com/car.jpg",
}
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic2", "message with expired attachment")
m.ID = "m4"
m.SequenceID = "m4"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "expired-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: time.Now().Add(-1 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
ids, err := s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "m4", ids[0])
})
}
func TestStore_Sender(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m1 := model.NewDefaultMessage("mytopic", "mymessage")
m1.Sender = netip.MustParseAddr("1.2.3.4")
require.Nil(t, s.AddMessage(m1))
m2 := model.NewDefaultMessage("mytopic", "mymessage without sender")
require.Nil(t, s.AddMessage(m2))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
require.Equal(t, messages[1].Sender, netip.Addr{})
})
}
func TestStore_DeleteScheduledBySequenceID(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Create a scheduled (unpublished) message
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
scheduledMsg.ID = "scheduled1"
scheduledMsg.SequenceID = "seq123"
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
require.Nil(t, s.AddMessage(scheduledMsg))
// Create a published message with different sequence ID
publishedMsg := model.NewDefaultMessage("mytopic", "published message")
publishedMsg.ID = "published1"
publishedMsg.SequenceID = "seq456"
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
require.Nil(t, s.AddMessage(publishedMsg))
// Create a scheduled message in a different topic
otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled")
otherTopicMsg.ID = "other1"
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(otherTopicMsg))
// Verify all messages exist (including scheduled)
messages, err := s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Delete scheduled message by sequence ID and verify returned IDs
deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123")
require.Nil(t, err)
require.Equal(t, 1, len(deletedIDs))
require.Equal(t, "scheduled1", deletedIDs[0])
// Verify scheduled message is deleted
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
// Verify other topic's message still exists (topic-scoped deletion)
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "other scheduled", messages[0].Message)
// Deleting non-existent sequence ID should return empty list
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent")
require.Nil(t, err)
require.Empty(t, deletedIDs)
// Deleting published message should not affect it (only deletes unpublished)
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456")
require.Nil(t, err)
require.Empty(t, deletedIDs)
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
})
}
func TestStore_MessageByID(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Add a message
m := model.NewDefaultMessage("mytopic", "some message")
m.Title = "some title"
m.Priority = 4
m.Tags = []string{"tag1", "tag2"}
require.Nil(t, s.AddMessage(m))
// Retrieve by ID
retrieved, err := s.Message(m.ID)
require.Nil(t, err)
require.Equal(t, m.ID, retrieved.ID)
require.Equal(t, "mytopic", retrieved.Topic)
require.Equal(t, "some message", retrieved.Message)
require.Equal(t, "some title", retrieved.Title)
require.Equal(t, 4, retrieved.Priority)
require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags)
// Non-existent ID returns ErrMessageNotFound
_, err = s.Message("doesnotexist")
require.Equal(t, model.ErrMessageNotFound, err)
})
}
func TestStore_MarkPublished(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Add a scheduled message (future time -> unpublished)
m := model.NewDefaultMessage("mytopic", "scheduled message")
m.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m))
// Verify it does not appear in non-scheduled queries
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 0, len(messages))
// Verify it does appear in scheduled queries
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Mark as published
require.Nil(t, s.MarkPublished(m))
// Now it should appear in non-scheduled queries too
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "scheduled message", messages[0].Message)
})
}
func TestStore_ExpireMessages(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Add messages to two topics
m1 := model.NewDefaultMessage("topic1", "message 1")
m1.Expires = time.Now().Add(time.Hour).Unix()
m2 := model.NewDefaultMessage("topic1", "message 2")
m2.Expires = time.Now().Add(time.Hour).Unix()
m3 := model.NewDefaultMessage("topic2", "message 3")
m3.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
// Verify all messages exist
messages, err := s.Messages("topic1", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Expire topic1 messages
require.Nil(t, s.ExpireMessages("topic1"))
// topic1 messages should now be expired (expires set to past)
expiredIDs, err := s.MessagesExpired()
require.Nil(t, err)
require.Equal(t, 2, len(expiredIDs))
sort.Strings(expiredIDs)
expectedIDs := []string{m1.ID, m2.ID}
sort.Strings(expectedIDs)
require.Equal(t, expectedIDs, expiredIDs)
// topic2 should be unaffected
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "message 3", messages[0].Message)
})
}
func TestStore_MarkAttachmentsDeleted(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Add a message with an expired attachment (file needs cleanup)
m1 := model.NewDefaultMessage("mytopic", "old file")
m1.ID = "msg1"
m1.SequenceID = "msg1"
m1.Expires = time.Now().Add(time.Hour).Unix()
m1.Attachment = &model.Attachment{
Name: "old.pdf",
Type: "application/pdf",
Size: 50000,
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
URL: "https://ntfy.sh/file/old.pdf",
}
require.Nil(t, s.AddMessage(m1))
// Add a message with another expired attachment
m2 := model.NewDefaultMessage("mytopic", "another old file")
m2.ID = "msg2"
m2.SequenceID = "msg2"
m2.Expires = time.Now().Add(time.Hour).Unix()
m2.Attachment = &model.Attachment{
Name: "another.pdf",
Type: "application/pdf",
Size: 30000,
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
URL: "https://ntfy.sh/file/another.pdf",
}
require.Nil(t, s.AddMessage(m2))
// Both should show as expired attachments needing cleanup
ids, err := s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 2, len(ids))
// Mark msg1's attachment as deleted (file cleaned up)
require.Nil(t, s.MarkAttachmentsDeleted("msg1"))
// Now only msg2 should show as needing cleanup
ids, err = s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "msg2", ids[0])
// Mark msg2 too
require.Nil(t, s.MarkAttachmentsDeleted("msg2"))
// No more expired attachments to clean up
ids, err = s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 0, len(ids))
// Messages themselves still exist
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
})
}
func TestStore_Stats(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Initial stats should be zero
messages, err := s.Stats()
require.Nil(t, err)
require.Equal(t, int64(0), messages)
// Update stats
require.Nil(t, s.UpdateStats(42))
messages, err = s.Stats()
require.Nil(t, err)
require.Equal(t, int64(42), messages)
// Update again (overwrites)
require.Nil(t, s.UpdateStats(100))
messages, err = s.Stats()
require.Nil(t, err)
require.Equal(t, int64(100), messages)
})
}
func TestStore_AddMessages(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Batch add multiple messages
msgs := []*model.Message{
model.NewDefaultMessage("mytopic", "batch 1"),
model.NewDefaultMessage("mytopic", "batch 2"),
model.NewDefaultMessage("othertopic", "batch 3"),
}
require.Nil(t, s.AddMessages(msgs))
// Verify all were inserted
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("othertopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "batch 3", messages[0].Message)
// Empty batch should succeed
require.Nil(t, s.AddMessages([]*model.Message{}))
// Batch with invalid event type should fail
badMsgs := []*model.Message{
model.NewKeepaliveMessage("mytopic"),
}
require.NotNil(t, s.AddMessages(badMsgs))
})
}
func TestStore_MessagesDue(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Add a message scheduled in the past (i.e. it's due now)
m1 := model.NewDefaultMessage("mytopic", "due message")
m1.Time = time.Now().Add(-time.Second).Unix()
// Set expires in the future so it doesn't get pruned
m1.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m1))
// Add a message scheduled in the future (not due)
m2 := model.NewDefaultMessage("mytopic", "future message")
m2.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m2))
// Mark m1 as published so it won't be "due"
// (MessagesDue returns unpublished messages whose time <= now)
// m1 is auto-published (time <= now), so it should not be due
// m2 is unpublished (time in future), not due yet
due, err := s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 0, len(due))
// Add a message that was explicitly scheduled in the past but time has "arrived"
// We need to manipulate the database to create a truly "due" message:
// a message with published=false and time <= now
m3 := model.NewDefaultMessage("mytopic", "truly due message")
m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now
require.Nil(t, s.AddMessage(m3))
// Not due yet
due, err = s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 0, len(due))
// Wait for it to become due
time.Sleep(3 * time.Second)
due, err = s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 1, len(due))
require.Equal(t, "truly due message", due[0].Message)
})
}
func TestStore_MessageFieldRoundTrip(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Create a message with all fields populated
m := model.NewDefaultMessage("mytopic", "hello world")
m.SequenceID = "custom_seq_id"
m.Title = "A Title"
m.Priority = 4
m.Tags = []string{"warning", "srv01"}
m.Click = "https://example.com/click"
m.Icon = "https://example.com/icon.png"
m.Actions = []*model.Action{
{
ID: "action1",
Action: "view",
Label: "Open Site",
URL: "https://example.com",
Clear: true,
},
{
ID: "action2",
Action: "http",
Label: "Call Webhook",
URL: "https://example.com/hook",
Method: "PUT",
Headers: map[string]string{"X-Token": "secret"},
Body: `{"key":"value"}`,
},
}
m.ContentType = "text/markdown"
m.Encoding = "base64"
m.Sender = netip.MustParseAddr("9.8.7.6")
m.User = "u_TestUser123"
require.Nil(t, s.AddMessage(m))
// Retrieve and verify every field
retrieved, err := s.Message(m.ID)
require.Nil(t, err)
require.Equal(t, m.ID, retrieved.ID)
require.Equal(t, "custom_seq_id", retrieved.SequenceID)
require.Equal(t, m.Time, retrieved.Time)
require.Equal(t, m.Expires, retrieved.Expires)
require.Equal(t, model.MessageEvent, retrieved.Event)
require.Equal(t, "mytopic", retrieved.Topic)
require.Equal(t, "hello world", retrieved.Message)
require.Equal(t, "A Title", retrieved.Title)
require.Equal(t, 4, retrieved.Priority)
require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags)
require.Equal(t, "https://example.com/click", retrieved.Click)
require.Equal(t, "https://example.com/icon.png", retrieved.Icon)
require.Equal(t, "text/markdown", retrieved.ContentType)
require.Equal(t, "base64", retrieved.Encoding)
require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender)
require.Equal(t, "u_TestUser123", retrieved.User)
// Verify actions round-trip
require.Equal(t, 2, len(retrieved.Actions))
require.Equal(t, "action1", retrieved.Actions[0].ID)
require.Equal(t, "view", retrieved.Actions[0].Action)
require.Equal(t, "Open Site", retrieved.Actions[0].Label)
require.Equal(t, "https://example.com", retrieved.Actions[0].URL)
require.Equal(t, true, retrieved.Actions[0].Clear)
require.Equal(t, "action2", retrieved.Actions[1].ID)
require.Equal(t, "http", retrieved.Actions[1].Action)
require.Equal(t, "Call Webhook", retrieved.Actions[1].Label)
require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL)
require.Equal(t, "PUT", retrieved.Actions[1].Method)
require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"])
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
})
}
func TestStore_AddMessage_InvalidUTF8(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// 0xc9 0x43: Latin-1 "ÉC" — 0xc9 starts a 2-byte UTF-8 sequence but 0x43 ('C') is not a continuation byte
m := model.NewDefaultMessage("mytopic", "\xc9Cas du serveur")
require.Nil(t, s.AddMessage(m))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "\uFFFDCas du serveur", messages[0].Message)
// 0xae: Latin-1 "®" — isolated byte above 0x7F, not a valid UTF-8 start for single byte
m2 := model.NewDefaultMessage("mytopic", "Product\xae Pro")
require.Nil(t, s.AddMessage(m2))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "Product\uFFFD Pro", messages[1].Message)
// 0xe8 0x6d 0x65: Latin-1 "ème" — 0xe8 starts a 3-byte UTF-8 sequence but 0x6d ('m') is not a continuation byte
m3 := model.NewDefaultMessage("mytopic", "probl\xe8me critique")
require.Nil(t, s.AddMessage(m3))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "probl\uFFFDme critique", messages[2].Message)
// 0xb2: Latin-1 "²" — isolated byte in 0x80-0xBF range (UTF-8 continuation byte without lead)
m4 := model.NewDefaultMessage("mytopic", "CO\xb2 level high")
require.Nil(t, s.AddMessage(m4))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "CO\uFFFD level high", messages[3].Message)
// 0xe9 0x6d 0x61: Latin-1 "éma" — 0xe9 starts a 3-byte UTF-8 sequence but 0x6d ('m') is not a continuation byte
m5 := model.NewDefaultMessage("mytopic", "th\xe9matique")
require.Nil(t, s.AddMessage(m5))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "th\uFFFDmatique", messages[4].Message)
// 0xed 0x64 0x65: Latin-1 "íde" — 0xed starts a 3-byte UTF-8 sequence but 0x64 ('d') is not a continuation byte
m6 := model.NewDefaultMessage("mytopic", "vid\xed\x64eo surveillance")
require.Nil(t, s.AddMessage(m6))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "vid\uFFFDdeo surveillance", messages[5].Message)
// 0xf3 0x6e 0x3a 0x20: Latin-1 "ón: " — 0xf3 starts a 4-byte UTF-8 sequence but 0x6e ('n') is not a continuation byte
m7 := model.NewDefaultMessage("mytopic", "notificaci\xf3n: alerta")
require.Nil(t, s.AddMessage(m7))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "notificaci\uFFFDn: alerta", messages[6].Message)
// 0xb7: Latin-1 "·" — isolated continuation byte
m8 := model.NewDefaultMessage("mytopic", "item\xb7value")
require.Nil(t, s.AddMessage(m8))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "item\uFFFDvalue", messages[7].Message)
// 0xa8: Latin-1 "¨" — isolated continuation byte
m9 := model.NewDefaultMessage("mytopic", "na\xa8ve")
require.Nil(t, s.AddMessage(m9))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "na\uFFFDve", messages[8].Message)
// 0xdf 0x64: Latin-1 "ßd" — 0xdf starts a 2-byte UTF-8 sequence but 0x64 ('d') is not a continuation byte
m10 := model.NewDefaultMessage("mytopic", "gro\xdf\x64ruck")
require.Nil(t, s.AddMessage(m10))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "gro\uFFFDdruck", messages[9].Message)
// 0xe4 0x67 0x74: Latin-1 "ägt" — 0xe4 starts a 3-byte UTF-8 sequence but 0x67 ('g') is not a continuation byte
m11 := model.NewDefaultMessage("mytopic", "tr\xe4gt Last")
require.Nil(t, s.AddMessage(m11))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "tr\uFFFDgt Last", messages[10].Message)
// 0xe9 0x65 0x20: Latin-1 "ée " — 0xe9 starts a 3-byte UTF-8 sequence but 0x65 ('e') is not a continuation byte
m12 := model.NewDefaultMessage("mytopic", "journ\xe9\x65 termin\xe9\x65")
require.Nil(t, s.AddMessage(m12))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "journ\uFFFDe termin\uFFFDe", messages[11].Message)
})
}
func TestStore_AddMessage_NullByte(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// 0x00: NUL byte — valid UTF-8 but rejected by PostgreSQL
m := model.NewDefaultMessage("mytopic", "hello\x00world")
require.Nil(t, s.AddMessage(m))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "helloworld", messages[0].Message)
})
}
func TestStore_AddMessage_InvalidUTF8InTitleAndTags(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Invalid UTF-8 can arrive via HTTP headers (Title, Tags) which bypass body validation
m := model.NewDefaultMessage("mytopic", "valid message")
m.Title = "\xc9clipse du syst\xe8me"
m.Tags = []string{"probl\xe8me", "syst\xe9me"}
m.Click = "https://example.com/\xae"
require.Nil(t, s.AddMessage(m))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "\uFFFDclipse du syst\uFFFDme", messages[0].Title)
require.Equal(t, "probl\uFFFDme", messages[0].Tags[0])
require.Equal(t, "syst\uFFFDme", messages[0].Tags[1])
require.Equal(t, "https://example.com/\uFFFD", messages[0].Click)
})
}
func TestStore_AddMessage_InvalidUTF8BatchDoesNotDropValidMessages(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Previously, a single invalid message would roll back the entire batch transaction.
// Sanitization ensures all messages in a batch are written successfully.
msgs := []*model.Message{
model.NewDefaultMessage("mytopic", "valid message 1"),
model.NewDefaultMessage("mytopic", "notificaci\xf3n: alerta"),
model.NewDefaultMessage("mytopic", "valid message 3"),
}
require.Nil(t, s.AddMessages(msgs))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 3, len(messages))
})
}

View File

@@ -42,8 +42,11 @@ extra:
link: https://github.com/binwiederhier
extra_javascript:
- static/js/extra.js
- static/js/bcrypt.js
- static/js/config-generator.js
extra_css:
- static/css/extra.css
- static/css/config-generator.css
markdown_extensions:
- admonition

232
model/model.go Normal file
View File

@@ -0,0 +1,232 @@
package model
import (
"errors"
"net/netip"
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
// List of possible events
const (
OpenEvent = "open"
KeepaliveEvent = "keepalive"
MessageEvent = "message"
MessageDeleteEvent = "message_delete"
MessageClearEvent = "message_clear"
PollRequestEvent = "poll_request"
)
// MessageIDLength is the length of a randomly generated message ID
const MessageIDLength = 12
// Errors for message operations
var (
ErrUnexpectedMessageType = errors.New("unexpected message type")
ErrMessageNotFound = errors.New("message not found")
)
// Message represents a message published to a topic
type Message struct {
ID string `json:"id"` // Random message ID
SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID)
Time int64 `json:"time"` // Unix time in seconds
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
Event string `json:"event"` // One of the above
Topic string `json:"topic"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"`
Priority int `json:"priority,omitempty"`
Tags []string `json:"tags,omitempty"`
Click string `json:"click,omitempty"`
Icon string `json:"icon,omitempty"`
Actions []*Action `json:"actions,omitempty"`
Attachment *Attachment `json:"attachment,omitempty"`
PollID string `json:"poll_id,omitempty"`
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
User string `json:"-"` // UserID of the uploader, used to associated attachments
}
// Context returns a log context for the message
func (m *Message) Context() log.Context {
fields := map[string]any{
"topic": m.Topic,
"message_id": m.ID,
"message_sequence_id": m.SequenceID,
"message_time": m.Time,
"message_event": m.Event,
"message_body_size": len(m.Message),
}
if m.Sender.IsValid() {
fields["message_sender"] = m.Sender.String()
}
if m.User != "" {
fields["message_user"] = m.User
}
return fields
}
// SanitizeUTF8 replaces invalid UTF-8 sequences and strips NUL bytes from all user-supplied
// string fields. This is called early in the publish path so that all downstream consumers
// (Firebase, WebPush, SMTP, cache) receive clean UTF-8 strings.
func (m *Message) SanitizeUTF8() {
m.Topic = util.SanitizeUTF8(m.Topic)
m.Message = util.SanitizeUTF8(m.Message)
m.Title = util.SanitizeUTF8(m.Title)
m.Click = util.SanitizeUTF8(m.Click)
m.Icon = util.SanitizeUTF8(m.Icon)
m.ContentType = util.SanitizeUTF8(m.ContentType)
for i, tag := range m.Tags {
m.Tags[i] = util.SanitizeUTF8(tag)
}
if m.Attachment != nil {
m.Attachment.Name = util.SanitizeUTF8(m.Attachment.Name)
m.Attachment.Type = util.SanitizeUTF8(m.Attachment.Type)
m.Attachment.URL = util.SanitizeUTF8(m.Attachment.URL)
}
}
// ForJSON returns a copy of the message suitable for JSON output.
// It clears the SequenceID if it equals the ID to reduce redundancy.
func (m *Message) ForJSON() *Message {
if m.SequenceID == m.ID {
clone := *m
clone.SequenceID = ""
return &clone
}
return m
}
// Attachment represents a file attachment on a message
type Attachment struct {
Name string `json:"name"`
Type string `json:"type,omitempty"`
Size int64 `json:"size,omitempty"`
Expires int64 `json:"expires,omitempty"`
URL string `json:"url"`
}
// Action represents a user-defined action on a message
type Action struct {
ID string `json:"id"`
Action string `json:"action"` // "view", "broadcast", "http", or "copy"
Label string `json:"label"` // action button label
Clear bool `json:"clear"` // clear notification after successful execution
URL string `json:"url,omitempty"` // used in "view" and "http" actions
Method string `json:"method,omitempty"` // used in "http" action, default is POST (!)
Headers map[string]string `json:"headers,omitempty"` // used in "http" action
Body string `json:"body,omitempty"` // used in "http" action
Intent string `json:"intent,omitempty"` // used in "broadcast" action
Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action
Value string `json:"value,omitempty"` // used in "copy" action
}
// NewAction creates a new action with initialized maps
func NewAction() *Action {
return &Action{
Headers: make(map[string]string),
Extras: make(map[string]string),
}
}
// NewMessage creates a new message with the current timestamp
func NewMessage(event, topic, msg string) *Message {
return &Message{
ID: util.RandomString(MessageIDLength),
Time: time.Now().Unix(),
Event: event,
Topic: topic,
Message: msg,
}
}
// NewOpenMessage is a convenience method to create an open message
func NewOpenMessage(topic string) *Message {
return NewMessage(OpenEvent, topic, "")
}
// NewKeepaliveMessage is a convenience method to create a keepalive message
func NewKeepaliveMessage(topic string) *Message {
return NewMessage(KeepaliveEvent, topic, "")
}
// NewDefaultMessage is a convenience method to create a notification message
func NewDefaultMessage(topic, msg string) *Message {
return NewMessage(MessageEvent, topic, msg)
}
// NewActionMessage creates a new action message (message_delete or message_clear)
func NewActionMessage(event, topic, sequenceID string) *Message {
m := NewMessage(event, topic, "")
m.SequenceID = sequenceID
return m
}
// NewPollRequestMessage is a convenience method to create a poll request message
func NewPollRequestMessage(topic, pollID string) *Message {
m := NewMessage(PollRequestEvent, topic, "New message")
m.PollID = pollID
return m
}
// ValidMessageID returns true if the given string is a valid message ID
func ValidMessageID(s string) bool {
return util.ValidRandomString(s, MessageIDLength)
}
// SinceMarker represents a point in time or message ID from which to retrieve messages
type SinceMarker struct {
time time.Time
id string
}
// NewSinceTime creates a new SinceMarker from a Unix timestamp
func NewSinceTime(timestamp int64) SinceMarker {
return SinceMarker{time.Unix(timestamp, 0), ""}
}
// NewSinceID creates a new SinceMarker from a message ID
func NewSinceID(id string) SinceMarker {
return SinceMarker{time.Unix(0, 0), id}
}
// IsAll returns true if this is the "all messages" marker
func (t SinceMarker) IsAll() bool {
return t == SinceAllMessages
}
// IsNone returns true if this is the "no messages" marker
func (t SinceMarker) IsNone() bool {
return t == SinceNoMessages
}
// IsLatest returns true if this is the "latest message" marker
func (t SinceMarker) IsLatest() bool {
return t == SinceLatestMessage
}
// IsID returns true if this marker references a specific message ID
func (t SinceMarker) IsID() bool {
return t.id != "" && t.id != SinceLatestMessage.id
}
// Time returns the time component of the marker
func (t SinceMarker) Time() time.Time {
return t.time
}
// ID returns the message ID component of the marker
func (t SinceMarker) ID() string {
return t.id
}
// Common SinceMarker values for subscribing to messages
var (
SinceAllMessages = SinceMarker{time.Unix(0, 0), ""}
SinceNoMessages = SinceMarker{time.Unix(1, 0), ""}
SinceLatestMessage = SinceMarker{time.Unix(0, 0), "latest"}
)

488
s3/client.go Normal file
View File

@@ -0,0 +1,488 @@
// Package s3 provides a minimal S3-compatible client that works with AWS S3, DigitalOcean Spaces,
// GCP Cloud Storage, MinIO, Backblaze B2, and other S3-compatible providers. It uses raw HTTP
// requests with AWS Signature V4 signing, no AWS SDK dependency required.
package s3
import (
"bytes"
"context"
"crypto/md5" //nolint:gosec // MD5 is required by the S3 protocol for Content-MD5 headers
"encoding/base64"
"encoding/hex"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
// Client is a minimal S3-compatible client. It supports PutObject, GetObject, DeleteObjects,
// and ListObjectsV2 operations using AWS Signature V4 signing. The bucket and optional key prefix
// are fixed at construction time. All operations target the same bucket and prefix.
//
// Fields must not be modified after the Client is passed to any method or goroutine.
type Client struct {
AccessKey string // AWS access key ID
SecretKey string // AWS secret access key
Region string // e.g. "us-east-1"
Endpoint string // host[:port] only, e.g. "s3.amazonaws.com" or "nyc3.digitaloceanspaces.com"
Bucket string // S3 bucket name
Prefix string // optional key prefix (e.g. "attachments"); prepended to all keys automatically
PathStyle bool // if true, use path-style addressing; otherwise virtual-hosted-style
HTTPClient *http.Client // if nil, http.DefaultClient is used
}
// New creates a new S3 client from the given Config.
func New(config *Config) *Client {
return &Client{
AccessKey: config.AccessKey,
SecretKey: config.SecretKey,
Region: config.Region,
Endpoint: config.Endpoint,
Bucket: config.Bucket,
Prefix: config.Prefix,
PathStyle: config.PathStyle,
}
}
// PutObject uploads body to the given key. The key is automatically prefixed with the client's
// configured prefix. The body size does not need to be known in advance.
//
// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request.
// Otherwise, the body is uploaded using S3 multipart upload, reading one part at a time
// into memory.
func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error {
first := make([]byte, partSize)
n, err := io.ReadFull(body, first)
if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF {
return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n))
}
if err != nil {
return fmt.Errorf("s3: PutObject read: %w", err)
}
combined := io.MultiReader(bytes.NewReader(first), body)
return c.putObjectMultipart(ctx, key, combined)
}
// GetObject downloads an object. The key is automatically prefixed with the client's configured
// prefix. The caller must close the returned ReadCloser.
func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int64, error) {
fullKey := c.objectKey(key)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.objectURL(fullKey), nil)
if err != nil {
return nil, 0, fmt.Errorf("s3: GetObject request: %w", err)
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return nil, 0, fmt.Errorf("s3: GetObject: %w", err)
}
if resp.StatusCode/100 != 2 {
err := parseError(resp)
resp.Body.Close()
return nil, 0, err
}
return resp.Body, resp.ContentLength, nil
}
// DeleteObjects removes multiple objects in a single batch request. Keys are automatically
// prefixed with the client's configured prefix. S3 supports up to 1000 keys per call; the
// caller is responsible for batching if needed.
//
// Even when S3 returns HTTP 200, individual keys may fail. If any per-key errors are present
// in the response, they are returned as a combined error.
func (c *Client) DeleteObjects(ctx context.Context, keys []string) error {
var body bytes.Buffer
body.WriteString("<Delete><Quiet>true</Quiet>")
for _, key := range keys {
body.WriteString("<Object><Key>")
xml.EscapeText(&body, []byte(c.objectKey(key)))
body.WriteString("</Key></Object>")
}
body.WriteString("</Delete>")
bodyBytes := body.Bytes()
payloadHash := sha256Hex(bodyBytes)
// Content-MD5 is required by the S3 protocol for DeleteObjects requests.
md5Sum := md5.Sum(bodyBytes) //nolint:gosec
contentMD5 := base64.StdEncoding.EncodeToString(md5Sum[:])
reqURL := c.bucketURL() + "?delete="
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
if err != nil {
return fmt.Errorf("s3: DeleteObjects request: %w", err)
}
req.ContentLength = int64(len(bodyBytes))
req.Header.Set("Content-Type", "application/xml")
req.Header.Set("Content-MD5", contentMD5)
c.signV4(req, payloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: DeleteObjects: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
// S3 may return HTTP 200 with per-key errors in the response body
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: DeleteObjects read response: %w", err)
}
var result deleteResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return nil // If we can't parse, assume success (Quiet mode returns empty body on success)
}
if len(result.Errors) > 0 {
var msgs []string
for _, e := range result.Errors {
msgs = append(msgs, fmt.Sprintf("%s: %s", e.Key, e.Message))
}
return fmt.Errorf("s3: DeleteObjects partial failure: %s", strings.Join(msgs, "; "))
}
return nil
}
// ListObjects performs a single ListObjectsV2 request using the client's configured prefix.
// Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000).
func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxKeys int) (*ListResult, error) {
query := url.Values{"list-type": {"2"}}
if prefix := c.prefixForList(); prefix != "" {
query.Set("prefix", prefix)
}
if continuationToken != "" {
query.Set("continuation-token", continuationToken)
}
if maxKeys > 0 {
query.Set("max-keys", strconv.Itoa(maxKeys))
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil)
if err != nil {
return nil, fmt.Errorf("s3: ListObjects request: %w", err)
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("s3: ListObjects: %w", err)
}
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("s3: ListObjects read: %w", err)
}
if resp.StatusCode/100 != 2 {
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
}
var result listObjectsV2Response
if err := xml.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("s3: ListObjects XML: %w", err)
}
objects := make([]Object, len(result.Contents))
for i, obj := range result.Contents {
objects[i] = Object(obj)
}
return &ListResult{
Objects: objects,
IsTruncated: result.IsTruncated,
NextContinuationToken: result.NextContinuationToken,
}, nil
}
// ListAllObjects returns all objects under the client's configured prefix by paginating through
// ListObjectsV2 results automatically. It stops after 10,000 pages as a safety valve.
func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
const maxPages = 10000
var all []Object
var token string
for page := 0; page < maxPages; page++ {
result, err := c.ListObjects(ctx, token, 0)
if err != nil {
return nil, err
}
all = append(all, result.Objects...)
if !result.IsTruncated {
return all, nil
}
token = result.NextContinuationToken
}
return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages)
}
// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error {
fullKey := c.objectKey(key)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body)
if err != nil {
return fmt.Errorf("s3: PutObject request: %w", err)
}
req.ContentLength = size
c.signV4(req, unsignedPayload)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: PutObject: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
return nil
}
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
// chunks, uploading each as a separate part. This allows uploading without knowing the total
// body size in advance.
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
fullKey := c.objectKey(key)
// Step 1: Initiate multipart upload
uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
if err != nil {
return err
}
// Step 2: Upload parts
var parts []completedPart
buf := make([]byte, partSize)
partNumber := 1
for {
n, err := io.ReadFull(body, buf)
if n > 0 {
etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
if uploadErr != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return uploadErr
}
parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
partNumber++
}
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
if err != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return fmt.Errorf("s3: PutObject read: %w", err)
}
}
// Step 3: Complete multipart upload
return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
}
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
reqURL := c.objectURL(fullKey) + "?uploads"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil)
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err)
}
req.ContentLength = 0
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return "", parseError(resp)
}
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err)
}
var result initiateMultipartUploadResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
}
return result.UploadID, nil
}
// uploadPart uploads a single part of a multipart upload and returns the ETag.
func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("s3: UploadPart request: %w", err)
}
req.ContentLength = int64(len(data))
c.signV4(req, unsignedPayload)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", fmt.Errorf("s3: UploadPart: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return "", parseError(resp)
}
etag := resp.Header.Get("ETag")
return etag, nil
}
// completeMultipartUpload finalizes a multipart upload with the given parts.
func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
var body bytes.Buffer
body.WriteString("<CompleteMultipartUpload>")
for _, p := range parts {
fmt.Fprintf(&body, "<Part><PartNumber>%d</PartNumber><ETag>%s</ETag></Part>", p.PartNumber, p.ETag)
}
body.WriteString("</CompleteMultipartUpload>")
bodyBytes := body.Bytes()
payloadHash := sha256Hex(bodyBytes)
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err)
}
req.ContentLength = int64(len(bodyBytes))
req.Header.Set("Content-Type", "application/xml")
c.signV4(req, payloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
// Read response body to check for errors (S3 can return 200 with an error body)
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err)
}
// Check if the response contains an error
var errResp ErrorResponse
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
errResp.StatusCode = resp.StatusCode
return &errResp
}
return nil
}
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
if err != nil {
return
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return
}
resp.Body.Close()
}
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
func (c *Client) signV4(req *http.Request, payloadHash string) {
now := time.Now().UTC()
datestamp := now.Format("20060102")
amzDate := now.Format("20060102T150405Z")
// Required headers
req.Header.Set("Host", c.hostHeader())
req.Header.Set("X-Amz-Date", amzDate)
req.Header.Set("X-Amz-Content-Sha256", payloadHash)
// Canonical headers (all headers we set, sorted by lowercase key)
signedKeys := make([]string, 0, len(req.Header))
canonHeaders := make(map[string]string, len(req.Header))
for k := range req.Header {
lk := strings.ToLower(k)
signedKeys = append(signedKeys, lk)
canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k))
}
sort.Strings(signedKeys)
signedHeadersStr := strings.Join(signedKeys, ";")
var chBuf strings.Builder
for _, k := range signedKeys {
chBuf.WriteString(k)
chBuf.WriteByte(':')
chBuf.WriteString(canonHeaders[k])
chBuf.WriteByte('\n')
}
// Canonical request
canonicalRequest := strings.Join([]string{
req.Method,
canonicalURI(req.URL),
canonicalQueryString(req.URL.Query()),
chBuf.String(),
signedHeadersStr,
payloadHash,
}, "\n")
// String to sign
credentialScope := datestamp + "/" + c.Region + "/s3/aws4_request"
stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest))
// Signing key
signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256(
[]byte("AWS4"+c.SecretKey), []byte(datestamp)),
[]byte(c.Region)),
[]byte("s3")),
[]byte("aws4_request"))
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
req.Header.Set("Authorization", fmt.Sprintf(
"AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
c.AccessKey, credentialScope, signedHeadersStr, signature,
))
}
func (c *Client) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient
}
return http.DefaultClient
}
// objectKey prepends the configured prefix to the given key.
func (c *Client) objectKey(key string) string {
if c.Prefix != "" {
return c.Prefix + "/" + key
}
return key
}
// prefixForList returns the prefix to use in ListObjectsV2 requests,
// with a trailing slash so that only objects under the prefix directory are returned.
func (c *Client) prefixForList() string {
if c.Prefix != "" {
return c.Prefix + "/"
}
return ""
}
// bucketURL returns the base URL for bucket-level operations.
func (c *Client) bucketURL() string {
if c.PathStyle {
return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket)
}
return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint)
}
// objectURL returns the full URL for an object (key should already include the prefix).
// Each path segment is URI-encoded to handle special characters in keys.
func (c *Client) objectURL(key string) string {
segments := strings.Split(key, "/")
for i, seg := range segments {
segments[i] = uriEncode(seg)
}
return c.bucketURL() + "/" + strings.Join(segments, "/")
}
// hostHeader returns the value for the Host header.
func (c *Client) hostHeader() string {
if c.PathStyle {
return c.Endpoint
}
return c.Bucket + "." + c.Endpoint
}

860
s3/client_test.go Normal file
View File

@@ -0,0 +1,860 @@
package s3
import (
"bytes"
"context"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"sort"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/require"
)
// --- Mock S3 server ---
//
// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
type mockS3Server struct {
objects map[string][]byte // full key (bucket/key) -> body
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
nextID int // counter for generating upload IDs
mu sync.RWMutex
}
func newMockS3Server() (*httptest.Server, *mockS3Server) {
m := &mockS3Server{
objects: make(map[string][]byte),
uploads: make(map[string]map[int][]byte),
}
return httptest.NewTLSServer(m), m
}
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Path is /{bucket}[/{key...}]
path := strings.TrimPrefix(r.URL.Path, "/")
q := r.URL.Query()
switch {
case r.Method == http.MethodPut && q.Has("partNumber"):
m.handleUploadPart(w, r, path)
case r.Method == http.MethodPut:
m.handlePut(w, r, path)
case r.Method == http.MethodPost && q.Has("uploads"):
m.handleInitiateMultipart(w, r, path)
case r.Method == http.MethodPost && q.Has("uploadId"):
m.handleCompleteMultipart(w, r, path)
case r.Method == http.MethodDelete && q.Has("uploadId"):
m.handleAbortMultipart(w, r, path)
case r.Method == http.MethodGet && q.Get("list-type") == "2":
m.handleList(w, r, path)
case r.Method == http.MethodGet:
m.handleGet(w, r, path)
case r.Method == http.MethodPost && q.Has("delete"):
m.handleDelete(w, r, path)
default:
http.Error(w, "not implemented", http.StatusNotImplemented)
}
}
func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
m.objects[path] = body
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
m.mu.Lock()
m.nextID++
uploadID := fmt.Sprintf("upload-%d", m.nextID)
m.uploads[uploadID] = make(map[int][]byte)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
}
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
var partNumber int
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
parts[partNumber] = body
m.mu.Unlock()
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
w.Header().Set("ETag", etag)
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
// Assemble parts in order
var assembled []byte
for i := 1; i <= len(parts); i++ {
assembled = append(assembled, parts[i]...)
}
m.objects[path] = assembled
delete(m.uploads, uploadID)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
}
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
delete(m.uploads, uploadID)
m.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
}
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
m.mu.RLock()
body, ok := m.objects[path]
m.mu.RUnlock()
if !ok {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`))
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
w.WriteHeader(http.StatusOK)
w.Write(body)
}
type listObjectsResponse struct {
XMLName xml.Name `xml:"ListBucketResult"`
Contents []listObject `xml:"Contents"`
// Pagination support
IsTruncated bool `xml:"IsTruncated"`
NextContinuationToken string `xml:"NextContinuationToken"`
}
func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
// bucketPath is just the bucket name
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var req struct {
Objects []struct {
Key string `xml:"Key"`
} `xml:"Object"`
}
if err := xml.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
m.mu.Lock()
for _, obj := range req.Objects {
delete(m.objects, bucketPath+"/"+obj.Key)
}
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></DeleteResult>`))
}
func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
prefix := r.URL.Query().Get("prefix")
contToken := r.URL.Query().Get("continuation-token")
m.mu.RLock()
var allKeys []string
for key := range m.objects {
objKey := strings.TrimPrefix(key, bucketPath+"/")
if objKey == key {
continue // different bucket
}
if prefix == "" || strings.HasPrefix(objKey, prefix) {
allKeys = append(allKeys, objKey)
}
}
m.mu.RUnlock()
sort.Strings(allKeys)
// Simple continuation token: it's the key to start after
startIdx := 0
if contToken != "" {
for i, k := range allKeys {
if k == contToken {
startIdx = i + 1
break
}
}
}
maxKeys := 1000
if mk := r.URL.Query().Get("max-keys"); mk != "" {
fmt.Sscanf(mk, "%d", &maxKeys)
}
endIdx := startIdx + maxKeys
truncated := false
nextToken := ""
if endIdx < len(allKeys) {
truncated = true
nextToken = allKeys[endIdx-1]
allKeys = allKeys[startIdx:endIdx]
} else {
allKeys = allKeys[startIdx:]
}
m.mu.RLock()
var contents []listObject
for _, objKey := range allKeys {
body := m.objects[bucketPath+"/"+objKey]
contents = append(contents, listObject{Key: objKey, Size: int64(len(body))})
}
m.mu.RUnlock()
resp := listObjectsResponse{
Contents: contents,
IsTruncated: truncated,
NextContinuationToken: nextToken,
}
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
xml.NewEncoder(w).Encode(resp)
}
func (m *mockS3Server) objectCount() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.objects)
}
// --- Helper to create a test client pointing at mock server ---
func newTestClient(server *httptest.Server, bucket, prefix string) *Client {
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
host := strings.TrimPrefix(server.URL, "https://")
return &Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: host,
Bucket: bucket,
Prefix: prefix,
PathStyle: true,
HTTPClient: server.Client(),
}
}
// --- URL parsing tests ---
func TestParseURL_Success(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/attachments?region=us-east-1")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "attachments", cfg.Prefix)
require.Equal(t, "us-east-1", cfg.Region)
require.Equal(t, "AKID", cfg.AccessKey)
require.Equal(t, "SECRET", cfg.SecretKey)
require.Equal(t, "s3.us-east-1.amazonaws.com", cfg.Endpoint)
require.False(t, cfg.PathStyle)
}
func TestParseURL_NoPrefix(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "", cfg.Prefix)
}
func TestParseURL_WithEndpoint(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/prefix?region=us-east-1&endpoint=https://s3.example.com")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "prefix", cfg.Prefix)
require.Equal(t, "s3.example.com", cfg.Endpoint)
require.True(t, cfg.PathStyle)
}
func TestParseURL_EndpointHTTP(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=http://localhost:9000")
require.Nil(t, err)
require.Equal(t, "localhost:9000", cfg.Endpoint)
require.True(t, cfg.PathStyle)
}
func TestParseURL_EndpointTrailingSlash(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=https://s3.example.com/")
require.Nil(t, err)
require.Equal(t, "s3.example.com", cfg.Endpoint)
}
func TestParseURL_NestedPrefix(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/a/b/c?region=us-east-1")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "a/b/c", cfg.Prefix)
}
func TestParseURL_MissingRegion(t *testing.T) {
_, err := ParseURL("s3://AKID:SECRET@my-bucket")
require.Error(t, err)
require.Contains(t, err.Error(), "region")
}
func TestParseURL_MissingCredentials(t *testing.T) {
_, err := ParseURL("s3://my-bucket?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "access key")
}
func TestParseURL_MissingSecretKey(t *testing.T) {
_, err := ParseURL("s3://AKID@my-bucket?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "secret key")
}
func TestParseURL_WrongScheme(t *testing.T) {
_, err := ParseURL("http://AKID:SECRET@my-bucket?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "scheme")
}
func TestParseURL_EmptyBucket(t *testing.T) {
_, err := ParseURL("s3://AKID:SECRET@?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "bucket")
}
// --- Unit tests: URL construction ---
func TestClient_BucketURL_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "https://s3.example.com/my-bucket", c.bucketURL())
}
func TestClient_BucketURL_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.bucketURL())
}
func TestClient_ObjectURL_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("prefix/obj"))
}
func TestClient_ObjectURL_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("prefix/obj"))
}
func TestClient_HostHeader_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "s3.example.com", c.hostHeader())
}
func TestClient_HostHeader_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.hostHeader())
}
func TestClient_ObjectKey(t *testing.T) {
c := &Client{Prefix: "attachments"}
require.Equal(t, "attachments/file123", c.objectKey("file123"))
c2 := &Client{Prefix: ""}
require.Equal(t, "file123", c2.objectKey("file123"))
}
func TestClient_PrefixForList(t *testing.T) {
c := &Client{Prefix: "attachments"}
require.Equal(t, "attachments/", c.prefixForList())
c2 := &Client{Prefix: ""}
require.Equal(t, "", c2.prefixForList())
}
// --- Integration tests using mock S3 server ---
func TestClient_PutGetObject(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Put
err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"))
require.Nil(t, err)
// Get
reader, size, err := client.GetObject(ctx, "test-key")
require.Nil(t, err)
require.Equal(t, int64(11), size)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, "hello world", string(data))
}
func TestClient_PutGetObject_WithPrefix(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "pfx")
ctx := context.Background()
err := client.PutObject(ctx, "test-key", strings.NewReader("hello"))
require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "test-key")
require.Nil(t, err)
data, _ := io.ReadAll(reader)
reader.Close()
require.Equal(t, "hello", string(data))
}
func TestClient_GetObject_NotFound(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
_, _, err := client.GetObject(context.Background(), "nonexistent")
require.Error(t, err)
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
require.Equal(t, "NoSuchKey", errResp.Code)
}
func TestClient_DeleteObjects(t *testing.T) {
server, mock := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Put several objects
for i := 0; i < 5; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")))
require.Nil(t, err)
}
require.Equal(t, 5, mock.objectCount())
// Delete some
err := client.DeleteObjects(ctx, []string{"key-1", "key-3"})
require.Nil(t, err)
require.Equal(t, 3, mock.objectCount())
// Verify deleted ones are gone
_, _, err = client.GetObject(ctx, "key-1")
require.Error(t, err)
_, _, err = client.GetObject(ctx, "key-3")
require.Error(t, err)
// Verify remaining ones are still there
reader, _, err := client.GetObject(ctx, "key-0")
require.Nil(t, err)
reader.Close()
}
func TestClient_ListObjects(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
ctx := context.Background()
// Client with prefix "pfx": list should only return objects under pfx/
client := newTestClient(server, "my-bucket", "pfx")
for i := 0; i < 3; i++ {
err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
// Also put an object outside the prefix using a no-prefix client
clientNoPrefix := newTestClient(server, "my-bucket", "")
err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")))
require.Nil(t, err)
// List with prefix client: should only see 3
result, err := client.ListObjects(ctx, "", 0)
require.Nil(t, err)
require.Len(t, result.Objects, 3)
require.False(t, result.IsTruncated)
// List with no-prefix client: should see all 4
result, err = clientNoPrefix.ListObjects(ctx, "", 0)
require.Nil(t, err)
require.Len(t, result.Objects, 4)
}
func TestClient_ListObjects_Pagination(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Put 5 objects
for i := 0; i < 5; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
// List with max-keys=2
result, err := client.ListObjects(ctx, "", 2)
require.Nil(t, err)
require.Len(t, result.Objects, 2)
require.True(t, result.IsTruncated)
require.NotEmpty(t, result.NextContinuationToken)
// Get next page
result2, err := client.ListObjects(ctx, result.NextContinuationToken, 2)
require.Nil(t, err)
require.Len(t, result2.Objects, 2)
require.True(t, result2.IsTruncated)
// Get last page
result3, err := client.ListObjects(ctx, result2.NextContinuationToken, 2)
require.Nil(t, err)
require.Len(t, result3.Objects, 1)
require.False(t, result3.IsTruncated)
}
func TestClient_ListAllObjects(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "pfx")
ctx := context.Background()
for i := 0; i < 10; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
objects, err := client.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, 10)
}
func TestClient_PutObject_LargeBody(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// 1 MB object
data := make([]byte, 1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "large", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "large")
require.Nil(t, err)
require.Equal(t, int64(1024*1024), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
}
func TestClient_PutObject_ChunkedUpload(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// 12 MB object, exceeds 5 MB partSize, triggers multipart upload path
data := make([]byte, 12*1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "multipart", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "multipart")
require.Nil(t, err)
require.Equal(t, int64(12*1024*1024), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
}
func TestClient_PutObject_ExactPartSize(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Exactly 5 MB (partSize), should use the simple put path (ReadFull succeeds fully)
data := make([]byte, 5*1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "exact", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "exact")
require.Nil(t, err)
require.Equal(t, int64(5*1024*1024), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
}
func TestClient_PutObject_NestedKey(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"))
require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt")
require.Nil(t, err)
data, _ := io.ReadAll(reader)
reader.Close()
require.Equal(t, "nested", string(data))
}
// --- Scale test: 20k objects (ntfy-adjacent) ---
func TestClient_ListAllObjects_20k(t *testing.T) {
if testing.Short() {
t.Skip("skipping 20k object test in short mode")
}
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "attachments")
ctx := context.Background()
const numObjects = 20000
const batchSize = 500
// Insert 20k objects in batches to keep it fast
for batch := 0; batch < numObjects/batchSize; batch++ {
for i := 0; i < batchSize; i++ {
idx := batch*batchSize + i
key := fmt.Sprintf("%08d", idx)
err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
}
// List all 20k objects with pagination
objects, err := client.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, numObjects)
// Verify total size
var totalSize int64
for _, obj := range objects {
totalSize += obj.Size
}
require.Equal(t, int64(numObjects), totalSize)
// Delete 1000 objects (simulating attachment expiry cleanup)
keys := make([]string, 1000)
for i := range keys {
keys[i] = fmt.Sprintf("%08d", i)
}
err = client.DeleteObjects(ctx, keys)
require.Nil(t, err)
// List again: should have 19000
objects, err = client.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, numObjects-1000)
}
// --- Real S3 integration test ---
//
// Set the following environment variables to run this test against a real S3 bucket:
//
// S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET
//
// Optional:
//
// S3_ENDPOINT: host[:port] for S3-compatible providers (e.g. "nyc3.digitaloceanspaces.com")
// S3_PATH_STYLE: set to "true" for path-style addressing
// S3_PREFIX: key prefix to use (default: "ntfy-s3-test")
func TestClient_RealBucket(t *testing.T) {
accessKey := os.Getenv("S3_ACCESS_KEY")
secretKey := os.Getenv("S3_SECRET_KEY")
region := os.Getenv("S3_REGION")
bucket := os.Getenv("S3_BUCKET")
if accessKey == "" || secretKey == "" || region == "" || bucket == "" {
t.Skip("skipping real S3 test: set S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET")
}
endpoint := os.Getenv("S3_ENDPOINT")
if endpoint == "" {
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region)
}
pathStyle := os.Getenv("S3_PATH_STYLE") == "true"
prefix := os.Getenv("S3_PREFIX")
if prefix == "" {
prefix = "ntfy-s3-test"
}
client := &Client{
AccessKey: accessKey,
SecretKey: secretKey,
Region: region,
Endpoint: endpoint,
Bucket: bucket,
Prefix: prefix,
PathStyle: pathStyle,
}
ctx := context.Background()
// Clean up any leftover objects from previous runs
existing, err := client.ListAllObjects(ctx)
require.Nil(t, err)
if len(existing) > 0 {
keys := make([]string, len(existing))
for i, obj := range existing {
// Strip the prefix since DeleteObjects will re-add it
keys[i] = strings.TrimPrefix(obj.Key, prefix+"/")
}
// Batch delete in groups of 1000
for i := 0; i < len(keys); i += 1000 {
end := i + 1000
if end > len(keys) {
end = len(keys)
}
err := client.DeleteObjects(ctx, keys[i:end])
require.Nil(t, err)
}
}
t.Run("PutGetDelete", func(t *testing.T) {
key := "test-object"
content := "hello from ntfy s3 test"
// Put
err := client.PutObject(ctx, key, strings.NewReader(content))
require.Nil(t, err)
// Get
reader, size, err := client.GetObject(ctx, key)
require.Nil(t, err)
require.Equal(t, int64(len(content)), size)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, content, string(data))
// Delete
err = client.DeleteObjects(ctx, []string{key})
require.Nil(t, err)
// Get after delete should fail
_, _, err = client.GetObject(ctx, key)
require.Error(t, err)
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
})
t.Run("ListObjects", func(t *testing.T) {
// Use a sub-prefix client for isolation
listClient := &Client{
AccessKey: accessKey,
SecretKey: secretKey,
Region: region,
Endpoint: endpoint,
Bucket: bucket,
Prefix: prefix + "/list-test",
PathStyle: pathStyle,
}
// Put 10 objects
for i := 0; i < 10; i++ {
err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"))
require.Nil(t, err)
}
// List
objects, err := listClient.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, 10)
// Clean up
keys := make([]string, 10)
for i := range keys {
keys[i] = fmt.Sprintf("%d", i)
}
err = listClient.DeleteObjects(ctx, keys)
require.Nil(t, err)
})
t.Run("LargeObject", func(t *testing.T) {
key := "large-object"
data := make([]byte, 5*1024*1024) // 5 MB
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, key, bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, key)
require.Nil(t, err)
require.Equal(t, int64(len(data)), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
err = client.DeleteObjects(ctx, []string{key})
require.Nil(t, err)
})
}

76
s3/types.go Normal file
View File

@@ -0,0 +1,76 @@
package s3
import "fmt"
// Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string.
type Config struct {
Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com"
PathStyle bool
Bucket string
Prefix string
Region string
AccessKey string
SecretKey string
}
// Object represents an S3 object returned by list operations.
type Object struct {
Key string
Size int64
}
// ListResult holds the response from a ListObjectsV2 call.
type ListResult struct {
Objects []Object
IsTruncated bool
NextContinuationToken string
}
// ErrorResponse is returned when S3 responds with a non-2xx status code.
type ErrorResponse struct {
StatusCode int
Code string `xml:"Code"`
Message string `xml:"Message"`
Body string `xml:"-"` // raw response body
}
func (e *ErrorResponse) Error() string {
if e.Code != "" {
return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message)
}
return fmt.Sprintf("s3: HTTP %d: %s", e.StatusCode, e.Body)
}
// listObjectsV2Response is the XML response from S3 ListObjectsV2
type listObjectsV2Response struct {
Contents []listObject `xml:"Contents"`
IsTruncated bool `xml:"IsTruncated"`
NextContinuationToken string `xml:"NextContinuationToken"`
}
type listObject struct {
Key string `xml:"Key"`
Size int64 `xml:"Size"`
}
// deleteResult is the XML response from S3 DeleteObjects
type deleteResult struct {
Errors []deleteError `xml:"Error"`
}
type deleteError struct {
Key string `xml:"Key"`
Code string `xml:"Code"`
Message string `xml:"Message"`
}
// initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload
type initiateMultipartUploadResult struct {
UploadID string `xml:"UploadId"`
}
// completedPart represents a successfully uploaded part for CompleteMultipartUpload
type completedPart struct {
PartNumber int
ETag string
}

166
s3/util.go Normal file
View File

@@ -0,0 +1,166 @@
package s3
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
)
const (
// SHA-256 hash of the empty string, used as the payload hash for bodiless requests
emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
// Sent as the payload hash for streaming uploads where the body is not buffered in memory
unsignedPayload = "UNSIGNED-PAYLOAD"
// maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB)
maxResponseBytes = 10 * 1024 * 1024
// partSize is the size of each part for multipart uploads (5 MB). This is also the threshold
// above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum
// part size of 5 MB for all parts except the last.
partSize = 5 * 1024 * 1024
)
// ParseURL parses an S3 URL of the form:
//
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
//
// When endpoint is specified, path-style addressing is enabled automatically.
func ParseURL(s3URL string) (*Config, error) {
u, err := url.Parse(s3URL)
if err != nil {
return nil, fmt.Errorf("s3: invalid URL: %w", err)
}
if u.Scheme != "s3" {
return nil, fmt.Errorf("s3: URL scheme must be 's3', got '%s'", u.Scheme)
}
if u.Host == "" {
return nil, fmt.Errorf("s3: bucket name must be specified as host")
}
bucket := u.Host
prefix := strings.TrimPrefix(u.Path, "/")
accessKey := u.User.Username()
secretKey, _ := u.User.Password()
if accessKey == "" || secretKey == "" {
return nil, fmt.Errorf("s3: access key and secret key must be specified in URL")
}
region := u.Query().Get("region")
if region == "" {
return nil, fmt.Errorf("s3: region query parameter is required")
}
endpointParam := u.Query().Get("endpoint")
var endpoint string
var pathStyle bool
if endpointParam != "" {
// Custom endpoint: strip scheme prefix to extract host[:port]
ep := strings.TrimRight(endpointParam, "/")
ep = strings.TrimPrefix(ep, "https://")
ep = strings.TrimPrefix(ep, "http://")
endpoint = ep
pathStyle = true
} else {
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region)
pathStyle = false
}
return &Config{
Endpoint: endpoint,
PathStyle: pathStyle,
Bucket: bucket,
Prefix: prefix,
Region: region,
AccessKey: accessKey,
SecretKey: secretKey,
}, nil
}
// parseError reads an S3 error response and returns an *ErrorResponse.
func parseError(resp *http.Response) error {
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: reading error response: %w", err)
}
return parseErrorFromBytes(resp.StatusCode, body)
}
func parseErrorFromBytes(statusCode int, body []byte) error {
errResp := &ErrorResponse{
StatusCode: statusCode,
Body: string(body),
}
// Try to parse XML error; if it fails, we still have StatusCode and Body
_ = xml.Unmarshal(body, errResp)
return errResp
}
// canonicalURI returns the URI-encoded path for the canonical request. Each path segment is
// percent-encoded per RFC 3986; forward slashes are preserved.
func canonicalURI(u *url.URL) string {
p := u.Path
if p == "" {
return "/"
}
segments := strings.Split(p, "/")
for i, seg := range segments {
segments[i] = uriEncode(seg)
}
return strings.Join(segments, "/")
}
// canonicalQueryString builds the query string for the canonical request. Keys and values
// are URI-encoded per RFC 3986 (using %20, not +) and sorted lexically by key.
func canonicalQueryString(values url.Values) string {
if len(values) == 0 {
return ""
}
keys := make([]string, 0, len(values))
for k := range values {
keys = append(keys, k)
}
sort.Strings(keys)
var pairs []string
for _, k := range keys {
ek := uriEncode(k)
vs := make([]string, len(values[k]))
copy(vs, values[k])
sort.Strings(vs)
for _, v := range vs {
pairs = append(pairs, ek+"="+uriEncode(v))
}
}
return strings.Join(pairs, "&")
}
// uriEncode percent-encodes a string per RFC 3986, encoding everything except unreserved
// characters (A-Z a-z 0-9 - _ . ~).
func uriEncode(s string) string {
var buf strings.Builder
for i := 0; i < len(s); i++ {
b := s[i]
if (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') ||
b == '-' || b == '_' || b == '.' || b == '~' {
buf.WriteByte(b)
} else {
fmt.Fprintf(&buf, "%%%02X", b)
}
}
return buf.String()
}
func sha256Hex(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}
func hmacSHA256(key, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}

181
s3/util_test.go Normal file
View File

@@ -0,0 +1,181 @@
package s3
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/require"
)
func TestURIEncode(t *testing.T) {
// Unreserved characters are not encoded
require.Equal(t, "abcdefghijklmnopqrstuvwxyz", uriEncode("abcdefghijklmnopqrstuvwxyz"))
require.Equal(t, "ABCDEFGHIJKLMNOPQRSTUVWXYZ", uriEncode("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
require.Equal(t, "0123456789", uriEncode("0123456789"))
require.Equal(t, "-_.~", uriEncode("-_.~"))
// Spaces use %20, not +
require.Equal(t, "hello%20world", uriEncode("hello world"))
// Slashes are encoded (canonicalURI handles slash splitting separately)
require.Equal(t, "a%2Fb", uriEncode("a/b"))
// Special characters
require.Equal(t, "%2B", uriEncode("+"))
require.Equal(t, "%3D", uriEncode("="))
require.Equal(t, "%26", uriEncode("&"))
require.Equal(t, "%40", uriEncode("@"))
require.Equal(t, "%23", uriEncode("#"))
// Mixed
require.Equal(t, "test~file-name_1.txt", uriEncode("test~file-name_1.txt"))
require.Equal(t, "key%20with%20spaces%2Fand%2Fslashes", uriEncode("key with spaces/and/slashes"))
// Empty string
require.Equal(t, "", uriEncode(""))
}
func TestCanonicalURI(t *testing.T) {
// Simple path
u, _ := url.Parse("https://example.com/bucket/key")
require.Equal(t, "/bucket/key", canonicalURI(u))
// Root path
u, _ = url.Parse("https://example.com/")
require.Equal(t, "/", canonicalURI(u))
// Empty path
u, _ = url.Parse("https://example.com")
require.Equal(t, "/", canonicalURI(u))
// Path with special characters
u, _ = url.Parse("https://example.com/bucket/key%20with%20spaces")
require.Equal(t, "/bucket/key%20with%20spaces", canonicalURI(u))
// Nested path
u, _ = url.Parse("https://example.com/bucket/a/b/c/file.txt")
require.Equal(t, "/bucket/a/b/c/file.txt", canonicalURI(u))
}
func TestCanonicalQueryString(t *testing.T) {
// Multiple keys sorted alphabetically
vals := url.Values{
"prefix": {"test/"},
"list-type": {"2"},
}
require.Equal(t, "list-type=2&prefix=test%2F", canonicalQueryString(vals))
// Empty values
require.Equal(t, "", canonicalQueryString(url.Values{}))
// Single key
require.Equal(t, "key=value", canonicalQueryString(url.Values{"key": {"value"}}))
// Key with multiple values (sorted)
vals = url.Values{"key": {"b", "a"}}
require.Equal(t, "key=a&key=b", canonicalQueryString(vals))
// Keys requiring encoding
vals = url.Values{"continuation-token": {"abc+def"}}
require.Equal(t, "continuation-token=abc%2Bdef", canonicalQueryString(vals))
}
func TestSHA256Hex(t *testing.T) {
// SHA-256 of empty string
require.Equal(t, emptyPayloadHash, sha256Hex([]byte("")))
// SHA-256 of known value
require.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", sha256Hex([]byte("hello")))
}
func TestHmacSHA256(t *testing.T) {
// Known test vector: HMAC-SHA256("key", "message")
result := hmacSHA256([]byte("key"), []byte("message"))
require.Len(t, result, 32) // SHA-256 produces 32 bytes
require.NotEqual(t, make([]byte, 32), result)
// Same inputs should produce same output
result2 := hmacSHA256([]byte("key"), []byte("message"))
require.Equal(t, result, result2)
// Different inputs should produce different output
result3 := hmacSHA256([]byte("different-key"), []byte("message"))
require.NotEqual(t, result, result3)
}
func TestSignV4_SetsRequiredHeaders(t *testing.T) {
c := &Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: "s3.us-east-1.amazonaws.com",
Bucket: "my-bucket",
}
req, _ := http.NewRequest(http.MethodGet, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
c.signV4(req, emptyPayloadHash)
// All required SigV4 headers must be set
require.NotEmpty(t, req.Header.Get("Host"))
require.NotEmpty(t, req.Header.Get("X-Amz-Date"))
require.Equal(t, emptyPayloadHash, req.Header.Get("X-Amz-Content-Sha256"))
// Authorization header must have correct format
auth := req.Header.Get("Authorization")
require.Contains(t, auth, "AWS4-HMAC-SHA256")
require.Contains(t, auth, "Credential=AKID/")
require.Contains(t, auth, "/us-east-1/s3/aws4_request")
require.Contains(t, auth, "SignedHeaders=")
require.Contains(t, auth, "Signature=")
}
func TestSignV4_UnsignedPayload(t *testing.T) {
c := &Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: "s3.us-east-1.amazonaws.com",
Bucket: "my-bucket",
}
req, _ := http.NewRequest(http.MethodPut, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
c.signV4(req, unsignedPayload)
require.Equal(t, unsignedPayload, req.Header.Get("X-Amz-Content-Sha256"))
}
func TestSignV4_DifferentRegions(t *testing.T) {
c1 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"}
c2 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"}
req1, _ := http.NewRequest(http.MethodGet, "https://b.s3.us-east-1.amazonaws.com/key", nil)
c1.signV4(req1, emptyPayloadHash)
req2, _ := http.NewRequest(http.MethodGet, "https://b.s3.eu-west-1.amazonaws.com/key", nil)
c2.signV4(req2, emptyPayloadHash)
// Different regions should produce different signatures
require.NotEqual(t, req1.Header.Get("Authorization"), req2.Header.Get("Authorization"))
}
func TestParseError_XMLResponse(t *testing.T) {
xmlBody := []byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`)
err := parseErrorFromBytes(404, xmlBody)
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
require.Equal(t, "NoSuchKey", errResp.Code)
require.Equal(t, "The specified key does not exist.", errResp.Message)
}
func TestParseError_NonXMLResponse(t *testing.T) {
err := parseErrorFromBytes(500, []byte("internal server error"))
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 500, errResp.StatusCode)
require.Equal(t, "", errResp.Code) // XML parsing failed, no code
require.Contains(t, errResp.Body, "internal server error")
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"unicode/utf8"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
@@ -39,7 +40,7 @@ type actionParser struct {
// parseActions parses the actions string as described in https://ntfy.sh/docs/publish/#action-buttons.
// It supports both a JSON representation (if the string begins with "[", see parseActionsFromJSON),
// and the "simple" format, which is more human-readable, but harder to parse (see parseActionsFromSimple).
func parseActions(s string) (actions []*action, err error) {
func parseActions(s string) (actions []*model.Action, err error) {
// Parse JSON or simple format
s = strings.TrimSpace(s)
if strings.HasPrefix(s, "[") {
@@ -80,8 +81,8 @@ func parseActions(s string) (actions []*action, err error) {
}
// parseActionsFromJSON converts a JSON array into an array of actions
func parseActionsFromJSON(s string) ([]*action, error) {
actions := make([]*action, 0)
func parseActionsFromJSON(s string) ([]*model.Action, error) {
actions := make([]*model.Action, 0)
if err := json.Unmarshal([]byte(s), &actions); err != nil {
return nil, fmt.Errorf("JSON error: %w", err)
}
@@ -107,7 +108,7 @@ func parseActionsFromJSON(s string) ([]*action, error) {
// https://github.com/adampresley/sample-ini-parser/blob/master/services/lexer/lexer/Lexer.go
// https://github.com/benbjohnson/sql-parser/blob/master/scanner.go
// https://blog.gopheracademy.com/advent-2014/parsers-lexers/
func parseActionsFromSimple(s string) ([]*action, error) {
func parseActionsFromSimple(s string) ([]*model.Action, error) {
if !utf8.ValidString(s) {
return nil, errors.New("invalid utf-8 string")
}
@@ -119,8 +120,8 @@ func parseActionsFromSimple(s string) ([]*action, error) {
}
// Parse loops trough parseAction() until the end of the string is reached
func (p *actionParser) Parse() ([]*action, error) {
actions := make([]*action, 0)
func (p *actionParser) Parse() ([]*model.Action, error) {
actions := make([]*model.Action, 0)
for !p.eof() {
a, err := p.parseAction()
if err != nil {
@@ -134,8 +135,8 @@ func (p *actionParser) Parse() ([]*action, error) {
// parseAction parses the individual sections of an action using parseSection into key/value pairs,
// and then uses populateAction to interpret the keys/values. The function terminates
// when EOF or ";" is reached.
func (p *actionParser) parseAction() (*action, error) {
a := newAction()
func (p *actionParser) parseAction() (*model.Action, error) {
a := model.NewAction()
section := 0
for {
key, value, last, err := p.parseSection()
@@ -155,7 +156,7 @@ func (p *actionParser) parseAction() (*action, error) {
// populateAction is the "business logic" of the parser. It applies the key/value
// pair to the action instance.
func populateAction(newAction *action, section int, key, value string) error {
func populateAction(newAction *model.Action, section int, key, value string) error {
// Auto-expand keys based on their index
if key == "" && section == 0 {
key = "action"

View File

@@ -95,6 +95,8 @@ type Config struct {
ListenUnixMode fs.FileMode
KeyFile string
CertFile string
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
DatabaseReplicaURLs []string // PostgreSQL read replica connection strings
FirebaseKeyFile string
CacheFile string
CacheDuration time.Duration
@@ -110,6 +112,7 @@ type Config struct {
AuthBcryptCost int
AuthStatsQueueWriterInterval time.Duration
AttachmentCacheDir string
AttachmentS3URL string
AttachmentTotalSizeLimit int64
AttachmentFileSizeLimit int64
AttachmentExpiryDuration time.Duration
@@ -199,6 +202,7 @@ func NewConfig() *Config {
ListenUnixMode: 0,
KeyFile: "",
CertFile: "",
DatabaseURL: "",
FirebaseKeyFile: "",
CacheFile: "",
CacheDuration: DefaultCacheDuration,

View File

@@ -142,6 +142,7 @@ var (
errHTTPBadRequestTemplateFileNotFound = &errHTTP{40047, http.StatusBadRequest, "invalid request: template file not found", "https://ntfy.sh/docs/publish/#message-templating", nil}
errHTTPBadRequestTemplateFileInvalid = &errHTTP{40048, http.StatusBadRequest, "invalid request: template file invalid", "https://ntfy.sh/docs/publish/#message-templating", nil}
errHTTPBadRequestSequenceIDInvalid = &errHTTP{40049, http.StatusBadRequest, "invalid request: sequence ID invalid", "https://ntfy.sh/docs/publish/#updating-deleting-notifications", nil}
errHTTPBadRequestEmailAddressInvalid = &errHTTP{40050, http.StatusBadRequest, "invalid request: invalid e-mail address", "https://ntfy.sh/docs/publish/#e-mail-notifications", nil}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", "", nil}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication", nil}
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication", nil}

View File

@@ -1,76 +0,0 @@
package server
import (
"bytes"
"fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/util"
"os"
"strings"
"testing"
)
var (
oneKilobyteArray = make([]byte, 1024)
)
func TestFileCache_Write_Success(t *testing.T) {
dir, c := newTestFileCache(t)
size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
require.Equal(t, int64(11), c.Size())
require.Equal(t, int64(10229), c.Remaining())
}
func TestFileCache_Write_Remove_Success(t *testing.T) {
dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024)
for i := 0; i < 10; i++ { // 10x999 = 9990
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
require.Nil(t, err)
require.Equal(t, int64(999), size)
}
require.Equal(t, int64(9990), c.Size())
require.Equal(t, int64(250), c.Remaining())
require.FileExists(t, dir+"/abcdefghijk1")
require.FileExists(t, dir+"/abcdefghijk5")
require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5"))
require.NoFileExists(t, dir+"/abcdefghijk1")
require.NoFileExists(t, dir+"/abcdefghijk5")
require.Equal(t, int64(7992), c.Size())
require.Equal(t, int64(2248), c.Remaining())
}
func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) {
dir, c := newTestFileCache(t)
for i := 0; i < 10; i++ {
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
require.Nil(t, err)
require.Equal(t, int64(1024), size)
}
_, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkX")
}
func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
dir, c := newTestFileCache(t)
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkl")
}
func newTestFileCache(t *testing.T) (dir string, cache *fileCache) {
dir = t.TempDir()
cache, err := newFileCache(dir, 10*1024)
require.Nil(t, err)
return dir, cache
}
func readFile(t *testing.T, f string) string {
b, err := os.ReadFile(f)
require.Nil(t, err)
return string(b)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/emersion/go-smtp"
"github.com/gorilla/websocket"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
@@ -23,7 +24,6 @@ const (
tagSMTP = "smtp" // Receive email
tagEmail = "email" // Send email
tagTwilio = "twilio"
tagFileCache = "file_cache"
tagMessageCache = "message_cache"
tagStripe = "stripe"
tagAccount = "account"
@@ -35,7 +35,7 @@ const (
)
var (
normalErrorCodes = []int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized, http.StatusForbidden, http.StatusInsufficientStorage}
normalErrorCodes = []int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized, http.StatusForbidden, http.StatusInsufficientStorage, http.StatusRequestEntityTooLarge}
rateLimitingErrorCodes = []int{http.StatusTooManyRequests, http.StatusRequestEntityTooLarge}
)
@@ -55,12 +55,12 @@ func logvr(v *visitor, r *http.Request) *log.Event {
}
// logvrm creates a new log event with HTTP request, visitor fields and message fields
func logvrm(v *visitor, r *http.Request, m *message) *log.Event {
func logvrm(v *visitor, r *http.Request, m *model.Message) *log.Event {
return logvr(v, r).With(m)
}
// logvrm creates a new log event with visitor fields and message fields
func logvm(v *visitor, m *message) *log.Event {
func logvm(v *visitor, m *model.Message) *log.Event {
return logv(v).With(m)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,825 +0,0 @@
package server
import (
"database/sql"
"fmt"
"github.com/stretchr/testify/assert"
"net/netip"
"path/filepath"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestSqliteCache_Messages(t *testing.T) {
testCacheMessages(t, newSqliteTestCache(t))
}
func TestMemCache_Messages(t *testing.T) {
testCacheMessages(t, newMemTestCache(t))
}
func testCacheMessages(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "my message")
m1.Time = 1
m2 := newDefaultMessage("mytopic", "my other message")
m2.Time = 2
require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message")))
require.Nil(t, c.AddMessage(m2))
// Adding invalid
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added!
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
// mytopic: count
counts, err := c.MessageCounts()
require.Nil(t, err)
require.Equal(t, 2, counts["mytopic"])
// mytopic: since all
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
require.Equal(t, 2, len(messages))
require.Equal(t, "my message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
require.Equal(t, messageEvent, messages[0].Event)
require.Equal(t, "", messages[0].Title)
require.Equal(t, 0, messages[0].Priority)
require.Nil(t, messages[0].Tags)
require.Equal(t, "my other message", messages[1].Message)
// mytopic: since none
messages, _ = c.Messages("mytopic", sinceNoMessages, false)
require.Empty(t, messages)
// mytopic: since m1 (by ID)
messages, _ = c.Messages("mytopic", newSinceID(m1.ID), false)
require.Equal(t, 1, len(messages))
require.Equal(t, m2.ID, messages[0].ID)
require.Equal(t, "my other message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
// mytopic: since 2
messages, _ = c.Messages("mytopic", newSinceTime(2), false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// mytopic: latest
messages, _ = c.Messages("mytopic", sinceLatestMessage, false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// example: count
counts, err = c.MessageCounts()
require.Nil(t, err)
require.Equal(t, 1, counts["example"])
// example: since all
messages, _ = c.Messages("example", sinceAllMessages, false)
require.Equal(t, "my example message", messages[0].Message)
// non-existing: count
counts, err = c.MessageCounts()
require.Nil(t, err)
require.Equal(t, 0, counts["doesnotexist"])
// non-existing: since all
messages, _ = c.Messages("doesnotexist", sinceAllMessages, false)
require.Empty(t, messages)
}
func TestSqliteCache_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newSqliteTestCache(t))
}
func TestMemCache_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newMemTestCache(t))
}
func testCacheMessagesLock(t *testing.T, c *messageCache) {
var wg sync.WaitGroup
for i := 0; i < 5000; i++ {
wg.Add(1)
go func() {
assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "test message")))
wg.Done()
}()
}
wg.Wait()
}
func TestSqliteCache_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newSqliteTestCache(t))
}
func TestMemCache_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newMemTestCache(t))
}
func testCacheMessagesScheduled(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "message 1")
m2 := newDefaultMessage("mytopic", "message 2")
m2.Time = time.Now().Add(time.Hour).Unix()
m3 := newDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
m4 := newDefaultMessage("mytopic2", "message 4")
m4.Time = time.Now().Add(time.Minute).Unix()
require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(m2))
require.Nil(t, c.AddMessage(m3))
messages, _ := c.Messages("mytopic", sinceAllMessages, false) // exclude scheduled
require.Equal(t, 1, len(messages))
require.Equal(t, "message 1", messages[0].Message)
messages, _ = c.Messages("mytopic", sinceAllMessages, true) // include scheduled
require.Equal(t, 3, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message) // Order!
require.Equal(t, "message 2", messages[2].Message)
messages, _ = c.MessagesDue()
require.Empty(t, messages)
}
func TestSqliteCache_Topics(t *testing.T) {
testCacheTopics(t, newSqliteTestCache(t))
}
func TestMemCache_Topics(t *testing.T) {
testCacheTopics(t, newMemTestCache(t))
}
func testCacheTopics(t *testing.T, c *messageCache) {
require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3")))
topics, err := c.Topics()
if err != nil {
t.Fatal(err)
}
require.Equal(t, 2, len(topics))
require.Equal(t, "topic1", topics["topic1"].ID)
require.Equal(t, "topic2", topics["topic2"].ID)
}
func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t))
}
func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t))
}
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c *messageCache) {
m := newDefaultMessage("mytopic", "some message")
m.Tags = []string{"tag1", "tag2"}
m.Priority = 5
m.Title = "some title"
require.Nil(t, c.AddMessage(m))
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
require.Equal(t, 5, messages[0].Priority)
require.Equal(t, "some title", messages[0].Title)
}
func TestSqliteCache_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newSqliteTestCache(t))
}
func TestMemCache_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newMemTestCache(t))
}
func testCacheMessagesSinceID(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "message 1")
m1.Time = 100
m2 := newDefaultMessage("mytopic", "message 2")
m2.Time = 200
m3 := newDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
m4 := newDefaultMessage("mytopic", "message 4")
m4.Time = 400
m5 := newDefaultMessage("mytopic", "message 5")
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
m6 := newDefaultMessage("mytopic", "message 6")
m6.Time = 600
m7 := newDefaultMessage("mytopic", "message 7")
m7.Time = 700
require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(m2))
require.Nil(t, c.AddMessage(m3))
require.Nil(t, c.AddMessage(m4))
require.Nil(t, c.AddMessage(m5))
require.Nil(t, c.AddMessage(m6))
require.Nil(t, c.AddMessage(m7))
// Case 1: Since ID exists, exclude scheduled
messages, _ := c.Messages("mytopic", newSinceID(m2.ID), false)
require.Equal(t, 3, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
require.Equal(t, "message 7", messages[2].Message)
// Case 2: Since ID exists, include scheduled
messages, _ = c.Messages("mytopic", newSinceID(m2.ID), true)
require.Equal(t, 5, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message)
require.Equal(t, "message 7", messages[2].Message)
require.Equal(t, "message 5", messages[3].Message) // Order!
require.Equal(t, "message 3", messages[4].Message) // Order!
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
messages, _ = c.Messages("mytopic", newSinceID("doesntexist"), true)
require.Equal(t, 7, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 2", messages[1].Message)
require.Equal(t, "message 4", messages[2].Message)
require.Equal(t, "message 6", messages[3].Message)
require.Equal(t, "message 7", messages[4].Message)
require.Equal(t, "message 5", messages[5].Message) // Order!
require.Equal(t, "message 3", messages[6].Message) // Order!
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
messages, _ = c.Messages("mytopic", newSinceID(m7.ID), false)
require.Equal(t, 0, len(messages))
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
messages, _ = c.Messages("mytopic", newSinceID(m7.ID), true)
require.Equal(t, 2, len(messages))
require.Equal(t, "message 5", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message)
}
func TestSqliteCache_Prune(t *testing.T) {
testCachePrune(t, newSqliteTestCache(t))
}
func TestMemCache_Prune(t *testing.T) {
testCachePrune(t, newMemTestCache(t))
}
func testCachePrune(t *testing.T, c *messageCache) {
now := time.Now().Unix()
m1 := newDefaultMessage("mytopic", "my message")
m1.Time = now - 10
m1.Expires = now - 5
m2 := newDefaultMessage("mytopic", "my other message")
m2.Time = now - 5
m2.Expires = now + 5 // In the future
m3 := newDefaultMessage("another_topic", "and another one")
m3.Time = now - 12
m3.Expires = now - 2
require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(m2))
require.Nil(t, c.AddMessage(m3))
counts, err := c.MessageCounts()
require.Nil(t, err)
require.Equal(t, 2, counts["mytopic"])
require.Equal(t, 1, counts["another_topic"])
expiredMessageIDs, err := c.MessagesExpired()
require.Nil(t, err)
require.Nil(t, c.DeleteMessages(expiredMessageIDs...))
counts, err = c.MessageCounts()
require.Nil(t, err)
require.Equal(t, 1, counts["mytopic"])
require.Equal(t, 0, counts["another_topic"])
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
}
func TestSqliteCache_Attachments(t *testing.T) {
testCacheAttachments(t, newSqliteTestCache(t))
}
func TestMemCache_Attachments(t *testing.T) {
testCacheAttachments(t, newMemTestCache(t))
}
func testCacheAttachments(t *testing.T, c *messageCache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &attachment{
Name: "flower.jpg",
Type: "image/jpeg",
Size: 5000,
Expires: expires1,
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
}
require.Nil(t, c.AddMessage(m))
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = newDefaultMessage("mytopic", "sending you a car")
m.ID = "m2"
m.SequenceID = "m2"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: expires2,
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, c.AddMessage(m))
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = newDefaultMessage("another-topic", "sending you another car")
m.ID = "m3"
m.SequenceID = "m3"
m.User = "u_BAsbaAa"
m.Sender = netip.MustParseAddr("5.6.7.8")
m.Attachment = &attachment{
Name: "another-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: expires3,
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
}
require.Nil(t, c.AddMessage(m))
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, "flower for you", messages[0].Message)
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
require.Equal(t, int64(5000), messages[0].Attachment.Size)
require.Equal(t, expires1, messages[0].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
require.Equal(t, "sending you a car", messages[1].Message)
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
require.Equal(t, int64(10000), messages[1].Attachment.Size)
require.Equal(t, expires2, messages[1].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
size, err := c.AttachmentBytesUsedBySender("1.2.3.4")
require.Nil(t, err)
require.Equal(t, int64(10000), size)
size, err = c.AttachmentBytesUsedBySender("5.6.7.8")
require.Nil(t, err)
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
size, err = c.AttachmentBytesUsedByUser("u_BAsbaAa")
require.Nil(t, err)
require.Equal(t, int64(20000), size)
}
func TestSqliteCache_Attachments_Expired(t *testing.T) {
testCacheAttachmentsExpired(t, newSqliteTestCache(t))
}
func TestMemCache_Attachments_Expired(t *testing.T) {
testCacheAttachmentsExpired(t, newMemTestCache(t))
}
func testCacheAttachmentsExpired(t *testing.T, c *messageCache) {
m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic", "message with attachment")
m.ID = "m2"
m.SequenceID = "m2"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: time.Now().Add(2 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic", "message with external attachment")
m.ID = "m3"
m.SequenceID = "m3"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Expires: 0, // Unknown!
URL: "https://somedomain.com/car.jpg",
}
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic2", "message with expired attachment")
m.ID = "m4"
m.SequenceID = "m4"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "expired-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: time.Now().Add(-1 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, c.AddMessage(m))
ids, err := c.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "m4", ids[0])
}
func TestSqliteCache_Migration_From0(t *testing.T) {
filename := newSqliteTestCacheFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 0" schema
_, err = db.Exec(`
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id VARCHAR(20) PRIMARY KEY,
time INT NOT NULL,
topic VARCHAR(64) NOT NULL,
message VARCHAR(1024) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
COMMIT;
`)
require.Nil(t, err)
// Insert a bunch of messages
for i := 0; i < 10; i++ {
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`,
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i))
require.Nil(t, err)
}
require.Nil(t, db.Close())
// Create cache to trigger migration
c := newSqliteTestCacheFromFile(t, filename, "")
checkSchemaVersion(t, c.db)
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
require.Equal(t, "some message 5", messages[5].Message)
require.Equal(t, "", messages[5].Title)
require.Nil(t, messages[5].Tags)
require.Equal(t, 0, messages[5].Priority)
}
func TestSqliteCache_Migration_From1(t *testing.T) {
filename := newSqliteTestCacheFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 1" schema
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS messages (
id VARCHAR(20) PRIMARY KEY,
time INT NOT NULL,
topic VARCHAR(64) NOT NULL,
message VARCHAR(512) NOT NULL,
title VARCHAR(256) NOT NULL,
priority INT NOT NULL,
tags VARCHAR(256) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
`)
require.Nil(t, err)
// Insert a bunch of messages
for i := 0; i < 10; i++ {
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`,
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "")
require.Nil(t, err)
}
require.Nil(t, db.Close())
// Create cache to trigger migration
c := newSqliteTestCacheFromFile(t, filename, "")
checkSchemaVersion(t, c.db)
// Add delayed message
delayedMessage := newDefaultMessage("mytopic", "some delayed message")
delayedMessage.Time = time.Now().Add(time.Minute).Unix()
require.Nil(t, c.AddMessage(delayedMessage))
// 10, not 11!
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
// 11!
messages, err = c.Messages("mytopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 11, len(messages))
// Check that index "idx_topic" exists
rows, err := c.db.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
require.Nil(t, err)
require.True(t, rows.Next())
var indexName string
require.Nil(t, rows.Scan(&indexName))
require.Equal(t, "idx_topic", indexName)
}
func TestSqliteCache_Migration_From9(t *testing.T) {
// This primarily tests the awkward migration that introduces the "expires" column.
// The migration logic has to update the column, using the existing "cache-duration" value.
filename := newSqliteTestCacheFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 8" schema
_, err = db.Exec(`
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
sender TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schemaVersion (id, version) VALUES (1, 9);
COMMIT;
`)
require.Nil(t, err)
// Insert a bunch of messages
insertQuery := `
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
for i := 0; i < 10; i++ {
_, err = db.Exec(
insertQuery,
fmt.Sprintf("abcd%d", i),
time.Now().Unix(),
"mytopic",
fmt.Sprintf("some message %d", i),
"", // title
0, // priority
"", // tags
"", // click
"", // icon
"", // actions
"", // attachment_name
"", // attachment_type
0, // attachment_size
0, // attachment_type
"", // attachment_url
"9.9.9.9", // sender
"", // encoding
1, // published
)
require.Nil(t, err)
}
// Create cache to trigger migration
cacheDuration := 17 * time.Hour
c, err := newSqliteCache(filename, "", cacheDuration, 0, 0, false)
require.Nil(t, err)
checkSchemaVersion(t, c.db)
// Check version
rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`)
require.Nil(t, err)
require.True(t, rows.Next())
var version int
require.Nil(t, rows.Scan(&version))
require.Equal(t, currentSchemaVersion, version)
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
for _, m := range messages {
require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix())
require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix())
}
}
func TestSqliteCache_StartupQueries_WAL(t *testing.T) {
filename := newSqliteTestCacheFile(t)
startupQueries := `pragma journal_mode = WAL;
pragma synchronous = normal;
pragma temp_store = memory;`
db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message")))
require.FileExists(t, filename)
require.FileExists(t, filename+"-wal")
require.FileExists(t, filename+"-shm")
}
func TestSqliteCache_StartupQueries_None(t *testing.T) {
filename := newSqliteTestCacheFile(t)
startupQueries := ""
db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message")))
require.FileExists(t, filename)
require.NoFileExists(t, filename+"-wal")
require.NoFileExists(t, filename+"-shm")
}
func TestSqliteCache_StartupQueries_Fail(t *testing.T) {
filename := newSqliteTestCacheFile(t)
startupQueries := `xx error`
_, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
require.Error(t, err)
}
func TestSqliteCache_Sender(t *testing.T) {
testSender(t, newSqliteTestCache(t))
}
func TestMemCache_Sender(t *testing.T) {
testSender(t, newMemTestCache(t))
}
func testSender(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "mymessage")
m1.Sender = netip.MustParseAddr("1.2.3.4")
require.Nil(t, c.AddMessage(m1))
m2 := newDefaultMessage("mytopic", "mymessage without sender")
require.Nil(t, c.AddMessage(m2))
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
require.Equal(t, messages[1].Sender, netip.Addr{})
}
func TestSqliteCache_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newSqliteTestCache(t))
}
func TestMemCache_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newMemTestCache(t))
}
func testDeleteScheduledBySequenceID(t *testing.T, c *messageCache) {
// Create a scheduled (unpublished) message
scheduledMsg := newDefaultMessage("mytopic", "scheduled message")
scheduledMsg.ID = "scheduled1"
scheduledMsg.SequenceID = "seq123"
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
require.Nil(t, c.AddMessage(scheduledMsg))
// Create a published message with different sequence ID
publishedMsg := newDefaultMessage("mytopic", "published message")
publishedMsg.ID = "published1"
publishedMsg.SequenceID = "seq456"
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
require.Nil(t, c.AddMessage(publishedMsg))
// Create a scheduled message in a different topic
otherTopicMsg := newDefaultMessage("othertopic", "other scheduled")
otherTopicMsg.ID = "other1"
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, c.AddMessage(otherTopicMsg))
// Verify all messages exist (including scheduled)
messages, err := c.Messages("mytopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = c.Messages("othertopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Delete scheduled message by sequence ID and verify returned IDs
deletedIDs, err := c.DeleteScheduledBySequenceID("mytopic", "seq123")
require.Nil(t, err)
require.Equal(t, 1, len(deletedIDs))
require.Equal(t, "scheduled1", deletedIDs[0])
// Verify scheduled message is deleted
messages, err = c.Messages("mytopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
// Verify other topic's message still exists (topic-scoped deletion)
messages, err = c.Messages("othertopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "other scheduled", messages[0].Message)
// Deleting non-existent sequence ID should return empty list
deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "nonexistent")
require.Nil(t, err)
require.Empty(t, deletedIDs)
// Deleting published message should not affect it (only deletes unpublished)
deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "seq456")
require.Nil(t, err)
require.Empty(t, deletedIDs)
messages, err = c.Messages("mytopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
}
func checkSchemaVersion(t *testing.T, db *sql.DB) {
rows, err := db.Query(`SELECT version FROM schemaVersion`)
require.Nil(t, err)
require.True(t, rows.Next())
var schemaVersion int
require.Nil(t, rows.Scan(&schemaVersion))
require.Equal(t, currentSchemaVersion, schemaVersion)
require.Nil(t, rows.Close())
}
func TestMemCache_NopCache(t *testing.T) {
c, _ := newNopCache()
require.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message")))
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Empty(t, messages)
topics, err := c.Topics()
require.Nil(t, err)
require.Empty(t, topics)
}
func newSqliteTestCache(t *testing.T) *messageCache {
c, err := newSqliteCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false)
if err != nil {
t.Fatal(err)
}
return c
}
func newSqliteTestCacheFile(t *testing.T) string {
return filepath.Join(t.TempDir(), "cache.db")
}
func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *messageCache {
c, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
return c
}
func newMemTestCache(t *testing.T) *messageCache {
c, err := newMemCache()
require.Nil(t, err)
return c
}

View File

@@ -32,16 +32,23 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/attachment"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/payments"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
"heckel.io/ntfy/v2/util/sprig"
"heckel.io/ntfy/v2/webpush"
)
// Server is the main server, providing the UI and API for ntfy
type Server struct {
config *Config
db *db.DB // Shared PostgreSQL connection pool (with optional replicas), nil when using SQLite
httpServer *http.Server
httpsServer *http.Server
httpMetricsServer *http.Server
@@ -56,9 +63,9 @@ type Server struct {
messages int64 // Total number of messages (persisted if messageCache enabled)
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages
webPush *webPushStore // Database that stores web push subscriptions
fileCache *fileCache // File system based cache that stores attachments
messageCache *message.Cache // Database that stores the messages
webPush *webpush.Store // Database that stores web push subscriptions
fileCache attachment.Store // Attachment store (file system or S3)
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set
@@ -90,6 +97,7 @@ var (
matrixPushPath = "/_matrix/push/v1/notify"
metricsPath = "/metrics"
apiHealthPath = "/v1/health"
apiVersionPath = "/v1/version"
apiConfigPath = "/v1/config"
apiStatsPath = "/v1/stats"
apiWebPushPath = "/v1/webpush"
@@ -115,6 +123,7 @@ var (
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
urlRegex = regexp.MustCompile(`^https?://`)
phoneNumberRegex = regexp.MustCompile(`^\+\d{1,100}$`)
emailAddressRegex = regexp.MustCompile(`^[^\s,;]+@[^\s,;]+$`)
//go:embed site
webFs embed.FS
@@ -171,36 +180,64 @@ func New(conf *Config) (*Server, error) {
if payments.Available && conf.StripeSecretKey != "" {
stripe = newStripeAPI()
}
messageCache, err := createMessageCache(conf)
// Open shared PostgreSQL connection pool if configured
var pool *db.DB
if conf.DatabaseURL != "" {
primary, err := pg.Open(conf.DatabaseURL)
if err != nil {
return nil, err
}
var replicas []*db.Host
for _, replicaURL := range conf.DatabaseReplicaURLs {
r, err := pg.OpenReplica(replicaURL)
if err != nil {
// Close already-opened replicas before returning
for _, opened := range replicas {
opened.DB.Close()
}
primary.DB.Close()
return nil, fmt.Errorf("failed to open database replica: %w", err)
}
replicas = append(replicas, r)
}
pool = db.New(primary, replicas)
}
messageCache, err := createMessageCache(conf, pool)
if err != nil {
return nil, err
}
var webPush *webPushStore
var wp *webpush.Store
if conf.WebPushPublicKey != "" {
webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries)
if pool != nil {
wp, err = webpush.NewPostgresStore(pool)
} else {
wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries)
}
if err != nil {
return nil, err
}
}
topics, err := messageCache.Topics()
topicIDs, err := messageCache.Topics()
if err != nil {
return nil, err
}
topics := make(map[string]*topic, len(topicIDs))
for _, id := range topicIDs {
topics[id] = newTopic(id)
}
messages, err := messageCache.Stats()
if err != nil {
return nil, err
}
var fileCache *fileCache
if conf.AttachmentCacheDir != "" {
fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
if err != nil {
return nil, err
}
fileCache, err := createAttachmentStore(conf)
if err != nil {
return nil, err
}
var userManager *user.Manager
if conf.AuthFile != "" {
if conf.AuthFile != "" || pool != nil {
authConfig := &user.Config{
Filename: conf.AuthFile,
DatabaseURL: conf.DatabaseURL,
StartupQueries: conf.AuthStartupQueries,
DefaultAccess: conf.AuthDefault,
ProvisionEnabled: true, // Enable provisioning of users and access
@@ -210,7 +247,11 @@ func New(conf *Config) (*Server, error) {
BcryptCost: conf.AuthBcryptCost,
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
}
userManager, err = user.NewManager(authConfig)
if pool != nil {
userManager, err = user.NewPostgresManager(pool, authConfig)
} else {
userManager, err = user.NewSQLiteManager(conf.AuthFile, conf.AuthStartupQueries, authConfig)
}
if err != nil {
return nil, err
}
@@ -231,8 +272,9 @@ func New(conf *Config) (*Server, error) {
}
s := &Server{
config: conf,
db: pool,
messageCache: messageCache,
webPush: webPush,
webPush: wp,
fileCache: fileCache,
firebaseClient: firebaseClient,
smtpSender: mailer,
@@ -247,13 +289,24 @@ func New(conf *Config) (*Server, error) {
return s, nil
}
func createMessageCache(conf *Config) (*messageCache, error) {
func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) {
if conf.CacheDuration == 0 {
return newNopCache()
return message.NewNopStore()
} else if pool != nil {
return message.NewPostgresStore(pool, conf.CacheBatchSize, conf.CacheBatchTimeout)
} else if conf.CacheFile != "" {
return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
}
return newMemCache()
return message.NewMemStore()
}
func createAttachmentStore(conf *Config) (attachment.Store, error) {
if conf.AttachmentS3URL != "" {
return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit)
} else if conf.AttachmentCacheDir != "" {
return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
}
return nil, nil
}
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
@@ -388,6 +441,9 @@ func (s *Server) closeDatabases() {
if s.webPush != nil {
s.webPush.Close()
}
if s.db != nil {
s.db.Close()
}
}
// handle is the main entry point for all HTTP requests
@@ -467,6 +523,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
return s.handleHealth(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiVersionPath {
return s.ensureAdmin(s.handleVersion)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiConfigPath {
return s.handleConfig(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
@@ -702,7 +760,7 @@ func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor)
// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
// can associate the download bandwidth with the uploader.
func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.AttachmentCacheDir == "" {
if s.fileCache == nil {
return errHTTPInternalError
}
matches := fileRegex.FindStringSubmatch(r.URL.Path)
@@ -710,16 +768,16 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
return errHTTPInternalErrorInvalidPath
}
messageID := matches[1]
file := filepath.Join(s.config.AttachmentCacheDir, messageID)
stat, err := os.Stat(file)
reader, size, err := s.fileCache.Read(messageID)
if err != nil {
return errHTTPNotFound.Fields(log.Context{
"message_id": messageID,
"error_context": "filesystem",
"error_context": "attachment_store",
})
}
defer reader.Close()
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
w.Header().Set("Content-Length", fmt.Sprintf("%d", size))
if r.Method == http.MethodHead {
return nil
}
@@ -728,11 +786,11 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
// - avoid abuse (e.g. 1 uploader, 1k downloaders)
// - and also uses the higher bandwidth limits of a paying user
m, err := s.messageCache.Message(messageID)
if errors.Is(err, errMessageNotFound) {
if errors.Is(err, model.ErrMessageNotFound) {
if s.config.CacheBatchTimeout > 0 {
// Strange edge case: If we immediately after upload request the file (the web app does this for images),
// and messages are persisted asynchronously, retry fetching from the database
m, err = util.Retry(func() (*message, error) {
m, err = util.Retry(func() (*model.Message, error) {
return s.messageCache.Message(messageID)
}, s.config.CacheBatchTimeout, 100*time.Millisecond, 300*time.Millisecond, 600*time.Millisecond)
}
@@ -755,19 +813,14 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
} else if m.Sender.IsValid() {
bandwidthVisitor = s.visitor(m.Sender, nil)
}
if !bandwidthVisitor.BandwidthAllowed(stat.Size()) {
if !bandwidthVisitor.BandwidthAllowed(size) {
return errHTTPTooManyRequestsLimitAttachmentBandwidth.With(m)
}
// Actually send file
f, err := os.Open(file)
if err != nil {
return err
}
defer f.Close()
if m.Attachment.Name != "" {
w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(m.Attachment.Name))
}
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), reader)
return err
}
@@ -778,7 +831,7 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
return writeMatrixDiscoveryResponse(w)
}
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) {
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Message, error) {
start := time.Now()
t, err := fromContext[*topic](r, contextTopic)
if err != nil {
@@ -792,7 +845,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
if err != nil {
return nil, err
}
m := newDefaultMessage(t.ID, "")
m := model.NewDefaultMessage(t.ID, "")
cache, firebase, email, call, template, unifiedpush, priorityStr, e := s.parsePublishParams(r, m)
if e != nil {
return nil, e.With(t)
@@ -817,7 +870,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
}
}
if m.PollID != "" {
m = newPollRequestMessage(t.ID, m.PollID)
m = model.NewPollRequestMessage(t.ID, m.PollID)
}
m.Sender = v.IP()
m.User = v.MaybeUserID()
@@ -830,6 +883,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
if m.Message == "" {
m.Message = emptyMessageBody
}
m.SanitizeUTF8()
delayed := m.Time > time.Now().Unix()
ev := logvrm(v, r, m).
Tag(tagPublish).
@@ -906,7 +960,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
return err
}
minc(metricMessagesPublishedSuccess)
return s.writeJSON(w, m.forJSON())
return s.writeJSON(w, m.ForJSON())
}
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
@@ -935,11 +989,11 @@ func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *
}
func (s *Server) handleDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
return s.handleActionMessage(w, r, v, messageDeleteEvent)
return s.handleActionMessage(w, r, v, model.MessageDeleteEvent)
}
func (s *Server) handleClear(w http.ResponseWriter, r *http.Request, v *visitor) error {
return s.handleActionMessage(w, r, v, messageClearEvent)
return s.handleActionMessage(w, r, v, model.MessageClearEvent)
}
func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *visitor, event string) error {
@@ -959,7 +1013,7 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
return e.With(t)
}
// Create an action message with the given event type
m := newActionMessage(event, t.ID, sequenceID)
m := model.NewActionMessage(event, t.ID, sequenceID)
m.Sender = v.IP()
m.User = v.MaybeUserID()
m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
@@ -975,7 +1029,7 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
if s.config.WebPushPublicKey != "" {
go s.publishToWebPushEndpoints(v, m)
}
if event == messageDeleteEvent {
if event == model.MessageDeleteEvent {
// Delete any existing scheduled message with the same sequence ID
deletedIDs, err := s.messageCache.DeleteScheduledBySequenceID(t.ID, sequenceID)
if err != nil {
@@ -996,10 +1050,10 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
s.mu.Lock()
s.messages++
s.mu.Unlock()
return s.writeJSON(w, m.forJSON())
return s.writeJSON(w, m.ForJSON())
}
func (s *Server) sendToFirebase(v *visitor, m *message) {
func (s *Server) sendToFirebase(v *visitor, m *model.Message) {
logvm(v, m).Tag(tagFirebase).Debug("Publishing to Firebase")
if err := s.firebaseClient.Send(v, m); err != nil {
minc(metricFirebasePublishedFailure)
@@ -1013,7 +1067,7 @@ func (s *Server) sendToFirebase(v *visitor, m *message) {
minc(metricFirebasePublishedSuccess)
}
func (s *Server) sendEmail(v *visitor, m *message, email string) {
func (s *Server) sendEmail(v *visitor, m *model.Message, email string) {
logvm(v, m).Tag(tagEmail).Field("email", email).Debug("Sending email to %s", email)
if err := s.smtpSender.Send(v, m, email); err != nil {
logvm(v, m).Tag(tagEmail).Field("email", email).Err(err).Warn("Unable to send email to %s: %v", email, err.Error())
@@ -1023,7 +1077,7 @@ func (s *Server) sendEmail(v *visitor, m *message, email string) {
minc(metricEmailsPublishedSuccess)
}
func (s *Server) forwardPollRequest(v *visitor, m *message) {
func (s *Server) forwardPollRequest(v *visitor, m *model.Message) {
topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
@@ -1055,7 +1109,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
}
}
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) {
func (s *Server) parsePublishParams(r *http.Request, m *model.Message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) {
if r.Method != http.MethodGet && updatePathRegex.MatchString(r.URL.Path) {
pathSequenceID, err := s.sequenceIDFromPath(r.URL.Path)
if err != nil {
@@ -1082,7 +1136,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
filename := readParam(r, "x-filename", "filename", "file", "f")
attach := readParam(r, "x-attach", "attach", "a")
if attach != "" || filename != "" {
m.Attachment = &attachment{}
m.Attachment = &model.Attachment{}
}
if filename != "" {
m.Attachment.Name = filename
@@ -1112,6 +1166,9 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
m.Icon = icon
}
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
if email != "" && !emailAddressRegex.MatchString(email) {
return false, false, "", "", "", false, "", errHTTPBadRequestEmailAddressInvalid
}
if s.smtpSender == nil && email != "" {
return false, false, "", "", "", false, "", errHTTPBadRequestEmailDisabled
}
@@ -1203,8 +1260,8 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
// If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
// 7. curl -T file.txt ntfy.sh/mytopic
// In all other cases, mostly if file.txt is > message limit, treat it as an attachment
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
if m.Event == pollRequestEvent { // Case 1
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
if m.Event == model.PollRequestEvent { // Case 1
return s.handleBodyDiscard(body)
} else if unifiedpush {
return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
@@ -1226,7 +1283,7 @@ func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error {
return err
}
func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
func (s *Server) handleBodyAsMessageAutoDetect(m *model.Message, body *util.PeekedReadCloser) error {
if utf8.Valid(body.PeekedBytes) {
m.Message = string(body.PeekedBytes) // Do not trim
} else {
@@ -1236,7 +1293,7 @@ func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedRead
return nil
}
func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
func (s *Server) handleBodyAsTextMessage(m *model.Message, body *util.PeekedReadCloser) error {
if !utf8.Valid(body.PeekedBytes) {
return errHTTPBadRequestMessageNotUTF8.With(m)
}
@@ -1249,7 +1306,7 @@ func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser
return nil
}
func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error {
func (s *Server) handleBodyAsTemplatedTextMessage(m *model.Message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error {
body, err := util.Peek(body, max(s.config.MessageSizeLimit, jsonBodyBytesLimit))
if err != nil {
return err
@@ -1274,7 +1331,7 @@ func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateM
// renderTemplateFromFile transforms the JSON message body according to a template from the filesystem.
// The template file must be in the templates directory, or in the configured template directory.
func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody string) error {
func (s *Server) renderTemplateFromFile(m *model.Message, templateName, peekedBody string) error {
if !templateNameRegex.MatchString(templateName) {
return errHTTPBadRequestTemplateFileNotFound
}
@@ -1316,7 +1373,7 @@ func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody str
// renderTemplateFromParams transforms the JSON message body according to the inline template in the
// message, title, and priority parameters.
func (s *Server) renderTemplateFromParams(m *message, peekedBody string, priorityStr string) error {
func (s *Server) renderTemplateFromParams(m *model.Message, peekedBody string, priorityStr string) error {
var err error
if m.Message, err = s.renderTemplate("priority query parameter", m.Message, peekedBody); err != nil {
return err
@@ -1357,8 +1414,8 @@ func (s *Server) renderTemplate(name, tpl, source string) (string, error) {
return strings.TrimSpace(strings.ReplaceAll(buf.String(), "\\n", "\n")), nil // replace any remaining "\n" (those outside of template curly braces) with newlines
}
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser) error {
if s.fileCache == nil || s.config.BaseURL == "" {
return errHTTPBadRequestAttachmentsDisallowed.With(m)
}
vinfo, err := v.Info()
@@ -1381,7 +1438,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
}
}
if m.Attachment == nil {
m.Attachment = &attachment{}
m.Attachment = &model.Attachment{}
}
var ext string
m.Attachment.Expires = attachmentExpiry
@@ -1408,9 +1465,9 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
}
func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) {
encoder := func(msg *model.Message) (string, error) {
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
return "", err
}
return buf.String(), nil
@@ -1419,12 +1476,12 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *
}
func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) {
encoder := func(msg *model.Message) (string, error) {
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
return "", err
}
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
if msg.Event != model.MessageEvent && msg.Event != model.MessageDeleteEvent && msg.Event != model.MessageClearEvent {
return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
}
return fmt.Sprintf("data: %s\n", buf.String()), nil
@@ -1433,8 +1490,8 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v
}
func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) {
if msg.Event == messageEvent { // only handle default events
encoder := func(msg *model.Message) (string, error) {
if msg.Event == model.MessageEvent { // only handle default events
return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
}
return "\n", nil // "keepalive" and "open" events just send an empty line
@@ -1469,7 +1526,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
closed = true
wlock.Unlock()
}()
sub := func(v *visitor, msg *message) error {
sub := func(v *visitor, msg *model.Message) error {
if !filters.Pass(msg) {
return nil
}
@@ -1512,7 +1569,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
topics[i].Unsubscribe(subscriberID) // Order!
}
}()
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
if err := sub(v, model.NewOpenMessage(topicsStr)); err != nil { // Send out open message
return err
}
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
@@ -1535,7 +1592,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
for _, t := range topics {
t.Keepalive()
}
if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
if err := sub(v, model.NewKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
return err
}
}
@@ -1631,7 +1688,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
}
})
sub := func(v *visitor, msg *message) error {
sub := func(v *visitor, msg *model.Message) error {
if !filters.Pass(msg) {
return nil
}
@@ -1661,7 +1718,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
topics[i].Unsubscribe(subscriberID) // Order!
}
}()
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
if err := sub(v, model.NewOpenMessage(topicsStr)); err != nil { // Send out open message
return err
}
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
@@ -1678,7 +1735,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
return nil
}
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
func parseSubscribeParams(r *http.Request) (poll bool, since model.SinceMarker, scheduled bool, filters *queryFilter, err error) {
poll = readBoolParam(r, false, "x-poll", "poll", "po")
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
since, err = parseSince(r, poll)
@@ -1759,11 +1816,11 @@ func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topi
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
// marker, returning only messages that are newer than the marker.
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
func (s *Server) sendOldMessages(topics []*topic, since model.SinceMarker, scheduled bool, v *visitor, sub subscriber) error {
if since.IsNone() {
return nil
}
messages := make([]*message, 0)
messages := make([]*model.Message, 0)
for _, t := range topics {
topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
if err != nil {
@@ -1786,32 +1843,32 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
//
// Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h),
// "all" for all messages, or "latest" for the most recent message for a topic
func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
func parseSince(r *http.Request, poll bool) (model.SinceMarker, error) {
since := readParam(r, "x-since", "since", "si")
// Easy cases (empty, all, none)
if since == "" {
if poll {
return sinceAllMessages, nil
return model.SinceAllMessages, nil
}
return sinceNoMessages, nil
return model.SinceNoMessages, nil
} else if since == "all" {
return sinceAllMessages, nil
return model.SinceAllMessages, nil
} else if since == "latest" {
return sinceLatestMessage, nil
return model.SinceLatestMessage, nil
} else if since == "none" {
return sinceNoMessages, nil
return model.SinceNoMessages, nil
}
// ID, timestamp, duration
if validMessageID(since) {
return newSinceID(since), nil
if model.ValidMessageID(since) {
return model.NewSinceID(since), nil
} else if s, err := strconv.ParseInt(since, 10, 64); err == nil {
return newSinceTime(s), nil
return model.NewSinceTime(s), nil
} else if d, err := time.ParseDuration(since); err == nil {
return newSinceTime(time.Now().Add(-1 * d).Unix()), nil
return model.NewSinceTime(time.Now().Add(-1 * d).Unix()), nil
}
return sinceNoMessages, errHTTPBadRequestSinceInvalid
return model.SinceNoMessages, errHTTPBadRequestSinceInvalid
}
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
@@ -1967,14 +2024,14 @@ func (s *Server) runFirebaseKeepaliver() {
for {
select {
case <-time.After(s.config.FirebaseKeepaliveInterval):
s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
s.sendToFirebase(v, model.NewKeepaliveMessage(firebaseControlTopic))
/*
FIXME: Disable iOS polling entirely for now due to thundering herd problem (see #677)
To solve this, we'd have to shard the iOS poll topics to spread out the polling evenly.
Given that it's not really necessary to poll, turning it off for now should not have any impact.
case <-time.After(s.config.FirebasePollInterval):
s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
s.sendToFirebase(v, model.NewKeepaliveMessage(firebasePollTopic))
*/
case <-s.closeChan:
return
@@ -2017,7 +2074,7 @@ func (s *Server) sendDelayedMessages() error {
return nil
}
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
func (s *Server) sendDelayedMessage(v *visitor, m *model.Message) error {
logvm(v, m).Debug("Sending delayed message")
s.mu.RLock()
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
@@ -2290,9 +2347,7 @@ func (s *Server) updateAndWriteStats(messagesCount int64) {
s.messagesHistory = s.messagesHistory[1:]
}
s.mu.Unlock()
go func() {
if err := s.messageCache.UpdateStats(messagesCount); err != nil {
log.Tag(tagManager).Err(err).Warn("Cannot write messages stats")
}
}()
if err := s.messageCache.UpdateStats(messagesCount); err != nil {
log.Tag(tagManager).Err(err).Warn("Cannot write messages stats")
}
}

View File

@@ -38,8 +38,32 @@
#
# firebase-key-file: <filename>
# If "database-url" is set, ntfy will use PostgreSQL for all database-backed stores (message cache,
# user manager, and web push subscriptions) instead of SQLite. When set, the "cache-file",
# "auth-file", and "web-push-file" options must not be set.
#
# Note: Setting "database-url" implicitly enables authentication and access control.
# The default access is "read-write" (see "auth-default-access").
#
# The URL supports standard PostgreSQL parameters (sslmode, connect_timeout, sslcert, etc.),
# as well as ntfy-specific connection pool parameters:
# pool_max_conns=10 - Maximum number of open connections (default: 10)
# pool_max_idle_conns=N - Maximum number of idle connections
# pool_conn_max_lifetime=5m - Maximum lifetime of a connection (Go duration)
# pool_conn_max_idle_time=1m - Maximum idle time of a connection (Go duration)
#
# See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
# for the full list of supported PostgreSQL connection parameters.
#
# Examples:
# database-url: "postgres://user:pass@host:5432/ntfy"
# database-url: "postgres://user:pass@host:5432/ntfy?sslmode=require&pool_max_conns=50"
#
# database-url: <connection-string>
# If "cache-file" is set, messages are cached in a local SQLite database instead of only in-memory.
# This allows for service restarts without losing messages in support of the since= parameter.
# Not required if "database-url" is set (messages are stored in PostgreSQL instead).
#
# The "cache-duration" parameter defines the duration for which messages will be buffered
# before they are deleted. This is required to support the "since=..." and "poll=1" parameter.
@@ -77,6 +101,8 @@
# If set, access to the ntfy server and API can be controlled on a granular level using
# the 'ntfy user' and 'ntfy access' commands. See the --help pages for details, or check the docs.
#
# Note: If "database-url" is set, auth is implicitly enabled and "auth-file" must not be set.
#
# - auth-file is the SQLite user/access database; it is created automatically if it doesn't already exist
# - auth-default-access defines the default/fallback access if no access control entry is found; it can be
# set to "read-write" (default), "read-only", "write-only" or "deny-all".
@@ -133,6 +159,7 @@
# - attachment-expiry-duration is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h)
#
# attachment-cache-dir:
# attachment-s3-url: "s3://ACCESS_KEY:SECRET_KEY@bucket/prefix?region=us-east-1"
# attachment-total-size-limit: "5G"
# attachment-file-size-limit: "15M"
# attachment-expiry-duration: "3h"
@@ -197,6 +224,7 @@
# - web-push-public-key is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
# - web-push-private-key is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
# - web-push-file is a database file to keep track of browser subscription endpoints, e.g. /var/cache/ntfy/webpush.db
# Not required if "database-url" is set (subscriptions are stored in PostgreSQL instead).
# - web-push-email-address is the admin email address send to the push provider, e.g. sysadmin@example.com
# - web-push-startup-queries is an optional list of queries to run on startup
# - web-push-expiry-warning-duration defines the duration after which unused subscriptions are sent a warning (default is 55d)

View File

@@ -3,13 +3,15 @@ package server
import (
"encoding/json"
"errors"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
"net/http"
"net/netip"
"strings"
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
const (
@@ -454,21 +456,8 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
return errHTTPUnauthorized
} else if err := s.userManager.AllowReservation(u.Name, req.Topic); err != nil {
return errHTTPConflictTopicReserved
} else if u.IsUser() {
hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic)
if err != nil {
return err
}
if !hasReservation {
reservations, err := s.userManager.ReservationsCount(u.Name)
if err != nil {
return err
} else if reservations >= u.Tier.ReservationLimit {
return errHTTPTooManyRequestsLimitReservations
}
}
}
// Actually add the reservation
// Actually add the reservation (with limit check inside the transaction to avoid races)
logvr(v, r).
Tag(tagAccount).
Fields(log.Context{
@@ -476,7 +465,14 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
"everyone": everyone.String(),
}).
Debug("Adding topic reservation")
if err := s.userManager.AddReservation(u.Name, req.Topic, everyone); err != nil {
var limit int64
if u.IsUser() && u.Tier != nil {
limit = u.Tier.ReservationLimit
}
if err := s.userManager.AddReservation(u.Name, req.Topic, everyone, limit); err != nil {
if errors.Is(err, user.ErrTooManyReservations) {
return errHTTPTooManyRequestsLimitReservations
}
return err
}
// Kill existing subscribers
@@ -529,22 +525,15 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
// The process relies on the manager to perform the actual deletions (see runManager).
func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *visitor, u *user.User, reservationsLimit int64) error {
reservations, err := s.userManager.Reservations(u.Name)
removedTopics, err := s.userManager.RemoveExcessReservations(u.Name, reservationsLimit)
if err != nil {
return err
} else if int64(len(reservations)) <= reservationsLimit {
} else if len(removedTopics) == 0 {
logvr(v, r).Tag(tagAccount).Debug("No excess reservations to remove")
return nil
}
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
logvr(v, r).Tag(tagAccount).Info("Removing excess reservations for topics %s", strings.Join(topics, ", "))
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
logvr(v, r).Tag(tagAccount).Info("Removed excess topic reservations, now removing messages for topics %s", strings.Join(removedTopics, ", "))
if err := s.messageCache.ExpireMessages(removedTopics...); err != nil {
return err
}
go s.pruneMessages()
@@ -641,7 +630,7 @@ func (s *Server) publishSyncEvent(v *visitor) error {
if err != nil {
return err
}
m := newDefaultMessage(syncTopic.ID, string(messageBytes))
m := model.NewDefaultMessage(syncTopic.ID, string(messageBytes))
if err := syncTopic.Publish(v, m); err != nil {
return err
}

File diff suppressed because it is too large Load Diff

View File

@@ -6,6 +6,14 @@ import (
"net/http"
)
func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request, v *visitor) error {
return s.writeJSON(w, &apiVersionResponse{
Version: s.config.BuildVersion,
Commit: s.config.BuildCommit,
Date: s.config.BuildDate,
})
}
func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
users, err := s.userManager.Users()
if err != nil {

View File

@@ -1,6 +1,7 @@
package server
import (
"encoding/json"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
@@ -9,393 +10,452 @@ import (
"time"
)
func TestVersion_Admin(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.BuildVersion = "1.2.3"
c.BuildCommit = "abcdef0"
c.BuildDate = "2026-02-08T00:00:00Z"
s := newTestServer(t, c)
defer s.closeDatabases()
// Create admin and regular user
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Admin can access /v1/version
rr := request(t, s, "GET", "/v1/version", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
var versionResponse apiVersionResponse
require.Nil(t, json.NewDecoder(rr.Body).Decode(&versionResponse))
require.Equal(t, "1.2.3", versionResponse.Version)
require.Equal(t, "abcdef0", versionResponse.Commit)
require.Equal(t, "2026-02-08T00:00:00Z", versionResponse.Date)
// Non-admin user cannot access /v1/version
rr = request(t, s, "GET", "/v1/version", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Unauthenticated user cannot access /v1/version
rr = request(t, s, "GET", "/v1/version", "", nil)
require.Equal(t, 401, rr.Code)
})
}
func TestUser_AddRemove(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Create user with tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 4, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Nil(t, users[1].Tier)
require.Equal(t, "emma", users[2].Name)
require.Equal(t, user.RoleUser, users[2].Role)
require.Equal(t, "tier1", users[2].Tier.Code)
require.Equal(t, user.Everyone, users[3].Name)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check user was deleted
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "emma", users[1].Name)
require.Equal(t, user.Everyone, users[2].Name)
// Reject invalid user change
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
})
require.Equal(t, 200, rr.Code)
// Create user with tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 4, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Nil(t, users[1].Tier)
require.Equal(t, "emma", users[2].Name)
require.Equal(t, user.RoleUser, users[2].Role)
require.Equal(t, "tier1", users[2].Tier.Code)
require.Equal(t, user.Everyone, users[3].Name)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check user was deleted
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "emma", users[1].Name)
require.Equal(t, user.Everyone, users[2].Name)
// Reject invalid user change
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
}
func TestUser_AddWithPasswordHash(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check that user can login with password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, user.RoleAdmin, users[0].Role)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
}
func TestUser_ChangeUserPassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with first password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Change password via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password": "ben-two"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Make sure first password fails
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben-two"),
})
require.Equal(t, 200, rr.Code)
}
func TestUser_ChangeUserTier(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier2",
}))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Equal(t, "tier1", users[1].Tier.Code)
// Change user tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "tier": "tier2"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users again
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, "tier2", users[1].Tier.Code)
}
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier2",
}))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Equal(t, "tier1", users[1].Tier.Code)
// Change user password and tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben-two", "tier": "tier2"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Make sure first password fails
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben-two"),
})
require.Equal(t, 200, rr.Code)
// Check new tier
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, "tier2", users[1].Tier.Code)
}
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"not-ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with first password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "not-ben"),
})
require.Equal(t, 200, rr.Code)
// Change user password and tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
}
func TestUser_DontChangeAdminPassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("admin", "admin", user.RoleAdmin, false))
// Try to change password via API
rr := request(t, s, "PUT", "/v1/users", `{"username": "admin", "password": "admin-new"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 403, rr.Code)
}
func TestUser_AddRemove_Failures(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Cannot create user with invalid username
rr := request(t, s, "POST", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
// Cannot create user if user already exists
rr = request(t, s, "POST", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
// Cannot create user with invalid tier
rr = request(t, s, "POST", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
// Cannot delete user as non-admin
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
}
func TestAccess_AllowReset(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Subscribing not allowed
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
// Grant access
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Now subscribing is allowed
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Reset access
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Subscribing not allowed (again)
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
}
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Grant access fails, because non-admin
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
}
func TestAccess_AllowReset_KillConnection(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin, grant access to "gol*" topics
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
start, timeTaken := time.Now(), atomic.Int64{}
go func() {
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
// Check that user can login with password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
timeTaken.Store(time.Since(start).Milliseconds())
}()
time.Sleep(500 * time.Millisecond)
// Reset access
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Wait for connection to be killed; this will fail if the connection is never killed
waitFor(t, func() bool {
return timeTaken.Load() >= 500
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, user.RoleAdmin, users[0].Role)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
})
}
func TestUser_ChangeUserPassword(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with first password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Change password via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password": "ben-two"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Make sure first password fails
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben-two"),
})
require.Equal(t, 200, rr.Code)
})
}
func TestUser_ChangeUserTier(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier2",
}))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Equal(t, "tier1", users[1].Tier.Code)
// Change user tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "tier": "tier2"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users again
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, "tier2", users[1].Tier.Code)
})
}
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier2",
}))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Equal(t, "tier1", users[1].Tier.Code)
// Change user password and tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben-two", "tier": "tier2"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Make sure first password fails
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben-two"),
})
require.Equal(t, 200, rr.Code)
// Check new tier
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, "tier2", users[1].Tier.Code)
})
}
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"not-ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with first password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "not-ben"),
})
require.Equal(t, 200, rr.Code)
// Change user password and tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
})
}
func TestUser_DontChangeAdminPassword(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("admin", "admin", user.RoleAdmin, false))
// Try to change password via API
rr := request(t, s, "PUT", "/v1/users", `{"username": "admin", "password": "admin-new"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 403, rr.Code)
})
}
func TestUser_AddRemove_Failures(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Cannot create user with invalid username
rr := request(t, s, "POST", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
// Cannot create user if user already exists
rr = request(t, s, "POST", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
// Cannot create user with invalid tier
rr = request(t, s, "POST", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
// Cannot delete user as non-admin
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
})
}
func TestAccess_AllowReset(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Subscribing not allowed
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
// Grant access
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Now subscribing is allowed
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Reset access
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Subscribing not allowed (again)
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
})
}
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Grant access fails, because non-admin
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
})
}
func TestAccess_AllowReset_KillConnection(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin, grant access to "gol*" topics
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
start, timeTaken := time.Now(), atomic.Int64{}
go func() {
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
timeTaken.Store(time.Since(start).Milliseconds())
}()
time.Sleep(500 * time.Millisecond)
// Reset access
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Wait for connection to be killed; this will fail if the connection is never killed
waitFor(t, func() bool {
return timeTaken.Load() >= 500
})
})
}

View File

@@ -10,6 +10,7 @@ import (
"firebase.google.com/go/v4/messaging"
"fmt"
"google.golang.org/api/option"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
"strings"
@@ -43,7 +44,7 @@ func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClien
}
}
func (c *firebaseClient) Send(v *visitor, m *message) error {
func (c *firebaseClient) Send(v *visitor, m *model.Message) error {
if !v.FirebaseAllowed() {
return errFirebaseTemporarilyBanned
}
@@ -121,11 +122,11 @@ func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
// On Android, this will trigger the app to poll the topic and thereby displaying new messages.
// - If UpstreamBaseURL is set, messages are forwarded as poll requests to an upstream server and then forwarded
// to Firebase here. This is mainly for iOS to support self-hosted servers.
func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, error) {
func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message, error) {
var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format
var apnsConfig *messaging.APNSConfig
switch m.Event {
case keepaliveEvent, openEvent:
case model.KeepaliveEvent, model.OpenEvent:
data = map[string]string{
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
@@ -133,7 +134,7 @@ func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, erro
"topic": m.Topic,
}
apnsConfig = createAPNSBackgroundConfig(data)
case pollRequestEvent:
case model.PollRequestEvent:
data = map[string]string{
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
@@ -143,7 +144,7 @@ func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, erro
"poll_id": m.PollID,
}
apnsConfig = createAPNSAlertConfig(m, data)
case messageDeleteEvent, messageClearEvent:
case model.MessageDeleteEvent, model.MessageClearEvent:
data = map[string]string{
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
@@ -152,7 +153,7 @@ func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, erro
"sequence_id": m.SequenceID,
}
apnsConfig = createAPNSBackgroundConfig(data)
case messageEvent:
case model.MessageEvent:
if auther != nil {
// If "anonymous read" for a topic is not allowed, we cannot send the message along
// via Firebase. Instead, we send a "poll_request" message, asking the client to poll.
@@ -235,7 +236,7 @@ func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message {
// createAPNSAlertConfig creates an APNS config for iOS notifications that show up as an alert (only relevant for iOS).
// We must set the Alert struct ("alert"), and we need to set MutableContent ("mutable-content"), so the Notification Service
// Extension in iOS can modify the message.
func createAPNSAlertConfig(m *message, data map[string]string) *messaging.APNSConfig {
func createAPNSAlertConfig(m *model.Message, data map[string]string) *messaging.APNSConfig {
apnsData := make(map[string]any)
for k, v := range data {
apnsData[k] = v
@@ -296,8 +297,8 @@ func maybeTruncateAPNSBodyMessage(s string) string {
//
// This empties all the fields that are not needed for a poll request and just sets the required fields,
// most importantly, the PollID.
func toPollRequest(m *message) *message {
pr := newPollRequestMessage(m.Topic, m.ID)
func toPollRequest(m *model.Message) *model.Message {
pr := model.NewPollRequestMessage(m.Topic, m.ID)
pr.ID = m.ID
pr.Time = m.Time
pr.Priority = m.Priority // Keep priority

View File

@@ -4,6 +4,7 @@ package server
import (
"errors"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
)
@@ -21,7 +22,7 @@ var (
type firebaseClient struct {
}
func (c *firebaseClient) Send(v *visitor, m *message) error {
func (c *firebaseClient) Send(v *visitor, m *model.Message) error {
return errFirebaseNotAvailable
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"net/netip"
"strings"
@@ -63,7 +64,7 @@ func (s *testFirebaseSender) Messages() []*messaging.Message {
}
func TestToFirebaseMessage_Keepalive(t *testing.T) {
m := newKeepaliveMessage("mytopic")
m := model.NewKeepaliveMessage("mytopic")
fbm, err := toFirebaseMessage(m, nil)
require.Nil(t, err)
require.Equal(t, "mytopic", fbm.Topic)
@@ -94,7 +95,7 @@ func TestToFirebaseMessage_Keepalive(t *testing.T) {
}
func TestToFirebaseMessage_Open(t *testing.T) {
m := newOpenMessage("mytopic")
m := model.NewOpenMessage("mytopic")
fbm, err := toFirebaseMessage(m, nil)
require.Nil(t, err)
require.Equal(t, "mytopic", fbm.Topic)
@@ -125,13 +126,13 @@ func TestToFirebaseMessage_Open(t *testing.T) {
}
func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
m := newDefaultMessage("mytopic", "this is a message")
m := model.NewDefaultMessage("mytopic", "this is a message")
m.Priority = 4
m.Tags = []string{"tag 1", "tag2"}
m.Click = "https://google.com"
m.Icon = "https://ntfy.sh/static/img/ntfy.png"
m.Title = "some title"
m.Actions = []*action{
m.Actions = []*model.Action{
{
ID: "123",
Action: "view",
@@ -150,7 +151,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
},
},
}
m.Attachment = &attachment{
m.Attachment = &model.Attachment{
Name: "some file.jpg",
Type: "image/jpeg",
Size: 12345,
@@ -219,7 +220,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
}
func TestToFirebaseMessage_Message_Normal_Not_Allowed(t *testing.T) {
m := newDefaultMessage("mytopic", "this is a message")
m := model.NewDefaultMessage("mytopic", "this is a message")
m.Priority = 5
fbm, err := toFirebaseMessage(m, &testAuther{Allow: false}) // Not allowed!
require.Nil(t, err)
@@ -250,7 +251,7 @@ func TestToFirebaseMessage_Message_Normal_Not_Allowed(t *testing.T) {
}
func TestToFirebaseMessage_PollRequest(t *testing.T) {
m := newPollRequestMessage("mytopic", "fOv6k1QbCzo6")
m := model.NewPollRequestMessage("mytopic", "fOv6k1QbCzo6")
fbm, err := toFirebaseMessage(m, nil)
require.Nil(t, err)
require.Equal(t, "mytopic", fbm.Topic)
@@ -344,18 +345,18 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
func TestToFirebaseSender_Abuse(t *testing.T) {
sender := &testFirebaseSender{allowed: 2}
client := newFirebaseClient(sender, &testAuther{})
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
visitor := newVisitor(newTestConfig(t, ""), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
require.Equal(t, 1, len(sender.Messages()))
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
require.Equal(t, 2, len(sender.Messages()))
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &model.Message{Topic: "mytopic"}))
require.Equal(t, 2, len(sender.Messages()))
sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &model.Message{Topic: "mytopic"}))
require.Equal(t, 0, len(sender.Messages()))
}

View File

@@ -17,15 +17,10 @@ func (s *Server) execManager() {
s.pruneMessages()
s.pruneAndNotifyWebPushSubscriptions()
// Message count per topic
var messagesCached int
messageCounts, err := s.messageCache.MessageCounts()
// Message count
messagesCached, err := s.messageCache.MessagesCount()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Cannot get message counts")
messageCounts = make(map[string]int) // Empty, so we can continue
}
for _, count := range messageCounts {
messagesCached += count
log.Tag(tagManager).Err(err).Warn("Cannot get messages count")
}
// Remove subscriptions without subscribers
@@ -104,6 +99,9 @@ func (s *Server) execManager() {
mset(metricUsers, usersCount)
mset(metricSubscribers, subscribers)
mset(metricTopics, topicsCount)
if s.fileCache != nil {
mset(metricAttachmentsTotalSize, s.fileCache.Size())
}
}
func (s *Server) pruneVisitors() {

View File

@@ -2,27 +2,30 @@ package server
import (
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/model"
"testing"
)
func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testing.T) {
// Tests that the manager runs without attachment-cache-dir set, see #617
c := newTestConfig(t)
c.AttachmentCacheDir = ""
s := newTestServer(t, c)
forEachBackend(t, func(t *testing.T, databaseURL string) {
// Tests that the manager runs without attachment-cache-dir set, see #617
c := newTestConfig(t, databaseURL)
c.AttachmentCacheDir = ""
s := newTestServer(t, c)
// Publish a message
rr := request(t, s, "POST", "/mytopic", "hi", nil)
require.Equal(t, 200, rr.Code)
m := toMessage(t, rr.Body.String())
// Publish a message
rr := request(t, s, "POST", "/mytopic", "hi", nil)
require.Equal(t, 200, rr.Code)
m := toMessage(t, rr.Body.String())
// Expire message
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
// Expire message
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
// Does not panic
s.pruneMessages()
// Does not panic
s.pruneMessages()
// Actually deleted
_, err := s.messageCache.Message(m.ID)
require.Equal(t, errMessageNotFound, err)
// Actually deleted
_, err := s.messageCache.Message(m.ID)
require.Equal(t, model.ErrMessageNotFound, err)
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,5 @@
//go:build !race
package server
const raceEnabled = false

View File

@@ -0,0 +1,5 @@
//go:build race
package server
const raceEnabled = true

File diff suppressed because one or more lines are too long

View File

@@ -11,6 +11,7 @@ import (
"text/template"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
@@ -76,7 +77,7 @@ func (s *Server) convertPhoneNumber(u *user.User, phoneNumber string) (string, *
// callPhone calls the Twilio API to make a phone call to the given phone number, using the given message.
// Failures will be logged, but not returned to the caller.
func (s *Server) callPhone(v *visitor, r *http.Request, m *message, to string) {
func (s *Server) callPhone(v *visitor, r *http.Request, m *model.Message, to string) {
u, sender := v.User(), m.Sender.String()
if u != nil {
sender = u.Name

View File

@@ -14,217 +14,224 @@ import (
)
func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
var called, verified atomic.Bool
var code atomic.Pointer[string]
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
if code.Load() != nil {
forEachBackend(t, func(t *testing.T, databaseURL string) {
var called, verified atomic.Bool
var code atomic.Pointer[string]
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
if code.Load() != nil {
t.Fatal("Should be only called once")
}
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
code.Store(util.String("123456"))
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
if verified.Load() {
t.Fatal("Should be only called once")
}
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
verified.Store(true)
} else {
t.Fatal("Unexpected path:", r.URL.Path)
}
}))
defer twilioVerifyServer.Close()
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
code.Store(util.String("123456"))
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
if verified.Load() {
t.Fatal("Should be only called once")
}
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
verified.Store(true)
} else {
t.Fatal("Unexpected path:", r.URL.Path)
}
}))
defer twilioVerifyServer.Close()
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioCallsServer.Close()
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
c.TwilioCallsBaseURL = twilioCallsServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
c.TwilioVerifyService = "VA1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioCallsServer.Close()
c := newTestConfigWithAuthFile(t)
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
c.TwilioCallsBaseURL = twilioCallsServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
c.TwilioVerifyService = "VA1234567890"
s := newTestServer(t, c)
// Send verification code for phone number
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
waitFor(t, func() bool {
return *code.Load() == "123456"
})
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
// Add phone number with code
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
waitFor(t, func() bool {
return verified.Load()
})
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Equal(t, 1, len(phoneNumbers))
require.Equal(t, "+12223334444", phoneNumbers[0])
// Send verification code for phone number
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
waitFor(t, func() bool {
return *code.Load() == "123456"
})
// Do the thing
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "yes",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
// Add phone number with code
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
waitFor(t, func() bool {
return verified.Load()
})
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Equal(t, 1, len(phoneNumbers))
require.Equal(t, "+12223334444", phoneNumbers[0])
// Remove the phone number
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
// Do the thing
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "yes",
// Verify the phone number is gone from the DB
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Equal(t, 0, len(phoneNumbers))
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
// Remove the phone number
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
// Verify the phone number is gone from the DB
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Equal(t, 0, len(phoneNumbers))
}
func TestServer_Twilio_Call_Success(t *testing.T) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
forEachBackend(t, func(t *testing.T, databaseURL string) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
})
}
func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
forEachBackend(t, func(t *testing.T, databaseURL string) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "yes", // <<<------
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "yes", // <<<------
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
})
}
func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+language%3D%22de-DE%22+loop%3D%223%22%3E%0A%09%09Du+hast+eine+Nachricht+von+notify+im+Thema+mytopic.+Nachricht%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Ende+der+Nachricht.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Diese+Nachricht+wurde+von+Benutzer+phil+gesendet.+Sie+wird+drei+Mal+wiederholt.%0A%09%09Um+dich+von+Anrufen+wie+diesen+abzumelden%2C+entferne+deine+Telefonnummer+in+der+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay+language%3D%22de-DE%22%3EAuf+Wiederh%C3%B6ren.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
forEachBackend(t, func(t *testing.T, databaseURL string) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+language%3D%22de-DE%22+loop%3D%223%22%3E%0A%09%09Du+hast+eine+Nachricht+von+notify+im+Thema+mytopic.+Nachricht%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Ende+der+Nachricht.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Diese+Nachricht+wurde+von+Benutzer+phil+gesendet.+Sie+wird+drei+Mal+wiederholt.%0A%09%09Um+dich+von+Anrufen+wie+diesen+abzumelden%2C+entferne+deine+Telefonnummer+in+der+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay+language%3D%22de-DE%22%3EAuf+Wiederh%C3%B6ren.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
c.TwilioCallFormat = template.Must(template.New("twiml").Parse(`
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
c.TwilioCallFormat = template.Must(template.New("twiml").Parse(`
<Response>
<Pause length="1"/>
<Say language="de-DE" loop="3">
@@ -240,88 +247,97 @@ func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
</Say>
<Say language="de-DE">Auf Wiederhören.</Say>
</Response>`))
s := newTestServer(t, c)
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
})
}
func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "http://dummy.invalid"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = "http://dummy.invalid"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
// Do the thing
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
// Do the thing
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
})
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+invalid",
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+invalid",
})
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
})
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_Twilio_Call_Anonymous(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+123123",
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+123123",
})
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
})
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_Twilio_Call_Unconfigured(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+1234",
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfig(t, databaseURL))
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+1234",
})
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
})
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
}

View File

@@ -11,7 +11,9 @@ import (
"github.com/SherClockHolmes/webpush-go"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
wpush "heckel.io/ntfy/v2/webpush"
)
const (
@@ -82,14 +84,14 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) {
subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
if err != nil {
logvm(v, m).Err(err).With(v, m).Warn("Unable to publish web push messages")
return
}
log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions))
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.forJSON()))
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.ForJSON()))
if err != nil {
log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload")
return
@@ -128,7 +130,7 @@ func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error {
if err != nil {
return err
}
warningSent := make([]*webPushSubscription, 0)
warningSent := make([]*wpush.Subscription, 0)
for _, subscription := range subscriptions {
if err := s.sendWebPushNotification(subscription, payload); err != nil {
log.Tag(tagWebPush).Err(err).With(subscription).Warn("Unable to publish expiry imminent warning")
@@ -143,7 +145,7 @@ func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error {
return nil
}
func (s *Server) sendWebPushNotification(sub *webPushSubscription, message []byte, contexters ...log.Contexter) error {
func (s *Server) sendWebPushNotification(sub *wpush.Subscription, message []byte, contexters ...log.Contexter) error {
log.Tag(tagWebPush).With(sub).With(contexters...).Debug("Sending web push message")
payload := &webpush.Subscription{
Endpoint: sub.Endpoint,

View File

@@ -4,6 +4,8 @@ package server
import (
"net/http"
"heckel.io/ntfy/v2/model"
)
const (
@@ -20,7 +22,7 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *
return errHTTPNotFound
}
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) {
// Nothing to see here
}

View File

@@ -5,10 +5,6 @@ package server
import (
"encoding/json"
"fmt"
"github.com/SherClockHolmes/webpush-go"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
"io"
"net/http"
"net/http/httptest"
@@ -18,6 +14,11 @@ import (
"sync/atomic"
"testing"
"time"
"github.com/SherClockHolmes/webpush-go"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
const (
@@ -25,237 +26,262 @@ const (
)
func TestServer_WebPush_Enabled(t *testing.T) {
conf := newTestConfig(t)
conf.WebRoot = "" // Disable web app
s := newTestServer(t, conf)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfig(t, databaseURL)
conf.WebRoot = "" // Disable web app
s := newTestServer(t, conf)
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
conf2 := newTestConfig(t)
s2 := newTestServer(t, conf2)
conf2 := newTestConfig(t, databaseURL)
s2 := newTestServer(t, conf2)
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
conf3 := newTestConfigWithWebPush(t)
s3 := newTestServer(t, conf3)
conf3 := newTestConfigWithWebPush(t, databaseURL)
s3 := newTestServer(t, conf3)
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 200, rr.Code)
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 200, rr.Code)
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
})
}
func TestServer_WebPush_Disabled(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfig(t, databaseURL))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 404, response.Code)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 404, response.Code)
})
}
func TestServer_WebPush_TopicAdd(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "")
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "")
})
}
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
})
}
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
topicList := make([]string, 51)
for i := range topicList {
topicList[i] = util.RandomString(5)
}
topicList := make([]string, 51)
for i := range topicList {
topicList[i] = util.RandomString(5)
}
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
})
}
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic", 0)
})
}
func TestServer_WebPush_Delete(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic", 0)
})
}
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
forEachBackend(t, func(t *testing.T, databaseURL string) {
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
}
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
forEachBackend(t, func(t *testing.T, databaseURL string) {
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 403, response.Code)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 403, response.Code)
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic", 0)
})
}
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
s := newTestServer(t, config)
forEachBackend(t, func(t *testing.T, databaseURL string) {
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 1)
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
// should've been deleted with the account
requireSubscriptionCount(t, s, "test-topic", 0)
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 1)
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
// should've been deleted with the account
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_Publish(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/push-receive", r.URL.Path)
require.Equal(t, "high", r.Header.Get("Urgency"))
require.Equal(t, "", r.Header.Get("Topic"))
received.Store(true)
}))
defer pushService.Close()
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/push-receive", r.URL.Path)
require.Equal(t, "high", r.Header.Get("Urgency"))
require.Equal(t, "", r.Header.Get("Topic"))
received.Store(true)
}))
defer pushService.Close()
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
request(t, s, "POST", "/test-topic", "web push test", nil)
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
request(t, s, "POST", "/test-topic", "web push test", nil)
waitFor(t, func() bool {
return received.Load()
waitFor(t, func() bool {
return received.Load()
})
})
}
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(http.StatusGone)
received.Store(true)
}))
defer pushService.Close()
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(http.StatusGone)
received.Store(true)
}))
defer pushService.Close()
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
requireSubscriptionCount(t, s, "test-topic", 1)
requireSubscriptionCount(t, s, "test-topic-abc", 1)
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
requireSubscriptionCount(t, s, "test-topic", 1)
requireSubscriptionCount(t, s, "test-topic-abc", 1)
request(t, s, "POST", "/test-topic", "web push test", nil)
request(t, s, "POST", "/test-topic", "web push test", nil)
waitFor(t, func() bool {
return received.Load()
waitFor(t, func() bool {
return received.Load()
})
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic-abc", 0)
})
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic-abc", 0)
}
func TestServer_WebPush_Expiry(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
var received atomic.Bool
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(200)
w.Write([]byte(``))
received.Store(true)
}))
defer pushService.Close()
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(200)
w.Write([]byte(``))
received.Store(true)
}))
defer pushService.Close()
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
endpoint := pushService.URL + "/push-receive"
addSubscription(t, s, endpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix())
require.Nil(t, err)
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-55*24*time.Hour).Unix()))
s.pruneAndNotifyWebPushSubscriptions()
requireSubscriptionCount(t, s, "test-topic", 1)
s.pruneAndNotifyWebPushSubscriptions()
requireSubscriptionCount(t, s, "test-topic", 1)
waitFor(t, func() bool {
return received.Load()
})
waitFor(t, func() bool {
return received.Load()
})
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix())
require.Nil(t, err)
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-60*24*time.Hour).Unix()))
s.pruneAndNotifyWebPushSubscriptions()
waitFor(t, func() bool {
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
return len(subs) == 0
s.pruneAndNotifyWebPushSubscriptions()
waitFor(t, func() bool {
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
return len(subs) == 0
})
})
}
@@ -281,11 +307,13 @@ func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLen
require.Len(t, subs, expectedLength)
}
func newTestConfigWithWebPush(t *testing.T) *Config {
conf := newTestConfig(t)
func newTestConfigWithWebPush(t *testing.T, databaseURL string) *Config {
conf := newTestConfig(t, databaseURL)
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
require.Nil(t, err)
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
if conf.DatabaseURL == "" {
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
}
conf.WebPushEmailAddress = "testing@example.com"
conf.WebPushPrivateKey = privateKey
conf.WebPushPublicKey = publicKey

View File

@@ -12,11 +12,12 @@ import (
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
type mailer interface {
Send(v *visitor, m *message, to string) error
Send(v *visitor, m *model.Message, to string) error
Counts() (total int64, success int64, failure int64)
}
@@ -27,7 +28,7 @@ type smtpSender struct {
mu sync.Mutex
}
func (s *smtpSender) Send(v *visitor, m *message, to string) error {
func (s *smtpSender) Send(v *visitor, m *model.Message, to string) error {
return s.withCount(v, m, func() error {
host, _, err := net.SplitHostPort(s.config.SMTPSenderAddr)
if err != nil {
@@ -63,7 +64,7 @@ func (s *smtpSender) Counts() (total int64, success int64, failure int64) {
return s.success + s.failure, s.success, s.failure
}
func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error {
func (s *smtpSender) withCount(v *visitor, m *model.Message, fn func() error) error {
err := fn()
s.mu.Lock()
defer s.mu.Unlock()
@@ -76,7 +77,7 @@ func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error {
return err
}
func formatMail(baseURL, senderIP, from, to string, m *message) (string, error) {
func formatMail(baseURL, senderIP, from, to string, m *model.Message) (string, error) {
topicURL := baseURL + "/" + m.Topic
subject := m.Title
if subject == "" {

View File

@@ -1,12 +1,14 @@
package server
import (
"github.com/stretchr/testify/require"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/model"
)
func TestFormatMail_Basic(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",
@@ -27,7 +29,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
}
func TestFormatMail_JustEmojis(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",
@@ -49,7 +51,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
}
func TestFormatMail_JustOtherTags(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",
@@ -73,7 +75,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
}
func TestFormatMail_JustPriority(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",
@@ -97,7 +99,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
}
func TestFormatMail_UTF8Subject(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",
@@ -119,7 +121,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
}
func TestFormatMail_WithAllTheThings(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",

View File

@@ -19,6 +19,7 @@ import (
"github.com/emersion/go-smtp"
"github.com/microcosm-cc/bluemonday"
"heckel.io/ntfy/v2/model"
)
var (
@@ -33,6 +34,7 @@ var (
var (
onlySpacesRegex = regexp.MustCompile(`(?m)^\s+$`)
consecutiveNewLinesRegex = regexp.MustCompile(`\n{3,}`)
htmlLineBreakRegex = regexp.MustCompile(`(?i)<br\s*/?>`)
)
const (
@@ -158,7 +160,7 @@ func (s *smtpSession) Data(r io.Reader) error {
if len(body) > conf.MessageSizeLimit {
body = body[:conf.MessageSizeLimit]
}
m := newDefaultMessage(s.topic, body)
m := model.NewDefaultMessage(s.topic, body)
subject := strings.TrimSpace(msg.Header.Get("Subject"))
if subject != "" {
dec := mime.WordDecoder{}
@@ -183,7 +185,7 @@ func (s *smtpSession) Data(r io.Reader) error {
})
}
func (s *smtpSession) publishMessage(m *message) error {
func (s *smtpSession) publishMessage(m *model.Message) error {
// Extract remote address (for rate limiting)
remoteAddr, _, err := net.SplitHostPort(s.conn.Conn().RemoteAddr().String())
if err != nil {
@@ -327,6 +329,9 @@ func readHTMLMailBody(reader io.Reader, transferEncoding string) (string, error)
if err != nil {
return "", err
}
// Convert <br> tags to newlines before stripping HTML, so that line breaks
// in HTML emails (e.g. from Synology DSM, and other appliances) are preserved.
body = htmlLineBreakRegex.ReplaceAllString(body, "\n")
stripped := bluemonday.
StrictPolicy().
AddSpaceWhenStrippingTag(true).

View File

@@ -694,7 +694,8 @@ home automation setup
Now the light is on
If you don&#39;t want to receive this message anymore, stop the push
services in your FRITZ!Box .
services in your FRITZ!Box .
Here you can see the active push services: &#34;System &gt; Push Service&#34;.
This mail has ben sent by your FRITZ!Box automatically.`
@@ -1354,9 +1355,11 @@ Congratulations! You have successfully set up the email notification on Synology
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/synology", r.URL.Path)
require.Equal(t, "[Synology NAS] Test Message from Litts_NAS", r.Header.Get("Title"))
actual := readAll(t, r.Body)
expected := `Congratulations! You have successfully set up the email notification on Synology_NAS. For further system configurations, please visit http://192.168.1.28:5000/, http://172.16.60.5:5000/. (If you cannot connect to the server, please contact the administrator.) From Synology_NAS`
require.Equal(t, expected, actual)
expected := "Congratulations! You have successfully set up the email notification on Synology_NAS.\n" +
"For further system configurations, please visit http://192.168.1.28:5000/, http://172.16.60.5:5000/.\n" +
"(If you cannot connect to the server, please contact the administrator.)\n\n" +
"From Synology_NAS"
require.Equal(t, expected, readAll(t, r.Body))
})
conf.SMTPServerDomain = "mydomain.me"
conf.SMTPServerAddrPrefix = ""
@@ -1365,6 +1368,36 @@ Congratulations! You have successfully set up the email notification on Synology
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_HTMLEmail_BrTagsPreserved(t *testing.T) {
email := `EHLO example.com
MAIL FROM: nas@example.com
RCPT TO: ntfy-alerts@ntfy.sh
DATA
Content-Type: text/html; charset=utf-8
Content-Transfer-Encoding: 8bit
Subject: Task Scheduler: daily-backup
Task Scheduler has completed a scheduled task.<BR><BR>Task: daily-backup<BR>Start time: Mon, 01 Jan 2026 02:00:00 +0000<BR>Stop time: Mon, 01 Jan 2024 02:03:00 +0000<BR>Current status: 0 (Normal)<BR>Standard output/error:<BR>OK<BR><BR>From MyNAS
.
`
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/alerts", r.URL.Path)
require.Equal(t, "Task Scheduler: daily-backup", r.Header.Get("Title"))
expected := "Task Scheduler has completed a scheduled task.\n\n" +
"Task: daily-backup\n" +
"Start time: Mon, 01 Jan 2026 02:00:00 +0000\n" +
"Stop time: Mon, 01 Jan 2024 02:03:00 +0000\n" +
"Current status: 0 (Normal)\n" +
"Standard output/error:\n" +
"OK\n\n" +
"From MyNAS"
require.Equal(t, expected, readAll(t, r.Body))
})
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_PlaintextWithToken(t *testing.T) {
email := `EHLO example.com
MAIL FROM: phil@example.com
@@ -1411,7 +1444,7 @@ what's up
type smtpHandlerFunc func(http.ResponseWriter, *http.Request)
func newTestSMTPServer(t *testing.T, handler smtpHandlerFunc) (s *smtp.Server, c net.Conn, conf *Config, scanner *bufio.Scanner) {
conf = newTestConfig(t)
conf = newTestConfig(t, "")
conf.SMTPServerListen = ":25"
conf.SMTPServerDomain = "ntfy.sh"
conf.SMTPServerAddrPrefix = "ntfy-"

View File

@@ -6,6 +6,7 @@ import (
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
@@ -33,7 +34,7 @@ type topicSubscriber struct {
}
// subscriber is a function that is called for every new message on a topic
type subscriber func(v *visitor, msg *message) error
type subscriber func(v *visitor, msg *model.Message) error
// newTopic creates a new topic
func newTopic(id string) *topic {
@@ -103,7 +104,7 @@ func (t *topic) Unsubscribe(id int) {
}
// Publish asynchronously publishes to all subscribers
func (t *topic) Publish(v *visitor, m *message) error {
func (t *topic) Publish(v *visitor, m *model.Message) error {
go func() {
// We want to lock the topic as short as possible, so we make a shallow copy of the
// subscribers map here. Actually sending out the messages then doesn't have to lock.

View File

@@ -7,10 +7,11 @@ import (
"time"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/model"
)
func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
subFn := func(v *visitor, msg *message) error {
subFn := func(v *visitor, msg *model.Message) error {
return nil
}
canceled1 := atomic.Bool{}
@@ -33,7 +34,7 @@ func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
func TestTopic_CancelSubscribersUser(t *testing.T) {
t.Parallel()
subFn := func(v *visitor, msg *message) error {
subFn := func(v *visitor, msg *model.Message) error {
return nil
}
canceled1 := atomic.Bool{}
@@ -76,7 +77,7 @@ func TestTopic_Subscribe_DuplicateID(t *testing.T) {
cancel: func() {},
}
subFn := func(v *visitor, msg *message) error {
subFn := func(v *visitor, msg *model.Message) error {
return nil
}

View File

@@ -2,219 +2,35 @@ package server
import (
"net/http"
"net/netip"
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
// List of possible events
const (
openEvent = "open"
keepaliveEvent = "keepalive"
messageEvent = "message"
messageDeleteEvent = "message_delete"
messageClearEvent = "message_clear"
pollRequestEvent = "poll_request"
)
const (
messageIDLength = 12
)
// message represents a message published to a topic
type message struct {
ID string `json:"id"` // Random message ID
SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID)
Time int64 `json:"time"` // Unix time in seconds
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
Event string `json:"event"` // One of the above
Topic string `json:"topic"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"`
Priority int `json:"priority,omitempty"`
Tags []string `json:"tags,omitempty"`
Click string `json:"click,omitempty"`
Icon string `json:"icon,omitempty"`
Actions []*action `json:"actions,omitempty"`
Attachment *attachment `json:"attachment,omitempty"`
PollID string `json:"poll_id,omitempty"`
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
User string `json:"-"` // UserID of the uploader, used to associated attachments
}
func (m *message) Context() log.Context {
fields := map[string]any{
"topic": m.Topic,
"message_id": m.ID,
"message_sequence_id": m.SequenceID,
"message_time": m.Time,
"message_event": m.Event,
"message_body_size": len(m.Message),
}
if m.Sender.IsValid() {
fields["message_sender"] = m.Sender.String()
}
if m.User != "" {
fields["message_user"] = m.User
}
return fields
}
// forJSON returns a copy of the message suitable for JSON output.
// It clears the SequenceID if it equals the ID to reduce redundancy.
func (m *message) forJSON() *message {
if m.SequenceID == m.ID {
clone := *m
clone.SequenceID = ""
return &clone
}
return m
}
type attachment struct {
Name string `json:"name"`
Type string `json:"type,omitempty"`
Size int64 `json:"size,omitempty"`
Expires int64 `json:"expires,omitempty"`
URL string `json:"url"`
}
type action struct {
ID string `json:"id"`
Action string `json:"action"` // "view", "broadcast", "http", or "copy"
Label string `json:"label"` // action button label
Clear bool `json:"clear"` // clear notification after successful execution
URL string `json:"url,omitempty"` // used in "view" and "http" actions
Method string `json:"method,omitempty"` // used in "http" action, default is POST (!)
Headers map[string]string `json:"headers,omitempty"` // used in "http" action
Body string `json:"body,omitempty"` // used in "http" action
Intent string `json:"intent,omitempty"` // used in "broadcast" action
Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action
Value string `json:"value,omitempty"` // used in "copy" action
}
func newAction() *action {
return &action{
Headers: make(map[string]string),
Extras: make(map[string]string),
}
}
// publishMessage is used as input when publishing as JSON
type publishMessage struct {
Topic string `json:"topic"`
SequenceID string `json:"sequence_id"`
Title string `json:"title"`
Message string `json:"message"`
Priority int `json:"priority"`
Tags []string `json:"tags"`
Click string `json:"click"`
Icon string `json:"icon"`
Actions []action `json:"actions"`
Attach string `json:"attach"`
Markdown bool `json:"markdown"`
Filename string `json:"filename"`
Email string `json:"email"`
Call string `json:"call"`
Cache string `json:"cache"` // use string as it defaults to true (or use &bool instead)
Firebase string `json:"firebase"` // use string as it defaults to true (or use &bool instead)
Delay string `json:"delay"`
Topic string `json:"topic"`
SequenceID string `json:"sequence_id"`
Title string `json:"title"`
Message string `json:"message"`
Priority int `json:"priority"`
Tags []string `json:"tags"`
Click string `json:"click"`
Icon string `json:"icon"`
Actions []model.Action `json:"actions"`
Attach string `json:"attach"`
Markdown bool `json:"markdown"`
Filename string `json:"filename"`
Email string `json:"email"`
Call string `json:"call"`
Cache string `json:"cache"` // use string as it defaults to true (or use &bool instead)
Firebase string `json:"firebase"` // use string as it defaults to true (or use &bool instead)
Delay string `json:"delay"`
}
// messageEncoder is a function that knows how to encode a message
type messageEncoder func(msg *message) (string, error)
// newMessage creates a new message with the current timestamp
func newMessage(event, topic, msg string) *message {
return &message{
ID: util.RandomString(messageIDLength),
Time: time.Now().Unix(),
Event: event,
Topic: topic,
Message: msg,
}
}
// newOpenMessage is a convenience method to create an open message
func newOpenMessage(topic string) *message {
return newMessage(openEvent, topic, "")
}
// newKeepaliveMessage is a convenience method to create a keepalive message
func newKeepaliveMessage(topic string) *message {
return newMessage(keepaliveEvent, topic, "")
}
// newDefaultMessage is a convenience method to create a notification message
func newDefaultMessage(topic, msg string) *message {
return newMessage(messageEvent, topic, msg)
}
// newPollRequestMessage is a convenience method to create a poll request message
func newPollRequestMessage(topic, pollID string) *message {
m := newMessage(pollRequestEvent, topic, newMessageBody)
m.PollID = pollID
return m
}
// newActionMessage creates a new action message (message_delete or message_clear)
func newActionMessage(event, topic, sequenceID string) *message {
m := newMessage(event, topic, "")
m.SequenceID = sequenceID
return m
}
func validMessageID(s string) bool {
return util.ValidRandomString(s, messageIDLength)
}
type sinceMarker struct {
time time.Time
id string
}
func newSinceTime(timestamp int64) sinceMarker {
return sinceMarker{time.Unix(timestamp, 0), ""}
}
func newSinceID(id string) sinceMarker {
return sinceMarker{time.Unix(0, 0), id}
}
func (t sinceMarker) IsAll() bool {
return t == sinceAllMessages
}
func (t sinceMarker) IsNone() bool {
return t == sinceNoMessages
}
func (t sinceMarker) IsLatest() bool {
return t == sinceLatestMessage
}
func (t sinceMarker) IsID() bool {
return t.id != "" && t.id != "latest"
}
func (t sinceMarker) Time() time.Time {
return t.time
}
func (t sinceMarker) ID() string {
return t.id
}
var (
sinceAllMessages = sinceMarker{time.Unix(0, 0), ""}
sinceNoMessages = sinceMarker{time.Unix(1, 0), ""}
sinceLatestMessage = sinceMarker{time.Unix(0, 0), "latest"}
)
type messageEncoder func(msg *model.Message) (string, error)
type queryFilter struct {
ID string
@@ -246,8 +62,8 @@ func parseQueryFilters(r *http.Request) (*queryFilter, error) {
}, nil
}
func (q *queryFilter) Pass(msg *message) bool {
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
func (q *queryFilter) Pass(msg *model.Message) bool {
if msg.Event != model.MessageEvent && msg.Event != model.MessageDeleteEvent && msg.Event != model.MessageClearEvent {
return true // filters only apply to messages
} else if q.ID != "" && msg.ID != q.ID {
return false
@@ -320,6 +136,12 @@ type apiHealthResponse struct {
Healthy bool `json:"healthy"`
}
type apiVersionResponse struct {
Version string `json:"version"`
Commit string `json:"commit"`
Date string `json:"date"`
}
type apiStatsResponse struct {
Messages int64 `json:"messages"`
MessagesRate float64 `json:"messages_rate"` // Average number of messages per second
@@ -564,12 +386,12 @@ const (
)
type webPushPayload struct {
Event string `json:"event"`
SubscriptionID string `json:"subscription_id"`
Message *message `json:"message"`
Event string `json:"event"`
SubscriptionID string `json:"subscription_id"`
Message *model.Message `json:"message"`
}
func newWebPushPayload(subscriptionID string, message *message) *webPushPayload {
func newWebPushPayload(subscriptionID string, message *model.Message) *webPushPayload {
return &webPushPayload{
Event: webPushMessageEvent,
SubscriptionID: subscriptionID,
@@ -587,22 +409,6 @@ func newWebPushSubscriptionExpiringPayload() *webPushControlMessagePayload {
}
}
type webPushSubscription struct {
ID string
Endpoint string
Auth string
P256dh string
UserID string
}
func (w *webPushSubscription) Context() log.Context {
return map[string]any{
"web_push_subscription_id": w.ID,
"web_push_subscription_user_id": w.UserID,
"web_push_subscription_endpoint": w.Endpoint,
}
}
// https://developer.mozilla.org/en-US/docs/Web/Manifest
type webManifestResponse struct {
Name string `json:"name"`

View File

@@ -8,6 +8,7 @@ import (
"golang.org/x/time/rate"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
@@ -53,7 +54,7 @@ const (
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
type visitor struct {
config *Config
messageCache *messageCache
messageCache *message.Cache
userManager *user.Manager // May be nil
ip netip.Addr // Visitor IP address
user *user.User // Only set if authenticated user, otherwise nil
@@ -114,7 +115,7 @@ const (
visitorLimitBasisTier = visitorLimitBasis("tier")
)
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
func newVisitor(conf *Config, messageCache *message.Cache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messages, emails, calls int64
if user != nil {
messages = user.Stats.Messages

View File

@@ -1,285 +0,0 @@
package server
import (
"database/sql"
"errors"
"heckel.io/ntfy/v2/util"
"net/netip"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
const (
subscriptionIDPrefix = "wps_"
subscriptionIDLength = 10
subscriptionEndpointLimitPerSubscriberIP = 10
)
var (
errWebPushNoRows = errors.New("no rows found")
errWebPushTooManySubscriptions = errors.New("too many subscriptions")
errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
)
const (
createWebPushSubscriptionsTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS subscription (
id TEXT PRIMARY KEY,
endpoint TEXT NOT NULL,
key_auth TEXT NOT NULL,
key_p256dh TEXT NOT NULL,
user_id TEXT NOT NULL,
subscriber_ip TEXT NOT NULL,
updated_at INT NOT NULL,
warned_at INT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
CREATE TABLE IF NOT EXISTS subscription_topic (
subscription_id TEXT NOT NULL,
topic TEXT NOT NULL,
PRIMARY KEY (subscription_id, topic),
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
COMMIT;
`
builtinStartupQueries = `
PRAGMA foreign_keys = ON;
`
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
selectWebPushSubscriptionsForTopicQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription_topic st
JOIN subscription s ON s.id = st.subscription_id
WHERE st.topic = ?
ORDER BY endpoint
`
selectWebPushSubscriptionsExpiringSoonQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription
WHERE warned_at = 0 AND updated_at <= ?
`
insertWebPushSubscriptionQuery = `
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
`
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
deleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
)
// Schema management queries
const (
currentWebPushSchemaVersion = 1
insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
type webPushStore struct {
db *sql.DB
}
func newWebPushStore(filename, startupQueries string) (*webPushStore, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupWebPushDB(db); err != nil {
return nil, err
}
if err := runWebPushStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &webPushStore{
db: db,
}, nil
}
func setupWebPushDB(db *sql.DB) error {
// If 'schemaVersion' table does not exist, this must be a new database
rows, err := db.Query(selectWebPushSchemaVersionQuery)
if err != nil {
return setupNewWebPushDB(db)
}
return rows.Close()
}
func setupNewWebPushDB(db *sql.DB) error {
if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
return err
}
return nil
}
func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
if _, err := db.Exec(builtinStartupQueries); err != nil {
return err
}
return nil
}
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
// existing entries for a given endpoint.
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Read number of subscriptions for subscriber IP address
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
if err != nil {
return err
}
defer rowsCount.Close()
var subscriptionCount int
if !rowsCount.Next() {
return errWebPushNoRows
}
if err := rowsCount.Scan(&subscriptionCount); err != nil {
return err
}
if err := rowsCount.Close(); err != nil {
return err
}
// Read existing subscription ID for endpoint (or create new ID)
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
if err != nil {
return err
}
defer rows.Close()
var subscriptionID string
if rows.Next() {
if err := rows.Scan(&subscriptionID); err != nil {
return err
}
} else {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return errWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
}
if err := rows.Close(); err != nil {
return err
}
// Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err
}
// Replace all subscription topics
if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
return err
}
for _, topic := range topics {
if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
return err
}
}
return tx.Commit()
}
// SubscriptionsForTopic returns all subscriptions for the given topic
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
if err != nil {
return nil, err
}
defer rows.Close()
return c.subscriptionsFromRows(rows)
}
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
if err != nil {
return nil, err
}
defer rows.Close()
return c.subscriptionsFromRows(rows)
}
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, subscription := range subscriptions {
if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
return err
}
}
return tx.Commit()
}
func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
subscriptions := make([]*webPushSubscription, 0)
for rows.Next() {
var id, endpoint, auth, p256dh, userID string
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
return nil, err
}
subscriptions = append(subscriptions, &webPushSubscription{
ID: id,
Endpoint: endpoint,
Auth: auth,
P256dh: p256dh,
UserID: userID,
})
}
return subscriptions, nil
}
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
return err
}
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
if userID == "" {
return errWebPushUserIDCannotBeEmpty
}
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
return err
}
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
if err != nil {
return err
}
_, err = c.db.Exec(deleteWebPushSubscriptionTopicWithoutSubscription)
return err
}
// Close closes the underlying database connection
func (c *webPushStore) Close() error {
return c.db.Close()
}

View File

@@ -1,199 +0,0 @@
package server
import (
"fmt"
"github.com/stretchr/testify/require"
"net/netip"
"path/filepath"
"testing"
"time"
)
func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
subs, err := webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "u_1234")
subs2, err := webPush.SubscriptionsForTopic("mytopic")
require.Nil(t, err)
require.Len(t, subs2, 1)
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
}
func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert 10 subscriptions with the same IP address
for i := 0; i < 10; i++ {
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
}
// Another one for the same endpoint should be fine
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
// But with a different endpoint it should fail
require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
// But with a different IP address it should be fine again
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
}
func TestWebPushStore_UpsertSubscription_UpdateTopics(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics, and another with one topic
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 2)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint)
subs, err = webPush.SubscriptionsForTopic("topic2")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
// Update the first subscription to have only one topic
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 2)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
subs, err = webPush.SubscriptionsForTopic("topic2")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func TestWebPushStore_RemoveSubscriptionsByEndpoint(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// And remove it again
require.Nil(t, webPush.RemoveSubscriptionsByEndpoint(testWebPushEndpoint))
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func TestWebPushStore_RemoveSubscriptionsByUserID(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// And remove it again
require.Nil(t, webPush.RemoveSubscriptionsByUserID("u_1234"))
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func TestWebPushStore_RemoveSubscriptionsByUserID_Empty(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
require.Equal(t, errWebPushUserIDCannotBeEmpty, webPush.RemoveSubscriptionsByUserID(""))
}
func TestWebPushStore_MarkExpiryWarningSent(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// Mark them as warning sent
require.Nil(t, webPush.MarkExpiryWarningSent(subs))
rows, err := webPush.db.Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
require.Nil(t, err)
defer rows.Close()
var endpoint string
require.True(t, rows.Next())
require.Nil(t, rows.Scan(&endpoint))
require.Nil(t, err)
require.Equal(t, testWebPushEndpoint, endpoint)
require.False(t, rows.Next())
}
func TestWebPushStore_SubscriptionsExpiring(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// Fake-mark them as soon-to-expire
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
require.Nil(t, err)
// Should not be cleaned up yet
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
// Run expiration
subs, err = webPush.SubscriptionsExpiring(7 * 24 * time.Hour)
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
}
func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// Fake-mark them as expired
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
require.Nil(t, err)
// Run expiration
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
// List again, should be 0
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func newTestWebPushStore(t *testing.T) *webPushStore {
webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "")
require.Nil(t, err)
return webPush
}

View File

@@ -3,7 +3,7 @@ package test
import (
"fmt"
"heckel.io/ntfy/v2/server"
"math/rand"
"net"
"net/http"
"path/filepath"
"testing"
@@ -16,7 +16,7 @@ func StartServer(t *testing.T) (*server.Server, int) {
// StartServerWithConfig starts a server.Server with a random port and waits for the server to be up
func StartServerWithConfig(t *testing.T, conf *server.Config) (*server.Server, int) {
port := 10000 + rand.Intn(30000)
port := findAvailablePort(t)
conf.ListenHTTP = fmt.Sprintf(":%d", port)
conf.AttachmentCacheDir = t.TempDir()
conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
@@ -33,6 +33,17 @@ func StartServerWithConfig(t *testing.T, conf *server.Config) (*server.Server, i
return s, port
}
// findAvailablePort asks the OS for a free port by binding to :0
func findAvailablePort(t *testing.T) int {
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
return port
}
// StopServer stops the test server and waits for the port to be down
func StopServer(t *testing.T, s *server.Server, port int) {
s.Stop()

5
tools/loadtest/go.mod Normal file
View File

@@ -0,0 +1,5 @@
module loadtest
go 1.25.2
require github.com/gorilla/websocket v1.5.3

2
tools/loadtest/go.sum Normal file
View File

@@ -0,0 +1,2 @@
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=

543
tools/loadtest/main.go Normal file
View File

@@ -0,0 +1,543 @@
// Load test program for ntfy staging server.
// Replicates production traffic patterns derived from access.log analysis.
//
// Traffic profile (from ~5M requests over 20 hours):
// ~71 req/sec average, ~4,300 req/min
// 49.6% poll requests (GET /TOPIC/json?poll=1&since=ID)
// 21.4% publish POST (POST /TOPIC with small body)
// 6.2% subscribe stream (GET /TOPIC/json?since=X, long-lived)
// 4.1% config check (GET /v1/config)
// 2.3% other topic GET (GET /TOPIC)
// 2.2% account check (GET /v1/account)
// 1.9% websocket sub (GET /TOPIC/ws?since=X)
// 1.5% publish PUT (PUT /TOPIC with small body)
// 1.5% raw subscribe (GET /TOPIC/raw?since=X)
// 1.1% json subscribe (GET /TOPIC/json, no since)
// 0.7% SSE subscribe (GET /TOPIC/sse?since=X)
// remaining: static, PATCH, OPTIONS, etc. (omitted)
package main
import (
"context"
"crypto/rand"
"encoding/hex"
"flag"
"fmt"
"io"
"math/big"
mrand "math/rand"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
)
var (
baseURL string
username string
password string
rps float64
scale float64
numTopics int
subStreams int
wsStreams int
sseStreams int
rawStreams int
duration time.Duration
totalRequests atomic.Int64
totalErrors atomic.Int64
activeStreams atomic.Int64
// Error tracking by category
errMu sync.Mutex
recentErrors []string // last N unique error messages
errorCounts = make(map[string]int64)
)
func main() {
flag.StringVar(&baseURL, "url", "https://staging.ntfy.sh", "Base URL of ntfy server")
flag.StringVar(&username, "user", "", "Username for authentication")
flag.StringVar(&password, "pass", "", "Password for authentication")
flag.Float64Var(&rps, "rps", 71, "Target requests per second (default: prod average)")
flag.Float64Var(&scale, "scale", 1.0, "Scale factor for all load (0.5 = half load, 2.0 = double)")
flag.IntVar(&numTopics, "topics", 500, "Number of unique topics to use")
flag.IntVar(&subStreams, "sub-streams", 200, "Number of concurrent JSON streaming subscriptions")
flag.IntVar(&wsStreams, "ws-streams", 50, "Number of concurrent WebSocket subscriptions")
flag.IntVar(&sseStreams, "sse-streams", 20, "Number of concurrent SSE subscriptions")
flag.IntVar(&rawStreams, "raw-streams", 30, "Number of concurrent raw subscriptions")
flag.DurationVar(&duration, "duration", 10*time.Minute, "Test duration")
flag.Parse()
rps *= scale
subStreams = int(float64(subStreams) * scale)
wsStreams = int(float64(wsStreams) * scale)
sseStreams = int(float64(sseStreams) * scale)
rawStreams = int(float64(rawStreams) * scale)
topics := generateTopics(numTopics)
fmt.Printf("ntfy load test\n")
fmt.Printf(" Target: %s\n", baseURL)
fmt.Printf(" RPS: %.1f\n", rps)
fmt.Printf(" Scale: %.1fx\n", scale)
fmt.Printf(" Topics: %d\n", numTopics)
fmt.Printf(" Sub streams: %d json, %d ws, %d sse, %d raw\n", subStreams, wsStreams, sseStreams, rawStreams)
fmt.Printf(" Duration: %s\n", duration)
fmt.Println()
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
// Also handle Ctrl+C
sigCtx, sigCancel := signal.NotifyContext(ctx, os.Interrupt)
defer sigCancel()
ctx = sigCtx
client := &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 90 * time.Second,
},
}
// Long-lived streaming client (no timeout)
streamClient := &http.Client{
Timeout: 0,
Transport: &http.Transport{
MaxIdleConns: 500,
MaxIdleConnsPerHost: 500,
IdleConnTimeout: 0,
},
}
var wg sync.WaitGroup
// Start long-lived streaming subscriptions
for i := 0; i < subStreams; i++ {
wg.Add(1)
go func() {
defer wg.Done()
streamSubscription(ctx, streamClient, topics, "json")
}()
}
for i := 0; i < wsStreams; i++ {
wg.Add(1)
go func() {
defer wg.Done()
wsSubscription(ctx, topics)
}()
}
for i := 0; i < sseStreams; i++ {
wg.Add(1)
go func() {
defer wg.Done()
streamSubscription(ctx, streamClient, topics, "sse")
}()
}
for i := 0; i < rawStreams; i++ {
wg.Add(1)
go func() {
defer wg.Done()
streamSubscription(ctx, streamClient, topics, "raw")
}()
}
// Start request generators based on traffic weights
// Weights from log analysis (normalized to sum ~100):
// poll=49.6, publish_post=21.4, config=4.1, other_get=2.3, account=2.2, publish_put=1.5
// Total short-lived weight ≈ 81.1
type requestType struct {
name string
weight float64
fn func(ctx context.Context, client *http.Client, topics []string)
}
types := []requestType{
{"poll", 49.6, doPoll},
{"publish_post", 21.4, doPublishPost},
{"config", 4.1, doConfig},
{"other_get", 2.3, doOtherGet},
{"account", 2.2, doAccountCheck},
{"publish_put", 1.5, doPublishPut},
}
totalWeight := 0.0
for _, t := range types {
totalWeight += t.weight
}
for _, t := range types {
t := t
typeRPS := rps * (t.weight / totalWeight)
if typeRPS < 0.1 {
continue
}
wg.Add(1)
go func() {
defer wg.Done()
runAtRate(ctx, typeRPS, func() {
t.fn(ctx, client, topics)
})
}()
}
// Stats reporter
wg.Add(1)
go func() {
defer wg.Done()
reportStats(ctx)
}()
wg.Wait()
fmt.Printf("\nDone. Total requests: %d, errors: %d\n", totalRequests.Load(), totalErrors.Load())
}
func trackError(category string, err error) {
totalErrors.Add(1)
key := fmt.Sprintf("%s: %s", category, truncateErr(err))
errMu.Lock()
errorCounts[key]++
errMu.Unlock()
}
func trackErrorMsg(category string, msg string) {
totalErrors.Add(1)
key := fmt.Sprintf("%s: %s", category, msg)
errMu.Lock()
errorCounts[key]++
errMu.Unlock()
}
func truncateErr(err error) string {
s := err.Error()
if len(s) > 120 {
s = s[:120] + "..."
}
return s
}
func setAuth(req *http.Request) {
if username != "" && password != "" {
req.SetBasicAuth(username, password)
}
}
func generateTopics(n int) []string {
topics := make([]string, n)
for i := 0; i < n; i++ {
b := make([]byte, 8)
rand.Read(b)
topics[i] = "loadtest-" + hex.EncodeToString(b)
}
return topics
}
func pickTopic(topics []string) string {
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(topics))))
return topics[n.Int64()]
}
func randomSince() string {
b := make([]byte, 6)
rand.Read(b)
return hex.EncodeToString(b)
}
func randomMessage() string {
messages := []string{
"Test notification",
"Server backup completed successfully",
"Deployment finished",
"Alert: disk usage above 80%",
"Build #1234 passed",
"New order received",
"Temperature sensor reading: 72F",
"Cron job completed",
}
return messages[mrand.Intn(len(messages))]
}
// runAtRate executes fn at approximately the given rate per second
func runAtRate(ctx context.Context, rate float64, fn func()) {
interval := time.Duration(float64(time.Second) / rate)
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
go fn()
}
}
}
// --- Short-lived request types ---
func doPoll(ctx context.Context, client *http.Client, topics []string) {
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s/json?poll=1&since=%s", baseURL, topic, randomSince())
doGet(ctx, client, url)
}
func doPublishPost(ctx context.Context, client *http.Client, topics []string) {
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s", baseURL, topic)
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(randomMessage()))
if err != nil {
trackError("publish_post_req", err)
return
}
setAuth(req)
// Some messages have titles/priorities like real traffic
if mrand.Float32() < 0.3 {
req.Header.Set("X-Title", "Load Test")
}
if mrand.Float32() < 0.1 {
req.Header.Set("X-Priority", fmt.Sprintf("%d", mrand.Intn(5)+1))
}
resp, err := client.Do(req)
totalRequests.Add(1)
if err != nil {
trackError("publish_post", err)
return
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
if resp.StatusCode >= 400 {
trackErrorMsg("publish_post_http", fmt.Sprintf("status %d", resp.StatusCode))
}
}
func doPublishPut(ctx context.Context, client *http.Client, topics []string) {
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s", baseURL, topic)
req, err := http.NewRequestWithContext(ctx, "PUT", url, strings.NewReader(randomMessage()))
if err != nil {
trackError("publish_put_req", err)
return
}
setAuth(req)
resp, err := client.Do(req)
totalRequests.Add(1)
if err != nil {
trackError("publish_put", err)
return
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
if resp.StatusCode >= 400 {
trackErrorMsg("publish_put_http", fmt.Sprintf("status %d", resp.StatusCode))
}
}
func doConfig(ctx context.Context, client *http.Client, topics []string) {
url := fmt.Sprintf("%s/v1/config", baseURL)
doGet(ctx, client, url)
}
func doAccountCheck(ctx context.Context, client *http.Client, topics []string) {
url := fmt.Sprintf("%s/v1/account", baseURL)
doGet(ctx, client, url)
}
func doOtherGet(ctx context.Context, client *http.Client, topics []string) {
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s", baseURL, topic)
doGet(ctx, client, url)
}
func doGet(ctx context.Context, client *http.Client, url string) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
trackError("get_req", err)
return
}
setAuth(req)
resp, err := client.Do(req)
totalRequests.Add(1)
if err != nil {
trackError("get", err)
return
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
if resp.StatusCode >= 400 {
trackErrorMsg("get_http", fmt.Sprintf("status %d for %s", resp.StatusCode, url))
}
}
// --- Long-lived streaming subscriptions ---
func streamSubscription(ctx context.Context, client *http.Client, topics []string, format string) {
for {
if ctx.Err() != nil {
return
}
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s/%s?since=all", baseURL, topic, format)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
time.Sleep(time.Second)
continue
}
setAuth(req)
activeStreams.Add(1)
resp, err := client.Do(req)
if err != nil {
activeStreams.Add(-1)
if ctx.Err() == nil {
trackError("stream_"+format+"_connect", err)
}
time.Sleep(time.Second)
continue
}
if resp.StatusCode >= 400 {
trackErrorMsg("stream_"+format+"_http", fmt.Sprintf("status %d", resp.StatusCode))
resp.Body.Close()
activeStreams.Add(-1)
time.Sleep(time.Second)
continue
}
// Read from stream until context cancelled or connection drops
buf := make([]byte, 4096)
for {
_, err := resp.Body.Read(buf)
if err != nil {
if ctx.Err() == nil {
trackError("stream_"+format+"_read", err)
}
break
}
}
resp.Body.Close()
activeStreams.Add(-1)
// Reconnect with small delay (like real clients do)
select {
case <-ctx.Done():
return
case <-time.After(time.Duration(mrand.Intn(3000)) * time.Millisecond):
}
}
}
func wsSubscription(ctx context.Context, topics []string) {
wsURL := strings.Replace(baseURL, "https://", "wss://", 1)
wsURL = strings.Replace(wsURL, "http://", "ws://", 1)
for {
if ctx.Err() != nil {
return
}
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s/ws?since=all", wsURL, topic)
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
var wsHeader http.Header
if username != "" && password != "" {
wsHeader = http.Header{}
req, _ := http.NewRequest("GET", url, nil)
req.SetBasicAuth(username, password)
wsHeader.Set("Authorization", req.Header.Get("Authorization"))
}
activeStreams.Add(1)
conn, _, err := dialer.DialContext(ctx, url, wsHeader)
if err != nil {
activeStreams.Add(-1)
if ctx.Err() == nil {
trackError("ws_connect", err)
}
time.Sleep(time.Second)
continue
}
// Read messages until context cancelled or error
done := make(chan struct{})
go func() {
defer close(done)
for {
conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
_, _, err := conn.ReadMessage()
if err != nil {
return
}
}
}()
select {
case <-ctx.Done():
conn.Close()
activeStreams.Add(-1)
return
case <-done:
conn.Close()
activeStreams.Add(-1)
}
select {
case <-ctx.Done():
return
case <-time.After(time.Duration(mrand.Intn(3000)) * time.Millisecond):
}
}
}
func reportStats(ctx context.Context) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
var lastRequests, lastErrors int64
lastTime := time.Now()
reportCount := 0
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
currentRequests := totalRequests.Load()
currentErrors := totalErrors.Load()
elapsed := now.Sub(lastTime).Seconds()
currentRPS := float64(currentRequests-lastRequests) / elapsed
errorRate := float64(currentErrors-lastErrors) / elapsed
fmt.Printf("[%s] rps=%.1f err/s=%.1f total=%d errors=%d streams=%d\n",
now.Format("15:04:05"),
currentRPS,
errorRate,
currentRequests,
currentErrors,
activeStreams.Load(),
)
// Print error breakdown every 30 seconds
reportCount++
if reportCount%6 == 0 && currentErrors > 0 {
errMu.Lock()
fmt.Printf(" Error breakdown:\n")
for k, v := range errorCounts {
fmt.Printf(" %s: %d\n", k, v)
}
errMu.Unlock()
}
lastRequests = currentRequests
lastErrors = currentErrors
lastTime = now
}
}
}

48
tools/pgimport/README.md Normal file
View File

@@ -0,0 +1,48 @@
# pgimport
One-off migration script to import ntfy data from SQLite to PostgreSQL.
This is **not** a generic migration tool. It only works with specific SQLite schema versions
(message cache v14, user db v6, web push v1) and their corresponding PostgreSQL schemas.
If your database versions differ, this tool will refuse to run.
## Build
```bash
go build -o pgimport ./tools/pgimport/
```
## Usage
```bash
# Using CLI flags
pgimport \
--database-url "postgres://user:pass@host:5432/ntfy?sslmode=require" \
--cache-file /var/cache/ntfy/cache.db \
--auth-file /var/lib/ntfy/user.db \
--web-push-file /var/lib/ntfy/webpush.db
# Using --create-schema to set up PostgreSQL schema automatically
pgimport \
--create-schema \
--database-url "postgres://user:pass@host:5432/ntfy?sslmode=require" \
--cache-file /var/cache/ntfy/cache.db \
--auth-file /var/lib/ntfy/user.db \
--web-push-file /var/lib/ntfy/webpush.db
# Using server.yml (flags override config values)
pgimport --config /etc/ntfy/server.yml
```
## Prerequisites
- PostgreSQL schema must already be set up, either by running ntfy with `database-url` once,
or by passing `--create-schema` to pgimport to create the initial schema automatically
- ntfy must not be running during the import
- All three SQLite files are optional; only the ones specified will be imported
## Notes
- The tool is idempotent and safe to re-run
- After importing, row counts and content are verified against the SQLite sources
- Invalid UTF-8 in messages is replaced with the Unicode replacement character

1164
tools/pgimport/main.go Normal file

File diff suppressed because it is too large Load Diff

141
tools/s3cli/main.go Normal file
View File

@@ -0,0 +1,141 @@
// Command s3cli is a minimal CLI for testing the s3 package. It supports put, get, rm, and ls.
//
// Usage:
//
// export S3_URL="s3://ACCESS_KEY:SECRET_KEY@BUCKET/PREFIX?region=REGION&endpoint=ENDPOINT"
//
// s3cli put <key> <file> Upload a file
// s3cli put <key> - Upload from stdin
// s3cli get <key> Download to stdout
// s3cli rm <key> [<key>...] Delete one or more objects
// s3cli ls List all objects
package main
import (
"context"
"fmt"
"io"
"os"
"text/tabwriter"
"heckel.io/ntfy/v2/s3"
)
func main() {
if len(os.Args) < 2 {
usage()
}
s3URL := os.Getenv("S3_URL")
if s3URL == "" {
fail("S3_URL environment variable is required")
}
cfg, err := s3.ParseURL(s3URL)
if err != nil {
fail("invalid S3_URL: %s", err)
}
client := s3.New(cfg)
ctx := context.Background()
switch os.Args[1] {
case "put":
cmdPut(ctx, client)
case "get":
cmdGet(ctx, client)
case "rm":
cmdRm(ctx, client)
case "ls":
cmdLs(ctx, client)
default:
usage()
}
}
func cmdPut(ctx context.Context, client *s3.Client) {
if len(os.Args) != 4 {
fail("usage: s3cli put <key> <file|->\n")
}
key := os.Args[2]
path := os.Args[3]
var r io.Reader
if path == "-" {
r = os.Stdin
} else {
f, err := os.Open(path)
if err != nil {
fail("open %s: %s", path, err)
}
defer f.Close()
r = f
}
if err := client.PutObject(ctx, key, r); err != nil {
fail("put: %s", err)
}
fmt.Fprintf(os.Stderr, "uploaded %s\n", key)
}
func cmdGet(ctx context.Context, client *s3.Client) {
if len(os.Args) != 3 {
fail("usage: s3cli get <key>\n")
}
key := os.Args[2]
reader, size, err := client.GetObject(ctx, key)
if err != nil {
fail("get: %s", err)
}
defer reader.Close()
n, err := io.Copy(os.Stdout, reader)
if err != nil {
fail("read: %s", err)
}
fmt.Fprintf(os.Stderr, "downloaded %s (%d bytes, content-length: %d)\n", key, n, size)
}
func cmdRm(ctx context.Context, client *s3.Client) {
if len(os.Args) < 3 {
fail("usage: s3cli rm <key> [<key>...]\n")
}
keys := os.Args[2:]
if err := client.DeleteObjects(ctx, keys); err != nil {
fail("rm: %s", err)
}
fmt.Fprintf(os.Stderr, "deleted %d object(s)\n", len(keys))
}
func cmdLs(ctx context.Context, client *s3.Client) {
objects, err := client.ListAllObjects(ctx)
if err != nil {
fail("ls: %s", err)
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
var totalSize int64
for _, obj := range objects {
fmt.Fprintf(w, "%d\t%s\n", obj.Size, obj.Key)
totalSize += obj.Size
}
w.Flush()
fmt.Fprintf(os.Stderr, "%d object(s), %d bytes total\n", len(objects), totalSize)
}
func usage() {
fmt.Fprintf(os.Stderr, `Usage: s3cli <command> [args...]
Commands:
put <key> <file|-> Upload a file (use - for stdin)
get <key> Download to stdout
rm <key> [keys...] Delete objects
ls List all objects
Environment:
S3_URL S3 connection URL (required)
s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
`)
os.Exit(1)
}
func fail(format string, args ...any) {
fmt.Fprintf(os.Stderr, format+"\n", args...)
os.Exit(1)
}

File diff suppressed because it is too large Load Diff

286
user/manager_postgres.go Normal file
View File

@@ -0,0 +1,286 @@
package user
import (
"heckel.io/ntfy/v2/db"
)
// PostgreSQL queries
const (
// User queries
postgresSelectUsersQuery = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
LEFT JOIN tier t on t.id = u.tier_id
ORDER BY
CASE u.role
WHEN 'admin' THEN 1
WHEN 'anonymous' THEN 3
ELSE 2
END, u.user_name
`
postgresSelectUserByIDQuery = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = $1
`
postgresSelectUserByNameQuery = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user_name = $1
`
postgresSelectUserByTokenQuery = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE tk.token = $1 AND (tk.expires = 0 OR tk.expires >= $2)
`
postgresSelectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = $1
`
postgresSelectUsernamesQuery = `
SELECT user_name
FROM "user"
ORDER BY
CASE role
WHEN 'admin' THEN 1
WHEN 'anonymous' THEN 3
ELSE 2
END, user_name
`
postgresSelectUserCountQuery = `SELECT COUNT(*) FROM "user"`
postgresSelectUserIDFromUsernameQuery = `SELECT id FROM "user" WHERE user_name = $1`
postgresInsertUserQuery = `INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) VALUES ($1, $2, $3, $4, $5, $6, $7)`
postgresUpdateUserPassQuery = `UPDATE "user" SET pass = $1 WHERE user_name = $2`
postgresUpdateUserRoleQuery = `UPDATE "user" SET role = $1 WHERE user_name = $2`
postgresUpdateUserProvisionedQuery = `UPDATE "user" SET provisioned = $1 WHERE user_name = $2`
postgresUpdateUserPrefsQuery = `UPDATE "user" SET prefs = $1 WHERE id = $2`
postgresUpdateUserStatsQuery = `UPDATE "user" SET stats_messages = $1, stats_emails = $2, stats_calls = $3 WHERE id = $4`
postgresUpdateUserStatsResetAllQuery = `UPDATE "user" SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
postgresUpdateUserTierQuery = `UPDATE "user" SET tier_id = (SELECT id FROM tier WHERE code = $1) WHERE user_name = $2`
postgresUpdateUserDeletedQuery = `UPDATE "user" SET deleted = $1 WHERE id = $2`
postgresDeleteUserQuery = `DELETE FROM "user" WHERE user_name = $1`
postgresDeleteUserTierQuery = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
postgresDeleteUsersMarkedQuery = `DELETE FROM "user" WHERE deleted < $1`
postgresDeleteUsersProvisionedQuery = `DELETE FROM "user" WHERE provisioned = true`
// Access queries
postgresSelectTopicPermsQuery = `
SELECT read, write
FROM user_access a
JOIN "user" u ON u.id = a.user_id
WHERE (u.user_name = $1 OR u.user_name = $2) AND $3 LIKE a.topic ESCAPE '\'
ORDER BY u.user_name DESC, LENGTH(a.topic) DESC, CASE WHEN a.write THEN 1 ELSE 0 END DESC
`
postgresSelectUserAllAccessQuery = `
SELECT user_id, topic, read, write, provisioned
FROM user_access
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
`
postgresSelectUserAccessQuery = `
SELECT topic, read, write, provisioned
FROM user_access
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
`
postgresSelectUserReservationsQuery = `
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
FROM user_access a_user
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM "user" WHERE user_name = $1)
WHERE a_user.user_id = a_user.owner_user_id
AND a_user.owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
ORDER BY a_user.topic
`
postgresSelectUserReservationsCountQuery = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
`
postgresSelectUserReservationsOwnerQuery = `
SELECT owner_user_id
FROM user_access
WHERE topic = $1
AND user_id = owner_user_id
`
postgresSelectUserHasReservationQuery = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
AND topic = $2
`
postgresSelectOtherAccessCountQuery = `
SELECT COUNT(*)
FROM user_access
WHERE (topic = $1 OR $2 LIKE topic ESCAPE '\')
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM "user" WHERE user_name = $3))
`
postgresUpsertUserAccessQuery = `
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
VALUES (
(SELECT id FROM "user" WHERE user_name = $1),
$2,
$3,
$4,
CASE WHEN $5 = '' THEN NULL ELSE (SELECT id FROM "user" WHERE user_name = $6) END,
$7
)
ON CONFLICT (user_id, topic)
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
`
postgresDeleteUserAccessQuery = `
DELETE FROM user_access
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
`
postgresDeleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = true`
postgresDeleteTopicAccessQuery = `
DELETE FROM user_access
WHERE (user_id = (SELECT id FROM "user" WHERE user_name = $1) OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2))
AND topic = $3
`
postgresDeleteAllAccessQuery = `DELETE FROM user_access`
// Token queries
postgresSelectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1 AND token = $2`
postgresSelectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1`
postgresSelectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = $1`
postgresSelectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = true`
postgresUpsertTokenQuery = `
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (user_id, token)
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
`
postgresUpdateTokenQuery = `UPDATE user_token SET label = $1, expires = $2 WHERE user_id = $3 AND token = $4`
postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1`
postgresDeleteAllProvisionedTokensQuery = `DELETE FROM user_token WHERE provisioned = true`
postgresDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = $1`
postgresDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
postgresDeleteExcessTokensQuery = `
DELETE FROM user_token
WHERE user_id = $1
AND (user_id, token) NOT IN (
SELECT user_id, token
FROM user_token
WHERE user_id = $2
ORDER BY expires DESC
LIMIT $3
)
`
// Tier queries
postgresInsertTierQuery = `
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
`
postgresUpdateTierQuery = `
UPDATE tier
SET name = $1, messages_limit = $2, messages_expiry_duration = $3, emails_limit = $4, calls_limit = $5, reservations_limit = $6, attachment_file_size_limit = $7, attachment_total_size_limit = $8, attachment_expiry_duration = $9, attachment_bandwidth_limit = $10, stripe_monthly_price_id = $11, stripe_yearly_price_id = $12
WHERE code = $13
`
postgresSelectTiersQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
`
postgresSelectTierByCodeQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE code = $1
`
postgresSelectTierByPriceIDQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE (stripe_monthly_price_id = $1 OR stripe_yearly_price_id = $2)
`
postgresDeleteTierQuery = `DELETE FROM tier WHERE code = $1`
// Phone queries
postgresSelectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = $1`
postgresInsertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2)`
postgresDeletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = $1 AND phone_number = $2`
// Billing queries
postgresUpdateBillingQuery = `
UPDATE "user"
SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6
WHERE user_name = $7
`
)
// NewPostgresManager creates a new Manager backed by a PostgreSQL database using an existing connection pool.
var postgresQueries = queries{
selectUserByID: postgresSelectUserByIDQuery,
selectUserByName: postgresSelectUserByNameQuery,
selectUserByToken: postgresSelectUserByTokenQuery,
selectUserByStripeCustomerID: postgresSelectUserByStripeCustomerIDQuery,
selectUsernames: postgresSelectUsernamesQuery,
selectUsers: postgresSelectUsersQuery,
selectUserCount: postgresSelectUserCountQuery,
selectUserIDFromUsername: postgresSelectUserIDFromUsernameQuery,
insertUser: postgresInsertUserQuery,
updateUserPass: postgresUpdateUserPassQuery,
updateUserRole: postgresUpdateUserRoleQuery,
updateUserProvisioned: postgresUpdateUserProvisionedQuery,
updateUserPrefs: postgresUpdateUserPrefsQuery,
updateUserStats: postgresUpdateUserStatsQuery,
updateUserStatsResetAll: postgresUpdateUserStatsResetAllQuery,
updateUserTier: postgresUpdateUserTierQuery,
updateUserDeleted: postgresUpdateUserDeletedQuery,
deleteUser: postgresDeleteUserQuery,
deleteUserTier: postgresDeleteUserTierQuery,
deleteUsersMarked: postgresDeleteUsersMarkedQuery,
deleteUsersProvisioned: postgresDeleteUsersProvisionedQuery,
selectTopicPerms: postgresSelectTopicPermsQuery,
selectUserAllAccess: postgresSelectUserAllAccessQuery,
selectUserAccess: postgresSelectUserAccessQuery,
selectUserReservations: postgresSelectUserReservationsQuery,
selectUserReservationsCount: postgresSelectUserReservationsCountQuery,
selectUserReservationsOwner: postgresSelectUserReservationsOwnerQuery,
selectUserHasReservation: postgresSelectUserHasReservationQuery,
selectOtherAccessCount: postgresSelectOtherAccessCountQuery,
upsertUserAccess: postgresUpsertUserAccessQuery,
deleteUserAccess: postgresDeleteUserAccessQuery,
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisionedQuery,
deleteTopicAccess: postgresDeleteTopicAccessQuery,
deleteAllAccess: postgresDeleteAllAccessQuery,
selectToken: postgresSelectTokenQuery,
selectTokens: postgresSelectTokensQuery,
selectTokenCount: postgresSelectTokenCountQuery,
selectAllProvisionedTokens: postgresSelectAllProvisionedTokensQuery,
upsertToken: postgresUpsertTokenQuery,
updateToken: postgresUpdateTokenQuery,
updateTokenLastAccess: postgresUpdateTokenLastAccessQuery,
deleteToken: postgresDeleteTokenQuery,
deleteProvisionedToken: postgresDeleteProvisionedTokenQuery,
deleteAllProvisionedTokens: postgresDeleteAllProvisionedTokensQuery,
deleteAllToken: postgresDeleteAllTokenQuery,
deleteExpiredTokens: postgresDeleteExpiredTokensQuery,
deleteExcessTokens: postgresDeleteExcessTokensQuery,
insertTier: postgresInsertTierQuery,
selectTiers: postgresSelectTiersQuery,
selectTierByCode: postgresSelectTierByCodeQuery,
selectTierByPriceID: postgresSelectTierByPriceIDQuery,
updateTier: postgresUpdateTierQuery,
deleteTier: postgresDeleteTierQuery,
selectPhoneNumbers: postgresSelectPhoneNumbersQuery,
insertPhoneNumber: postgresInsertPhoneNumberQuery,
deletePhoneNumber: postgresDeletePhoneNumberQuery,
updateBilling: postgresUpdateBillingQuery,
}
// NewPostgresManager creates a new Manager backed by a PostgreSQL database
func NewPostgresManager(d *db.DB, config *Config) (*Manager, error) {
if err := setupPostgres(d.Primary()); err != nil {
return nil, err
}
return newManager(d, postgresQueries, config)
}

View File

@@ -0,0 +1,113 @@
package user
import (
"database/sql"
"fmt"
)
// Initial PostgreSQL schema
const (
postgresCreateTablesQueries = `
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit BIGINT NOT NULL,
messages_expiry_duration BIGINT NOT NULL,
emails_limit BIGINT NOT NULL,
calls_limit BIGINT NOT NULL,
reservations_limit BIGINT NOT NULL,
attachment_file_size_limit BIGINT NOT NULL,
attachment_total_size_limit BIGINT NOT NULL,
attachment_expiry_duration BIGINT NOT NULL,
attachment_bandwidth_limit BIGINT NOT NULL,
stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT,
UNIQUE(code),
UNIQUE(stripe_monthly_price_id),
UNIQUE(stripe_yearly_price_id)
);
CREATE TABLE IF NOT EXISTS "user" (
id TEXT PRIMARY KEY,
tier_id TEXT REFERENCES tier(id),
user_name TEXT NOT NULL UNIQUE,
pass TEXT NOT NULL,
role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')),
prefs JSONB NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned BOOLEAN NOT NULL,
stats_messages BIGINT NOT NULL DEFAULT 0,
stats_emails BIGINT NOT NULL DEFAULT 0,
stats_calls BIGINT NOT NULL DEFAULT 0,
stripe_customer_id TEXT UNIQUE,
stripe_subscription_id TEXT UNIQUE,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until BIGINT,
stripe_subscription_cancel_at BIGINT,
created BIGINT NOT NULL,
deleted BIGINT
);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
topic TEXT NOT NULL,
read BOOLEAN NOT NULL,
write BOOLEAN NOT NULL,
owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE,
provisioned BOOLEAN NOT NULL,
PRIMARY KEY (user_id, topic)
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
token TEXT NOT NULL UNIQUE,
label TEXT NOT NULL,
last_access BIGINT NOT NULL,
last_origin TEXT NOT NULL,
expires BIGINT NOT NULL,
provisioned BOOLEAN NOT NULL,
PRIMARY KEY (user_id, token)
);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number)
);
CREATE TABLE IF NOT EXISTS schema_version (
store TEXT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
ON CONFLICT (id) DO NOTHING;
`
)
// Schema table management queries for Postgres
const (
postgresCurrentSchemaVersion = 6
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'user'`
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
)
func setupPostgres(db *sql.DB) error {
var schemaVersion int
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
if err != nil {
return setupNewPostgres(db)
}
if schemaVersion > postgresCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
}
// Note: PostgreSQL migrations will be added when needed
return nil
}
func setupNewPostgres(db *sql.DB) error {
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
return err
}
if _, err := db.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
return err
}
return nil
}

295
user/manager_sqlite.go Normal file
View File

@@ -0,0 +1,295 @@
package user
import (
"database/sql"
"fmt"
"path/filepath"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/util"
)
const (
// User queries
sqliteSelectUsersQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
ORDER BY
CASE u.role
WHEN 'admin' THEN 1
WHEN 'anonymous' THEN 3
ELSE 2
END, u.user
`
sqliteSelectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
`
sqliteSelectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ?
`
sqliteSelectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
`
sqliteSelectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ?
`
sqliteSelectUsernamesQuery = `
SELECT user
FROM user
ORDER BY
CASE role
WHEN 'admin' THEN 1
WHEN 'anonymous' THEN 3
ELSE 2
END, user
`
sqliteSelectUserCountQuery = `SELECT COUNT(*) FROM user`
sqliteSelectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?`
sqliteInsertUserQuery = `INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) VALUES (?, ?, ?, ?, ?, ?, ?)`
sqliteUpdateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
sqliteUpdateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
sqliteUpdateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
sqliteUpdateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
sqliteUpdateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
sqliteUpdateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
sqliteUpdateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
sqliteUpdateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
sqliteDeleteUserQuery = `DELETE FROM user WHERE user = ?`
sqliteDeleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
sqliteDeleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
sqliteDeleteUsersProvisionedQuery = `DELETE FROM user WHERE provisioned = 1`
// Access queries
sqliteSelectTopicPermsQuery = `
SELECT read, write
FROM user_access a
JOIN user u ON u.id = a.user_id
WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC
`
sqliteSelectUserAllAccessQuery = `
SELECT user_id, topic, read, write, provisioned
FROM user_access
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
`
sqliteSelectUserAccessQuery = `
SELECT topic, read, write, provisioned
FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
`
sqliteSelectUserReservationsQuery = `
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
FROM user_access a_user
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
WHERE a_user.user_id = a_user.owner_user_id
AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
ORDER BY a_user.topic
`
sqliteSelectUserReservationsCountQuery = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
`
sqliteSelectUserReservationsOwnerQuery = `
SELECT owner_user_id
FROM user_access
WHERE topic = ?
AND user_id = owner_user_id
`
sqliteSelectUserHasReservationQuery = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
AND topic = ?
`
sqliteSelectOtherAccessCountQuery = `
SELECT COUNT(*)
FROM user_access
WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
`
sqliteUpsertUserAccessQuery = `
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
ON CONFLICT (user_id, topic)
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
`
sqliteDeleteUserAccessQuery = `
DELETE FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
`
sqliteDeleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = 1`
sqliteDeleteTopicAccessQuery = `
DELETE FROM user_access
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
AND topic = ?
`
sqliteDeleteAllAccessQuery = `DELETE FROM user_access`
// Token queries
sqliteSelectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
sqliteSelectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
sqliteSelectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
sqliteSelectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
sqliteUpsertTokenQuery = `
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (user_id, token)
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
`
sqliteUpdateTokenQuery = `UPDATE user_token SET label = ?, expires = ? WHERE user_id = ? AND token = ?`
sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
sqliteDeleteAllProvisionedTokensQuery = `DELETE FROM user_token WHERE provisioned = 1`
sqliteDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
sqliteDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
sqliteDeleteExcessTokensQuery = `
DELETE FROM user_token
WHERE user_id = ?
AND (user_id, token) NOT IN (
SELECT user_id, token
FROM user_token
WHERE user_id = ?
ORDER BY expires DESC
LIMIT ?
)
`
// Tier queries
sqliteInsertTierQuery = `
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
sqliteUpdateTierQuery = `
UPDATE tier
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
WHERE code = ?
`
sqliteSelectTiersQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
`
sqliteSelectTierByCodeQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE code = ?
`
sqliteSelectTierByPriceIDQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
`
sqliteDeleteTierQuery = `DELETE FROM tier WHERE code = ?`
// Phone queries
sqliteSelectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = ?`
sqliteInsertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
sqliteDeletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
// Billing queries
sqliteUpdateBillingQuery = `
UPDATE user
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
WHERE user = ?
`
)
var sqliteQueries = queries{
selectUserByID: sqliteSelectUserByIDQuery,
selectUserByName: sqliteSelectUserByNameQuery,
selectUserByToken: sqliteSelectUserByTokenQuery,
selectUserByStripeCustomerID: sqliteSelectUserByStripeCustomerIDQuery,
selectUsernames: sqliteSelectUsernamesQuery,
selectUsers: sqliteSelectUsersQuery,
selectUserCount: sqliteSelectUserCountQuery,
selectUserIDFromUsername: sqliteSelectUserIDFromUsernameQuery,
insertUser: sqliteInsertUserQuery,
updateUserPass: sqliteUpdateUserPassQuery,
updateUserRole: sqliteUpdateUserRoleQuery,
updateUserProvisioned: sqliteUpdateUserProvisionedQuery,
updateUserPrefs: sqliteUpdateUserPrefsQuery,
updateUserStats: sqliteUpdateUserStatsQuery,
updateUserStatsResetAll: sqliteUpdateUserStatsResetAllQuery,
updateUserTier: sqliteUpdateUserTierQuery,
updateUserDeleted: sqliteUpdateUserDeletedQuery,
deleteUser: sqliteDeleteUserQuery,
deleteUserTier: sqliteDeleteUserTierQuery,
deleteUsersMarked: sqliteDeleteUsersMarkedQuery,
deleteUsersProvisioned: sqliteDeleteUsersProvisionedQuery,
selectTopicPerms: sqliteSelectTopicPermsQuery,
selectUserAllAccess: sqliteSelectUserAllAccessQuery,
selectUserAccess: sqliteSelectUserAccessQuery,
selectUserReservations: sqliteSelectUserReservationsQuery,
selectUserReservationsCount: sqliteSelectUserReservationsCountQuery,
selectUserReservationsOwner: sqliteSelectUserReservationsOwnerQuery,
selectUserHasReservation: sqliteSelectUserHasReservationQuery,
selectOtherAccessCount: sqliteSelectOtherAccessCountQuery,
upsertUserAccess: sqliteUpsertUserAccessQuery,
deleteUserAccess: sqliteDeleteUserAccessQuery,
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisionedQuery,
deleteTopicAccess: sqliteDeleteTopicAccessQuery,
deleteAllAccess: sqliteDeleteAllAccessQuery,
selectToken: sqliteSelectTokenQuery,
selectTokens: sqliteSelectTokensQuery,
selectTokenCount: sqliteSelectTokenCountQuery,
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokensQuery,
upsertToken: sqliteUpsertTokenQuery,
updateToken: sqliteUpdateTokenQuery,
updateTokenLastAccess: sqliteUpdateTokenLastAccessQuery,
deleteToken: sqliteDeleteTokenQuery,
deleteProvisionedToken: sqliteDeleteProvisionedTokenQuery,
deleteAllProvisionedTokens: sqliteDeleteAllProvisionedTokensQuery,
deleteAllToken: sqliteDeleteAllTokenQuery,
deleteExpiredTokens: sqliteDeleteExpiredTokensQuery,
deleteExcessTokens: sqliteDeleteExcessTokensQuery,
insertTier: sqliteInsertTierQuery,
selectTiers: sqliteSelectTiersQuery,
selectTierByCode: sqliteSelectTierByCodeQuery,
selectTierByPriceID: sqliteSelectTierByPriceIDQuery,
updateTier: sqliteUpdateTierQuery,
deleteTier: sqliteDeleteTierQuery,
selectPhoneNumbers: sqliteSelectPhoneNumbersQuery,
insertPhoneNumber: sqliteInsertPhoneNumberQuery,
deletePhoneNumber: sqliteDeletePhoneNumberQuery,
updateBilling: sqliteUpdateBillingQuery,
}
// NewSQLiteManager creates a new Manager backed by a SQLite database
func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager, error) {
parentDir := filepath.Dir(filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
}
d, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupSQLite(d); err != nil {
return nil, err
}
if err := runSQLiteStartupQueries(d, startupQueries); err != nil {
return nil, err
}
return newManager(db.New(&db.Host{DB: d}, nil), sqliteQueries, config)
}

View File

@@ -0,0 +1,465 @@
package user
import (
"database/sql"
"fmt"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
// Initial SQLite schema
const (
sqliteCreateTablesQueries = `
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
calls_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
)
const (
sqliteBuiltinStartupQueries = `PRAGMA foreign_keys = ON;`
)
// Schema version table management for SQLite
const (
sqliteCurrentSchemaVersion = 6
sqliteInsertSchemaVersionQuery = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteUpdateSchemaVersionQuery = `UPDATE schemaVersion SET version = ? WHERE id = 1`
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
// Schema migrations for SQLite
const (
// 1 -> 2 (complex migration!)
sqliteMigrate1To2CreateTablesQueries = `
ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
sqliteMigrate1To2SelectAllOldUsernamesNoTxQuery = `SELECT user FROM user_old`
sqliteMigrate1To2InsertUserNoTxQuery = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
`
sqliteMigrate1To2InsertFromOldTablesAndDropNoTxQuery = `
INSERT INTO user_access (user_id, topic, read, write)
SELECT u.id, a.topic, a.read, a.write
FROM user u
JOIN access a ON u.user = a.user;
DROP TABLE access;
DROP TABLE user_old;
`
// 2 -> 3
sqliteMigrate2To3UpdateQueries = `
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
DROP INDEX IF EXISTS idx_tier_price_id;
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
`
// 3 -> 4
sqliteMigrate3To4UpdateQueries = `
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
`
// 4 -> 5
sqliteMigrate4To5UpdateQueries = `
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
`
// 5 -> 6
sqliteMigrate5To6UpdateQueries = `
PRAGMA foreign_keys=off;
-- Alter user table: Add provisioned column
ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
INSERT INTO user
SELECT
id,
tier_id,
user,
pass,
role,
prefs,
sync_topic,
0, -- provisioned
stats_messages,
stats_emails,
stats_calls,
stripe_customer_id,
stripe_subscription_id,
stripe_subscription_status,
stripe_subscription_interval,
stripe_subscription_paid_until,
stripe_subscription_cancel_at,
created,
deleted
FROM user_old;
DROP TABLE user_old;
-- Alter user_access table: Add provisioned column
ALTER TABLE user_access RENAME TO user_access_old;
CREATE TABLE user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INTEGER NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
DROP TABLE user_access_old;
-- Alter user_token table: Add provisioned column
ALTER TABLE user_token RENAME TO user_token_old;
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
DROP TABLE user_token_old;
-- Recreate indices
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
-- Re-enable foreign keys
PRAGMA foreign_keys=on;
`
)
var (
sqliteMigrations = map[int]func(db *sql.DB) error{
1: sqliteMigrateFrom1,
2: sqliteMigrateFrom2,
3: sqliteMigrateFrom3,
4: sqliteMigrateFrom4,
5: sqliteMigrateFrom5,
}
)
func setupSQLite(db *sql.DB) error {
var schemaVersion int
if err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
return setupNewSQLite(db)
}
if schemaVersion == sqliteCurrentSchemaVersion {
return nil
} else if schemaVersion > sqliteCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
}
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
fn, ok := sqliteMigrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db); err != nil {
return err
}
}
return nil
}
func setupNewSQLite(sqlDB *sql.DB) error {
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteCreateTablesQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
return err
}
return nil
})
}
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
return err
}
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
return nil
}
func sqliteMigrateFrom1(sqlDB *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
// Rename user -> user_old, and create new tables
if _, err := tx.Exec(sqliteMigrate1To2CreateTablesQueries); err != nil {
return err
}
// Insert users from user_old into new user table, with ID and sync_topic
rows, err := tx.Query(sqliteMigrate1To2SelectAllOldUsernamesNoTxQuery)
if err != nil {
return err
}
defer rows.Close()
usernames := make([]string, 0)
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return err
}
usernames = append(usernames, username)
}
if err := rows.Close(); err != nil {
return err
}
for _, username := range usernames {
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
if _, err := tx.Exec(sqliteMigrate1To2InsertUserNoTxQuery, userID, syncTopic, username); err != nil {
return err
}
}
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
if _, err := tx.Exec(sqliteMigrate1To2InsertFromOldTablesAndDropNoTxQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 2); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom2(sqlDB *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate2To3UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 3); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom3(sqlDB *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate3To4UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 4); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom4(sqlDB *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate4To5UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 5); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom5(sqlDB *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate5To6UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 6); err != nil {
return err
}
return nil
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -2,11 +2,12 @@ package user
import (
"errors"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/payments"
"net/netip"
"strings"
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/payments"
)
// User is a struct that represents a user
@@ -242,6 +243,20 @@ const (
everyoneID = "u_everyone"
)
// Config holds the configuration for the user Manager
type Config struct {
Filename string // Database filename, e.g. "/var/lib/ntfy/user.db" (SQLite)
DatabaseURL string // Database connection string (PostgreSQL)
StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers (SQLite only)
DefaultAccess Permission // Default permission if no ACL matches
ProvisionEnabled bool // Hack: Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
Users []*User // Predefined users to create on startup
Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
BcryptCost int // Cost of generated passwords; lowering makes testing faster
}
// Error constants used by the package
var (
ErrUnauthenticated = errors.New("unauthenticated")
@@ -259,3 +274,75 @@ var (
ErrProvisionedUserChange = errors.New("cannot change or delete provisioned user")
ErrProvisionedTokenChange = errors.New("cannot change or delete provisioned token")
)
// queries holds the database-specific SQL queries
type queries struct {
// User queries
selectUserByID string
selectUserByName string
selectUserByToken string
selectUserByStripeCustomerID string
selectUsernames string
selectUsers string
selectUserCount string
selectUserIDFromUsername string
insertUser string
updateUserPass string
updateUserRole string
updateUserProvisioned string
updateUserPrefs string
updateUserStats string
updateUserStatsResetAll string
updateUserTier string
updateUserDeleted string
deleteUser string
deleteUserTier string
deleteUsersMarked string
deleteUsersProvisioned string
// Access queries
selectTopicPerms string
selectUserAllAccess string
selectUserAccess string
selectUserReservations string
selectUserReservationsCount string
selectUserReservationsOwner string
selectUserHasReservation string
selectOtherAccessCount string
upsertUserAccess string
deleteUserAccess string
deleteUserAccessProvisioned string
deleteTopicAccess string
deleteAllAccess string
// Token queries
selectToken string
selectTokens string
selectTokenCount string
selectAllProvisionedTokens string
upsertToken string
updateToken string
updateTokenLastAccess string
deleteToken string
deleteProvisionedToken string
deleteAllProvisionedTokens string
deleteAllToken string
deleteExpiredTokens string
deleteExcessTokens string
// Tier queries
insertTier string
selectTiers string
selectTierByCode string
selectTierByPriceID string
updateTier string
deleteTier string
// Phone queries
selectPhoneNumbers string
insertPhoneNumber string
deletePhoneNumber string
// Billing queries
updateBilling string
}

View File

@@ -1,10 +1,12 @@
package user
import (
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/util"
"database/sql"
"regexp"
"strings"
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/util"
)
var (
@@ -77,3 +79,37 @@ func hashPassword(password string, cost int) (string, error) {
}
return string(hash), nil
}
func nullString(s string) sql.NullString {
if s == "" {
return sql.NullString{}
}
return sql.NullString{String: s, Valid: true}
}
func nullInt64(v int64) sql.NullInt64 {
if v == 0 {
return sql.NullInt64{}
}
return sql.NullInt64{Int64: v, Valid: true}
}
// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
// and escapes '_', assuming '\' as escape character.
func toSQLWildcard(s string) string {
return escapeUnderscore(strings.ReplaceAll(s, "*", "%"))
}
// fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*',
// and removes the '\_' escape character.
func fromSQLWildcard(s string) string {
return strings.ReplaceAll(unescapeUnderscore(s), "%", "*")
}
func escapeUnderscore(s string) string {
return strings.ReplaceAll(s, "_", "\\_")
}
func unescapeUnderscore(s string) string {
return strings.ReplaceAll(s, "\\_", "_")
}

281
user/util_test.go Normal file
View File

@@ -0,0 +1,281 @@
package user
import (
"github.com/stretchr/testify/require"
"strings"
"testing"
)
func TestAllowedRole(t *testing.T) {
require.True(t, AllowedRole(RoleUser))
require.True(t, AllowedRole(RoleAdmin))
require.False(t, AllowedRole(RoleAnonymous))
require.False(t, AllowedRole(Role("invalid")))
require.False(t, AllowedRole(Role("")))
require.False(t, AllowedRole(Role("superadmin")))
}
func TestAllowedTopic(t *testing.T) {
// Valid topics
require.True(t, AllowedTopic("test"))
require.True(t, AllowedTopic("mytopic"))
require.True(t, AllowedTopic("topic123"))
require.True(t, AllowedTopic("my-topic"))
require.True(t, AllowedTopic("my_topic"))
require.True(t, AllowedTopic("Topic123"))
require.True(t, AllowedTopic("a"))
require.True(t, AllowedTopic(strings.Repeat("a", 64))) // Max length
// Invalid topics - wildcards not allowed
require.False(t, AllowedTopic("topic*"))
require.False(t, AllowedTopic("*"))
require.False(t, AllowedTopic("my*topic"))
// Invalid topics - special characters
require.False(t, AllowedTopic("my topic")) // Space
require.False(t, AllowedTopic("my.topic")) // Dot
require.False(t, AllowedTopic("my/topic")) // Slash
require.False(t, AllowedTopic("my@topic")) // At sign
require.False(t, AllowedTopic("my+topic")) // Plus
require.False(t, AllowedTopic("topic!")) // Exclamation
require.False(t, AllowedTopic("topic#")) // Hash
require.False(t, AllowedTopic("topic$")) // Dollar
require.False(t, AllowedTopic("topic%")) // Percent
require.False(t, AllowedTopic("topic&")) // Ampersand
require.False(t, AllowedTopic("my\\topic")) // Backslash
// Invalid topics - length
require.False(t, AllowedTopic("")) // Empty
require.False(t, AllowedTopic(strings.Repeat("a", 65))) // Too long
}
func TestAllowedTopicPattern(t *testing.T) {
// Valid patterns - same as AllowedTopic
require.True(t, AllowedTopicPattern("test"))
require.True(t, AllowedTopicPattern("mytopic"))
require.True(t, AllowedTopicPattern("topic123"))
require.True(t, AllowedTopicPattern("my-topic"))
require.True(t, AllowedTopicPattern("my_topic"))
require.True(t, AllowedTopicPattern("a"))
require.True(t, AllowedTopicPattern(strings.Repeat("a", 64))) // Max length
// Valid patterns - with wildcards
require.True(t, AllowedTopicPattern("*"))
require.True(t, AllowedTopicPattern("topic*"))
require.True(t, AllowedTopicPattern("*topic"))
require.True(t, AllowedTopicPattern("my*topic"))
require.True(t, AllowedTopicPattern("***"))
require.True(t, AllowedTopicPattern("test_*"))
require.True(t, AllowedTopicPattern("my-*-topic"))
require.True(t, AllowedTopicPattern(strings.Repeat("*", 64))) // Max length with wildcards
// Invalid patterns - special characters (other than wildcard)
require.False(t, AllowedTopicPattern("my topic")) // Space
require.False(t, AllowedTopicPattern("my.topic")) // Dot
require.False(t, AllowedTopicPattern("my/topic")) // Slash
require.False(t, AllowedTopicPattern("my@topic")) // At sign
require.False(t, AllowedTopicPattern("my+topic")) // Plus
require.False(t, AllowedTopicPattern("topic!")) // Exclamation
require.False(t, AllowedTopicPattern("topic#")) // Hash
require.False(t, AllowedTopicPattern("topic$")) // Dollar
require.False(t, AllowedTopicPattern("topic%")) // Percent
require.False(t, AllowedTopicPattern("topic&")) // Ampersand
require.False(t, AllowedTopicPattern("my\\topic")) // Backslash
// Invalid patterns - length
require.False(t, AllowedTopicPattern("")) // Empty
require.False(t, AllowedTopicPattern(strings.Repeat("a", 65))) // Too long
}
func TestValidPasswordHash(t *testing.T) {
// Valid bcrypt hashes with different versions
require.Nil(t, ValidPasswordHash("$2a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
require.Nil(t, ValidPasswordHash("$2b$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", 10))
require.Nil(t, ValidPasswordHash("$2y$12$1234567890123456789012u1234567890123456789012345678901", 10))
// Valid hash with minimum cost
require.Nil(t, ValidPasswordHash("$2a$04$1234567890123456789012u1234567890123456789012345678901", 4))
// Invalid - wrong prefix
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("$2c$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("$3a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("bcrypt$10$hash", 10))
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("nothash", 10))
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("", 10))
// Invalid - malformed hash
require.NotNil(t, ValidPasswordHash("$2a$10$tooshort", 10))
require.NotNil(t, ValidPasswordHash("$2a$10", 10))
require.NotNil(t, ValidPasswordHash("$2a$", 10))
// Invalid - cost too low
require.Equal(t, ErrPasswordHashWeak, ValidPasswordHash("$2a$04$1234567890123456789012u1234567890123456789012345678901", 10))
require.Equal(t, ErrPasswordHashWeak, ValidPasswordHash("$2a$09$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
// Edge case - cost exactly at minimum
require.Nil(t, ValidPasswordHash("$2a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
}
func TestValidToken(t *testing.T) {
// Valid tokens
require.True(t, ValidToken("tk_1234567890123456789012345678x"))
require.True(t, ValidToken("tk_abcdefghijklmnopqrstuvwxyzabc"))
require.True(t, ValidToken("tk_ABCDEFGHIJKLMNOPQRSTUVWXYZABC"))
require.True(t, ValidToken("tk_012345678901234567890123456ab"))
require.True(t, ValidToken("tk_-----------------------------"))
require.True(t, ValidToken("tk______________________________"))
// Invalid tokens - wrong prefix
require.False(t, ValidToken("tx_1234567890123456789012345678x"))
require.False(t, ValidToken("tk1234567890123456789012345678xy"))
require.False(t, ValidToken("token_1234567890123456789012345"))
// Invalid tokens - wrong length
require.False(t, ValidToken("tk_")) // Too short
require.False(t, ValidToken("tk_123")) // Too short
require.False(t, ValidToken("tk_123456789012345678901234567890")) // Too long (30 chars after prefix)
require.False(t, ValidToken("tk_123456789012345678901234567")) // Too short (28 chars)
// Invalid tokens - invalid characters
require.False(t, ValidToken("tk_123456789012345678901234567!@"))
require.False(t, ValidToken("tk_12345678901234567890123456 8x"))
require.False(t, ValidToken("tk_123456789012345678901234567.x"))
require.False(t, ValidToken("tk_123456789012345678901234567*x"))
// Invalid tokens - no prefix
require.False(t, ValidToken("1234567890123456789012345678901x"))
require.False(t, ValidToken(""))
}
func TestGenerateToken(t *testing.T) {
// Generate multiple tokens
tokens := make(map[string]bool)
for i := 0; i < 100; i++ {
token := GenerateToken()
// Check format
require.True(t, strings.HasPrefix(token, "tk_"), "Token should start with tk_")
require.Equal(t, 32, len(token), "Token should be 32 characters long")
// Check it's valid
require.True(t, ValidToken(token), "Generated token should be valid")
// Check it's lowercase
require.Equal(t, strings.ToLower(token), token, "Token should be lowercase")
// Check uniqueness
require.False(t, tokens[token], "Token should be unique")
tokens[token] = true
}
// Verify we got 100 unique tokens
require.Equal(t, 100, len(tokens))
}
func TestHashPassword(t *testing.T) {
password := "test-password-123"
// Hash the password
hash, err := HashPassword(password)
require.Nil(t, err)
require.NotEmpty(t, hash)
// Check it's a valid bcrypt hash
require.Nil(t, ValidPasswordHash(hash, DefaultUserPasswordBcryptCost))
// Check it starts with correct prefix
require.True(t, strings.HasPrefix(hash, "$2a$"))
// Hash the same password again - should produce different hash
hash2, err := HashPassword(password)
require.Nil(t, err)
require.NotEqual(t, hash, hash2, "Same password should produce different hashes (salt)")
// Empty password should still work
emptyHash, err := HashPassword("")
require.Nil(t, err)
require.NotEmpty(t, emptyHash)
require.Nil(t, ValidPasswordHash(emptyHash, DefaultUserPasswordBcryptCost))
}
func TestHashPassword_WithCost(t *testing.T) {
password := "test-password"
// Test with different costs
hash4, err := hashPassword(password, 4)
require.Nil(t, err)
require.True(t, strings.HasPrefix(hash4, "$2a$04$"))
hash10, err := hashPassword(password, 10)
require.Nil(t, err)
require.True(t, strings.HasPrefix(hash10, "$2a$10$"))
hash12, err := hashPassword(password, 12)
require.Nil(t, err)
require.True(t, strings.HasPrefix(hash12, "$2a$12$"))
// All should be valid
require.Nil(t, ValidPasswordHash(hash4, 4))
require.Nil(t, ValidPasswordHash(hash10, 10))
require.Nil(t, ValidPasswordHash(hash12, 12))
}
func TestUser_TierID(t *testing.T) {
// User with tier
u := &User{
Tier: &Tier{
ID: "ti_123",
Code: "pro",
},
}
require.Equal(t, "ti_123", u.TierID())
// User without tier
u2 := &User{
Tier: nil,
}
require.Equal(t, "", u2.TierID())
// Nil user
var u3 *User
require.Equal(t, "", u3.TierID())
}
func TestUser_IsAdmin(t *testing.T) {
admin := &User{Role: RoleAdmin}
require.True(t, admin.IsAdmin())
require.False(t, admin.IsUser())
user := &User{Role: RoleUser}
require.False(t, user.IsAdmin())
anonymous := &User{Role: RoleAnonymous}
require.False(t, anonymous.IsAdmin())
// Nil user
var nilUser *User
require.False(t, nilUser.IsAdmin())
}
func TestUser_IsUser(t *testing.T) {
user := &User{Role: RoleUser}
require.True(t, user.IsUser())
require.False(t, user.IsAdmin())
admin := &User{Role: RoleAdmin}
require.False(t, admin.IsUser())
anonymous := &User{Role: RoleAnonymous}
require.False(t, anonymous.IsUser())
// Nil user
var nilUser *User
require.False(t, nilUser.IsUser())
}
func TestPermission_String(t *testing.T) {
require.Equal(t, "read-write", PermissionReadWrite.String())
require.Equal(t, "read-only", PermissionRead.String())
require.Equal(t, "write-only", PermissionWrite.String())
require.Equal(t, "deny-all", PermissionDenyAll.String())
}

View File

@@ -17,6 +17,7 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"
"github.com/gabriel-vasile/mimetype"
"golang.org/x/term"
@@ -434,3 +435,22 @@ func Int(v int) *int {
func Time(v time.Time) *time.Time {
return &v
}
// SanitizeUTF8 ensures a string is safe to store in PostgreSQL by handling two cases:
//
// 1. Invalid UTF-8 sequences: Some clients send Latin-1/ISO-8859-1 encoded text (e.g. accented
// characters like é, ñ, ß) in HTTP headers or SMTP messages. Go treats these as raw bytes in
// strings, but PostgreSQL rejects them. Any invalid UTF-8 byte is replaced with the Unicode
// replacement character (U+FFFD, "<22>") so the message is still delivered rather than lost.
//
// 2. NUL bytes (0x00): These are valid in UTF-8 but PostgreSQL TEXT columns reject them.
// They are stripped entirely.
func SanitizeUTF8(s string) string {
if !utf8.ValidString(s) {
s = strings.ToValidUTF8(s, "\xef\xbf\xbd") // U+FFFD
}
if strings.ContainsRune(s, 0) {
s = strings.ReplaceAll(s, "\x00", "")
}
return s
}

Some files were not shown because too many files have changed in this diff Show More