Compare commits

...

76 Commits

Author SHA1 Message Date
b4b8bd925d v25.7.4
Some checks failed
Default (tags) / security (push) Successful in 12m5s
Default (tags) / test (push) Failing after 4m5s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-19 08:07:34 +00:00
5ac44b898b fix(smart-proxy): include proxy IPs in smart proxy configuration 2026-02-19 08:07:34 +00:00
9b4393b5ac v25.7.3
Some checks failed
Default (tags) / security (push) Successful in 33s
Default (tags) / test (push) Failing after 4m1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-16 14:35:26 +00:00
02b4ed8018 fix(metrics): centralize connection-closed reporting via ConnectionGuard and remove duplicate explicit metrics.connection_closed calls 2026-02-16 14:35:26 +00:00
e4e4b4f1ec v25.7.2
Some checks failed
Default (tags) / security (push) Successful in 33s
Default (tags) / test (push) Failing after 4m4s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-16 13:43:22 +00:00
d361a21543 fix(rustproxy-http): preserve original Host header when proxying and add X-Forwarded-* headers; add TLS WebSocket echo backend helper and integration test for terminate-and-reencrypt websocket 2026-02-16 13:43:22 +00:00
106713a546 v25.7.1
Some checks failed
Default (tags) / security (push) Successful in 34s
Default (tags) / test (push) Failing after 4m6s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-16 13:29:45 +00:00
101675b5f8 fix(proxy): use TLS to backends for terminate-and-reencrypt routes 2026-02-16 13:29:45 +00:00
9fac17bc39 v25.7.0
Some checks failed
Default (tags) / security (push) Successful in 30s
Default (tags) / test (push) Failing after 4m1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-16 12:11:49 +00:00
2e3cf515a4 feat(routes): add protocol-based route matching and ensure terminate-and-reencrypt routes HTTP through the full HTTP proxy; update docs and tests 2026-02-16 12:11:49 +00:00
754d32fd34 v25.6.0
Some checks failed
Default (tags) / security (push) Successful in 1m39s
Default (tags) / test (push) Failing after 5m7s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-16 12:02:36 +00:00
f0b7c27996 feat(rustproxy): add protocol-based routing and backend TLS re-encryption support 2026-02-16 12:02:36 +00:00
db932e8acc v25.5.0
Some checks failed
Default (tags) / security (push) Successful in 1m1s
Default (tags) / test (push) Failing after 5m5s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-16 03:00:39 +00:00
455d5bb757 feat(tls): add shared TLS acceptor with SNI resolver and session resumption; prefer shared acceptor and fall back to per-connection when routes specify custom TLS versions 2026-02-16 03:00:39 +00:00
fa2a27df6d v25.4.0
Some checks failed
Default (tags) / security (push) Successful in 30s
Default (tags) / test (push) Failing after 5m5s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-16 01:37:43 +00:00
7b2ccbdd11 feat(rustproxy): support dynamically loaded TLS certificates via loadCertificate IPC and include them in listener TLS configs for rebuilds and hot-swap 2026-02-16 01:37:43 +00:00
8409984fcc v25.3.1
Some checks failed
Default (tags) / security (push) Successful in 1m44s
Default (tags) / test (push) Failing after 5m8s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-15 15:05:03 +00:00
af10d189a3 fix(plugins): remove unused dependencies and simplify plugin exports 2026-02-15 15:05:03 +00:00
0b4d180cdf v25.3.0
Some checks failed
Default (tags) / security (push) Has been cancelled
Default (tags) / test (push) Has been cancelled
Default (tags) / release (push) Has been cancelled
Default (tags) / metadata (push) Has been cancelled
2026-02-14 14:02:25 +00:00
7b3545d1b5 feat(smart-proxy): add background concurrent certificate provisioning with per-domain timeouts and concurrency control 2026-02-14 14:02:25 +00:00
e837419d5d v25.2.2
Some checks failed
Default (tags) / security (push) Has been cancelled
Default (tags) / test (push) Has been cancelled
Default (tags) / release (push) Has been cancelled
Default (tags) / metadata (push) Has been cancelled
2026-02-14 12:42:20 +00:00
487a603fa3 fix(smart-proxy): start metrics polling before certificate provisioning to avoid blocking metrics collection 2026-02-14 12:42:20 +00:00
d6fdd3fc86 v25.2.1
Some checks failed
Default (tags) / security (push) Has been cancelled
Default (tags) / test (push) Has been cancelled
Default (tags) / release (push) Has been cancelled
Default (tags) / metadata (push) Has been cancelled
2026-02-14 12:28:42 +00:00
344f224c89 fix(smartproxy): no changes detected in git diff 2026-02-14 12:28:42 +00:00
6bbd2b3ee1 test(metrics): add v25.2.0 end-to-end assertions for per-IP, history, and HTTP request metrics 2026-02-14 12:24:48 +00:00
c44216df28 v25.2.0
Some checks failed
Default (tags) / security (push) Has been cancelled
Default (tags) / test (push) Has been cancelled
Default (tags) / release (push) Has been cancelled
Default (tags) / metadata (push) Has been cancelled
2026-02-14 11:15:17 +00:00
f80cdcf41c feat(metrics): add per-IP and HTTP-request metrics, propagate source IP through proxy paths, and expose new metrics to the TS adapter 2026-02-14 11:15:17 +00:00
6c84aedee1 v25.1.0
Some checks failed
Default (tags) / security (push) Has been cancelled
Default (tags) / test (push) Has been cancelled
Default (tags) / release (push) Has been cancelled
Default (tags) / metadata (push) Has been cancelled
2026-02-13 23:18:22 +00:00
1f95d2b6c4 feat(metrics): add real-time throughput sampling and byte-counting metrics 2026-02-13 23:18:22 +00:00
37372353d7 v25.0.0
Some checks failed
Default (tags) / security (push) Has been cancelled
Default (tags) / test (push) Has been cancelled
Default (tags) / release (push) Has been cancelled
Default (tags) / metadata (push) Has been cancelled
2026-02-13 21:24:16 +00:00
7afa4c4c58 BREAKING CHANGE(certs): accept a second eventComms argument in certProvisionFunction, add cert provisioning event types, and emit certificate lifecycle events 2026-02-13 21:24:16 +00:00
998662e137 v24.0.1
Some checks failed
Default (tags) / security (push) Has been cancelled
Default (tags) / test (push) Has been cancelled
Default (tags) / release (push) Has been cancelled
Default (tags) / metadata (push) Has been cancelled
2026-02-13 16:57:46 +00:00
a8f8946a4d fix(proxy): improve proxy robustness: add connect timeouts, graceful shutdown, WebSocket watchdog, and metrics guard 2026-02-13 16:57:46 +00:00
07e464fdac v24.0.0
Some checks failed
Default (tags) / security (push) Has been cancelled
Default (tags) / test (push) Has been cancelled
Default (tags) / release (push) Has been cancelled
Default (tags) / metadata (push) Has been cancelled
2026-02-13 16:32:02 +00:00
0e058594c9 BREAKING CHANGE(smart-proxy): move certificate persistence to an in-memory store and introduce consumer-managed certStore API; add default self-signed fallback cert and change ACME account handling 2026-02-13 16:32:02 +00:00
e0af82c1ef v23.1.6
Some checks failed
Default (tags) / security (push) Has been cancelled
Default (tags) / test (push) Has been cancelled
Default (tags) / release (push) Has been cancelled
Default (tags) / metadata (push) Has been cancelled
2026-02-13 13:08:30 +00:00
efe3d80713 fix(smart-proxy): disable built-in Rust ACME when a certProvisionFunction is provided and improve certificate provisioning flow 2026-02-13 13:08:30 +00:00
6b04bc612b v23.1.5
Some checks failed
Default (tags) / security (push) Successful in 38s
Default (tags) / test (push) Failing after 4m1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-13 12:02:47 +00:00
e774ec87ca fix(smart-proxy): provision certificates for wildcard domains instead of skipping them 2026-02-13 12:02:47 +00:00
cbde778f09 v23.1.4
Some checks failed
Default (tags) / security (push) Successful in 43s
Default (tags) / test (push) Failing after 4m6s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-12 22:35:25 +00:00
bc2bc874a5 fix(tests): make tests more robust and bump small dependencies 2026-02-12 22:35:25 +00:00
fdabf807b0 v23.1.3
Some checks failed
Default (tags) / security (push) Successful in 44s
Default (tags) / test (push) Failing after 4m6s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-12 20:17:32 +00:00
81e0e6b4d8 fix(rustproxy): install default rustls crypto provider early; detect and skip raw fast-path for HTTP connections and return proper HTTP 502 when no route matches 2026-02-12 20:17:32 +00:00
28fa69bf59 v23.1.2
Some checks failed
Default (tags) / security (push) Successful in 36s
Default (tags) / test (push) Failing after 4m1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-11 13:48:30 +00:00
5019658032 fix(core): use node: scoped builtin imports and add route unit tests 2026-02-11 13:48:30 +00:00
a9fe365c78 v23.1.1
Some checks failed
Default (tags) / security (push) Successful in 39s
Default (tags) / test (push) Failing after 4m1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-11 12:52:45 +00:00
32e0410227 fix(rust-proxy): increase rust proxy bridge maxPayloadSize to 100 MB and bump dependencies 2026-02-11 12:52:45 +00:00
fd56064495 v23.1.0
Some checks failed
Default (tags) / security (push) Successful in 35s
Default (tags) / test (push) Failing after 4m1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-10 09:43:40 +00:00
3b7e6a6ed7 feat(rust-bridge): integrate tsrust to build and locate cross-compiled Rust binaries; refactor rust-proxy bridge to use typed IPC and streamline process handling; add @push.rocks/smartrust and update build/dev dependencies 2026-02-10 09:43:40 +00:00
131ed8949a v23.0.0
Some checks failed
Default (tags) / security (push) Successful in 52s
Default (tags) / test (push) Failing after 48s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-09 17:11:37 +00:00
7b3009dc53 BREAKING CHANGE(proxies/nftables-proxy): remove nftables-proxy implementation, models, and utilities from the repository 2026-02-09 17:11:37 +00:00
db2e2fb76e v22.6.0
Some checks failed
Default (tags) / security (push) Successful in 40s
Default (tags) / test (push) Failing after 48s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-09 16:25:33 +00:00
f7605e042e feat(smart-proxy): add socket-handler relay, fast-path port-only forwarding, metrics and bridge improvements, and various TS/Rust integration fixes 2026-02-09 16:25:33 +00:00
41efdb47f8 v22.5.0
Some checks failed
Default (tags) / security (push) Successful in 38s
Default (tags) / test (push) Failing after 55s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-02-09 10:55:46 +00:00
1df3b7af4a feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates 2026-02-09 10:55:46 +00:00
a31fee41df v22.4.2
Some checks failed
Default (tags) / security (push) Successful in 39s
Default (tags) / test (push) Failing after 48s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-01-31 02:01:23 +00:00
9146d7c758 fix(tests): shorten long-lived connection test timeouts and update certificate metadata timestamps 2026-01-31 02:01:23 +00:00
fb0584e68d v22.4.1
Some checks failed
Default (tags) / security (push) Successful in 38s
Default (tags) / test (push) Failing after 49s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-01-30 19:52:36 +00:00
2068b7a1ad fix(smartproxy): improve certificate manager mocking in tests, enhance IPv6 validation, and record initial bytes for connection metrics 2026-01-30 19:52:36 +00:00
1d1e5062a6 v22.4.0
Some checks failed
Default (tags) / security (push) Successful in 38s
Default (tags) / test (push) Failing after 47s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-01-30 10:44:28 +00:00
c2dd7494d6 feat(smart-proxy): calculate when SNI is required for TLS routing and allow session tickets for single-target passthrough routes; add tests, docs, and npm metadata updates 2026-01-30 10:44:28 +00:00
ea3b8290d2 v22.3.0
Some checks failed
Default (tags) / security (push) Successful in 38s
Default (tags) / test (push) Failing after 47s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-01-30 09:22:41 +00:00
9b1adb1d7a feat(docs): update README with installation, improved feature table, expanded quick-start, ACME/email example, API options interface, and clarified licensing/trademark text 2026-01-30 09:22:41 +00:00
90e8f92e86 v22.2.0
Some checks failed
Default (tags) / security (push) Successful in 41s
Default (tags) / test (push) Failing after 49s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-01-30 04:06:32 +00:00
9697ab3078 feat(proxies): introduce nftables command executor and utilities, default certificate provider, expanded route/socket helper modules, and security improvements 2026-01-30 04:06:32 +00:00
f25be4c55a v22.1.1
Some checks failed
Default (tags) / security (push) Successful in 43s
Default (tags) / test (push) Failing after 49s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2025-12-09 21:39:49 +00:00
05c5635a13 fix(tests): Normalize route configurations in tests to use name (remove id) and standardize route names 2025-12-09 21:39:49 +00:00
788fdd79c5 v22.1.0
Some checks failed
Default (tags) / security (push) Successful in 44s
Default (tags) / test (push) Failing after 49s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2025-12-09 13:07:29 +00:00
9c25bf0a27 feat(smart-proxy): Improve connection/rate-limit atomicity, SNI parsing, HttpProxy & ACME orchestration, and routing utilities 2025-12-09 13:07:29 +00:00
a0b23a8e7e v22.0.0
Some checks failed
Default (tags) / security (push) Successful in 49s
Default (tags) / test (push) Failing after 1m7s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2025-12-09 09:33:51 +00:00
c4b9d7eb72 BREAKING CHANGE(smart-proxy/utils/route-validator): Consolidate and refactor route validators; move to class-based API and update usages
Replaced legacy route-validators.ts with a unified route-validator.ts that provides a class-based RouteValidator plus the previous functional API (isValidPort, isValidDomain, validateRouteMatch, validateRouteAction, validateRouteConfig, validateRoutes, hasRequiredPropertiesForAction, assertValidRoute) for backwards compatibility. Updated utils exports and all imports/tests to reference the new module. Also switched static file loading in certificate manager to use SmartFileFactory.nodeFs(), and added @push.rocks/smartserve to devDependencies.
2025-12-09 09:33:50 +00:00
be3ac75422 fix some tests and prepare next step of evolution 2025-12-09 09:19:13 +00:00
ad44274075 21.1.7
Some checks failed
Default (tags) / security (push) Successful in 55s
Default (tags) / test (push) Failing after 46m17s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2025-08-19 13:58:22 +00:00
3efd9c72ba fix(route-validator): Relax domain validation to accept localhost, prefix wildcards (e.g. *example.com) and IP literals; add comprehensive domain validation tests 2025-08-19 13:58:22 +00:00
b96e0cd48e 21.1.6
Some checks failed
Default (tags) / security (push) Successful in 57s
Default (tags) / test (push) Failing after 46m14s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2025-08-19 11:38:20 +00:00
c909d3db3e fix(ip-utils): Fix IP wildcard/shorthand handling and add validation test 2025-08-19 11:38:20 +00:00
219 changed files with 34054 additions and 28513 deletions

3
.gitignore vendored
View File

@@ -17,4 +17,5 @@ dist/
dist_*/
#------# custom
.claude/*
.claude/*
rust/target

View File

@@ -1,68 +0,0 @@
# language of the project (csharp, python, rust, java, typescript, go, cpp, or ruby)
# * For C, use cpp
# * For JavaScript, use typescript
# Special requirements:
# * csharp: Requires the presence of a .sln file in the project folder.
language: typescript
# whether to use the project's gitignore file to ignore files
# Added on 2025-04-07
ignore_all_files_in_gitignore: true
# list of additional paths to ignore
# same syntax as gitignore, so you can use * and **
# Was previously called `ignored_dirs`, please update your config if you are using that.
# Added (renamed)on 2025-04-07
ignored_paths: []
# whether the project is in read-only mode
# If set to true, all editing tools will be disabled and attempts to use them will result in an error
# Added on 2025-04-18
read_only: false
# list of tool names to exclude. We recommend not excluding any tools, see the readme for more details.
# Below is the complete list of tools for convenience.
# To make sure you have the latest list of tools, and to view their descriptions,
# execute `uv run scripts/print_tool_overview.py`.
#
# * `activate_project`: Activates a project by name.
# * `check_onboarding_performed`: Checks whether project onboarding was already performed.
# * `create_text_file`: Creates/overwrites a file in the project directory.
# * `delete_lines`: Deletes a range of lines within a file.
# * `delete_memory`: Deletes a memory from Serena's project-specific memory store.
# * `execute_shell_command`: Executes a shell command.
# * `find_referencing_code_snippets`: Finds code snippets in which the symbol at the given location is referenced.
# * `find_referencing_symbols`: Finds symbols that reference the symbol at the given location (optionally filtered by type).
# * `find_symbol`: Performs a global (or local) search for symbols with/containing a given name/substring (optionally filtered by type).
# * `get_current_config`: Prints the current configuration of the agent, including the active and available projects, tools, contexts, and modes.
# * `get_symbols_overview`: Gets an overview of the top-level symbols defined in a given file.
# * `initial_instructions`: Gets the initial instructions for the current project.
# Should only be used in settings where the system prompt cannot be set,
# e.g. in clients you have no control over, like Claude Desktop.
# * `insert_after_symbol`: Inserts content after the end of the definition of a given symbol.
# * `insert_at_line`: Inserts content at a given line in a file.
# * `insert_before_symbol`: Inserts content before the beginning of the definition of a given symbol.
# * `list_dir`: Lists files and directories in the given directory (optionally with recursion).
# * `list_memories`: Lists memories in Serena's project-specific memory store.
# * `onboarding`: Performs onboarding (identifying the project structure and essential tasks, e.g. for testing or building).
# * `prepare_for_new_conversation`: Provides instructions for preparing for a new conversation (in order to continue with the necessary context).
# * `read_file`: Reads a file within the project directory.
# * `read_memory`: Reads the memory with the given name from Serena's project-specific memory store.
# * `remove_project`: Removes a project from the Serena configuration.
# * `replace_lines`: Replaces a range of lines within a file with new content.
# * `replace_symbol_body`: Replaces the full definition of a symbol.
# * `restart_language_server`: Restarts the language server, may be necessary when edits not through Serena happen.
# * `search_for_pattern`: Performs a search for a pattern in the project.
# * `summarize_changes`: Provides instructions for summarizing the changes made to the codebase.
# * `switch_modes`: Activates modes by providing a list of their names
# * `think_about_collected_information`: Thinking tool for pondering the completeness of collected information.
# * `think_about_task_adherence`: Thinking tool for determining whether the agent is still on track with the current task.
# * `think_about_whether_you_are_done`: Thinking tool for determining whether the task is truly completed.
# * `write_memory`: Writes a named memory (for future reference) to Serena's project-specific memory store.
excluded_tools: []
# initial prompt for the project. It will always be given to the LLM upon activating the project
# (contrary to the memories, which are loaded on demand).
initial_prompt: ""
project_name: "smartproxy"

View File

@@ -1,19 +1,20 @@
-----BEGIN CERTIFICATE-----
MIIDCzCCAfOgAwIBAgIUPU4tviz3ZvsMDjCz1NZRT16b0Y4wDQYJKoZIhvcNAQEL
BQAwFTETMBEGA1UEAwwKcHVzaC5yb2NrczAeFw0yNTAyMDMyMzA5MzRaFw0yNjAy
MDMyMzA5MzRaMBUxEzARBgNVBAMMCnB1c2gucm9ja3MwggEiMA0GCSqGSIb3DQEB
AQUAA4IBDwAwggEKAoIBAQCZMkBYD/pYLBv9MiyHTLRT24kQyPeJBtZqryibi1jk
BT1ZgNl3yo5U6kjj/nYBU/oy7M4OFC0xyaJQ4wpvLHu7xzREqwT9N9WcDcxaahUi
P8+PsjGyznPrtXa1ASzGAYMNvXyWWp3351UWZHMEs6eY/Y7i8m4+0NwP5h8RNBCF
KSFS41Ee9rNAMCnQSHZv1vIzKeVYPmYnCVmL7X2kQb+gS6Rvq5sEGLLKMC5QtTwI
rdkPGpx4xZirIyf8KANbt0sShwUDpiCSuOCtpze08jMzoHLG9Nv97cJQjb/BhiES
hLL+YjfAUFjq0rQ38zFKLJ87QB9Jym05mY6IadGQLXVXAgMBAAGjUzBRMB0GA1Ud
DgQWBBQjpowWjrql/Eo2EVjl29xcjuCgkTAfBgNVHSMEGDAWgBQjpowWjrql/Eo2
EVjl29xcjuCgkTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAY
44vqbaf6ewFrZC0f3Kk4A10lC6qjWkcDFfw+JE8nzt+4+xPqp1eWgZKF2rONyAv2
nG41Xygt19ByancXLU44KB24LX8F1GV5Oo7CGBA+xtoSPc0JulXw9fGclZDC6XiR
P/+vhGgCHicbfP2O+N00pOifrTtf2tmOT4iPXRRo4TxmPzuCd+ZJTlBhPlKCmICq
yGdAiEo6HsSiP+M5qVlNx8s57MhQYk5TpgmI6FU4mO7zfDfSatFonlg+aDbrnaqJ
v/+km02M+oB460GmKwsSTnThHZgLNCLiKqD8bdziiCQjx5u0GjLI6468o+Aehb8l
l/x9vWTTk/QKq41X5hFk
MIIDQTCCAimgAwIBAgIUJm+igT1AVSuwNzjvqjSF6cysw6MwDQYJKoZIhvcNAQEL
BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDIxMzIyMzI1MloXDTM2MDIx
MTIyMzI1MlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
AAOCAQ8AMIIBCgKCAQEAyjitkDd4DdlVk4TfVxKUqdxnJCGj9uyrUPAqR8hB6+bR
+8rW63ryBNYNRizxOGw41E19ascNuyA98mUF4oqjid1K4VqDiKzv1Uq/3NUunCw/
rEddR5hCoVkTsBJjzNgBJqncS606v0hfA00cCkpGR+Te7Q/E8T8lApioz1zFQ05Y
C69oeJHIrJcrIkIFAgmXDgRF0Z4ErUeu+wVOWT267uVAYn5AdFMxCSIBsYtPseqy
cC5EQ6BCBtsIGitlRgzLRg957ZZa+SF38ao+/ijYmOLHpQT0mFaUyLT7BKgxguGs
8CHcIxN5Qo27J3sC5ymnrv2uk5DcAOUcxklXUbVCeQIDAQABo4GKMIGHMB0GA1Ud
DgQWBBShZhz7aX/KhleAfYKvTgyG5ANuDjAfBgNVHSMEGDAWgBShZhz7aX/KhleA
fYKvTgyG5ANuDjAPBgNVHRMBAf8EBTADAQH/MDQGA1UdEQQtMCuCCWxvY2FsaG9z
dIIKcHVzaC5yb2Nrc4IMKi5wdXNoLnJvY2tzhwR/AAABMA0GCSqGSIb3DQEBCwUA
A4IBAQAyUvjUszQp4riqa3CfBFFtjh+7DKNuQPOlYAwSEis4l+YK06Glx4fJBHcx
eCPhQ/0wnPzi6CZe3vVRXd5fX27nVs+lMQD6Oc47F8OmTU6NXnb/1AcvrycDsP8D
9Y9qecekbpegrN1W4D46goBAwvrd6Qy0EHi0Z5z02rfyXAdxm0OmdpuWoIMcEgUQ
YyXIq3zSFE6uoO61WdLvBcXN6iaiSTVy0605WncDe2+UT9MeNq6zi1JD34jsgUrd
xq0WRUk2C6C4Irkf00Q12rXeL+Jv5OwyrUUZRvz0gLgG02UUbB/6Ca5GYNXniEuI
Py/EHTqbtjLIs7HxYjQH86FI9fUj
-----END CERTIFICATE-----

View File

@@ -1,28 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCZMkBYD/pYLBv9
MiyHTLRT24kQyPeJBtZqryibi1jkBT1ZgNl3yo5U6kjj/nYBU/oy7M4OFC0xyaJQ
4wpvLHu7xzREqwT9N9WcDcxaahUiP8+PsjGyznPrtXa1ASzGAYMNvXyWWp3351UW
ZHMEs6eY/Y7i8m4+0NwP5h8RNBCFKSFS41Ee9rNAMCnQSHZv1vIzKeVYPmYnCVmL
7X2kQb+gS6Rvq5sEGLLKMC5QtTwIrdkPGpx4xZirIyf8KANbt0sShwUDpiCSuOCt
pze08jMzoHLG9Nv97cJQjb/BhiEShLL+YjfAUFjq0rQ38zFKLJ87QB9Jym05mY6I
adGQLXVXAgMBAAECggEARGCBBq1PBHbfoUH5TQSIAlvdEEBa9+602lZG7jIioVfT
W7Uem5Ctuan+kcDcY9hbNsqqZ+9KgsvoJmlIGXoF2jjeE/4vUmRO9AHWoc5yk2Be
4NjcxN3QMLdEfiLBnLlFCOd4CdX1ZxZ6TG3WRpV3a1pVIeeqHGB1sKT6Xd/atcwG
RvpiXzu0SutGxVb6WE9r6hovZ4fVERCyCRczUGrUH5ICbxf6E7L4u8xjEYR4uEKK
/8ZkDqrWdRASDAdPPMNqnHUEAho/WnxpNeb6B4lvvv2QWxIS9H1OikF/NzWPgVNS
oPpvtJgjyo5xdgLm3zE4lcSPNVSrh1TBXuAn9TG4WQKBgQDScPFkUNBqjC5iPMof
bqDHlhlptrHmiv9LC0lgjEDPgIEQfjLfdCugwDk32QyAcb5B60upDYeqCFDkfV/C
T536qxevYPjPAjahLPHqMxkWpjvtY6NOTgbbcpVtblU2Fj8R8qbyPNADG31LicU9
GVPtQ4YcVaMWCYbg5107+9dFWQKBgQC6XK+foKK+81RFdrqaNNgebTWTsANnBcZe
xl0bj6oL5yY0IzroxHvgcNS7UMriWCu+K2xfkUBdMmxU773VN5JQ5k15ezjgtrvc
8oAaEsxYP4su12JSTC/zsBANUgrNbFj8++qqKYWt2aQc2O/kbZ4MNfekIVFc8AjM
2X9PxvxKLwKBgHXL7QO3TQLnVyt8VbQEjBFMzwriznB7i+4o8jkOKVU93IEr8zQr
5iQElcLSR3I6uUJTALYvsaoXH5jXKVwujwL69LYiNQRDe+r6qqvrUHbiNJdsd8Rk
XuhGGqj34tD04Pcd+h+MtO+YWqmHBBZwcA9XBeIkebbjPFH2kLT8AwN5AoGAYQy9
hMJxnkE3hIkk+gNE/OtgeE20J+Vw/ZANkrnJEzPHyGUEW41e+W2oyvdzAFZsSTdx
037f5ujIU58Z27x53NliRT4vS4693H0Iyws5EUfeIoGVuUflvODWKymraHjhCrXh
6cV/0R5DAabTnsCbCr7b/MRBC8YQvyUQ0KnOXo8CgYBQYGpvJnSWyvsCjtb6apTP
drjcBhVd0aSBpLGtDdtUCV4oLl9HPy+cLzcGaqckBqCwEq5DKruhMEf7on56bUMd
m/3ItFk1TnhysAeJHb3zLqmJ9CKBitpqLlsOE7MEXVNmbTYeXU10Uo9yOfyt1i7T
su+nT5VtyPkmF/l4wZl5+g==
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDKOK2QN3gN2VWT
hN9XEpSp3GckIaP27KtQ8CpHyEHr5tH7ytbrevIE1g1GLPE4bDjUTX1qxw27ID3y
ZQXiiqOJ3UrhWoOIrO/VSr/c1S6cLD+sR11HmEKhWROwEmPM2AEmqdxLrTq/SF8D
TRwKSkZH5N7tD8TxPyUCmKjPXMVDTlgLr2h4kcislysiQgUCCZcOBEXRngStR677
BU5ZPbru5UBifkB0UzEJIgGxi0+x6rJwLkRDoEIG2wgaK2VGDMtGD3ntllr5IXfx
qj7+KNiY4selBPSYVpTItPsEqDGC4azwIdwjE3lCjbsnewLnKaeu/a6TkNwA5RzG
SVdRtUJ5AgMBAAECggEAEM8piTp9I5yQVxr1RCv+mMiB99BGjHrTij6uawlXxnPJ
Ol574zbU6Yc/8vh/JB8l0arvzQmHCAoX8B9K4BABZE81X1paYujqJi8ImAMN9Owe
LlQ/yhjbWAVbJDiBHHjLjrLRpaY8gwQxZqk5FpdiNG1vROIZzypeKZM2PAdke9HA
PvJtsyfXdEz+jb5EUgaadyn7aquR6y607a8m55y34POLOcssteUOje4GdrTekHK0
62E+iEnawBjIs7gBzJf0j1XjFNq3aAeLrn8gFCEb+yK7X++8FJ8YjwsqS5V1aMsR
1PZguW0jCzYHATc2OcIozlvdBriPxy7eX8Y3MFvNMQKBgQD22ReUyKX5TKA/fg3z
S/QGfYqd4T35jkwK1MaXOuFOBzNyTMT6ZJkbjxPOYPB0uakcfIlva8bI77mE5vRe
PWYlvitp9Zz3v2kt/WgnwG32ZdVedPjEoi9aitUXmiiIoxdPVAUAgLPFFN65Sr2G
2NM/vduZcAPUr0UWnFx4dlpo8QKBgQDRuAV44Y+1396511oW4OR8ofWFECMc5uEV
wQ26EpbilEYhRSBty+1PAK5AcEGybeVtUn9RSmx0Ef1L15wnzP/C886LIzkaig/9
xs0yudXgOFdBAzYQKnK2lZmSKkzcUFJtifat3E+ZMCo/duhzXpzecg/lVNGh6gcx
xbtphJCyCQKBgEO8zvvFE8aVgGPr82gQL6aYTLGGXbtdkQBn4xcc0TbYQwXaizMq
59joKkc30sQ1LnLiudQZfzMklYQi3Gv/7UfuJ3usKqbRn8s+/pXp+ELlLuf8sUdE
OjpeXptbckQMfRkHtVet+abbU0MFf3zBgza6osg4NNToQ80wmy9zStwBAoGAGLeD
jZeoBFt6OJT0/TVMOJQuB5y7RrC/Xnz+TSvbtKCdE1a+V7JtKZ5+6wFP/OOO4q+S
adZHqfZk0Ad9VAOJMUTi1usz07jp4ZMIpC3a0y5QukzSll0qX/KJwvxRSrX8wQQ9
mogYqYlPsWMmSlKgUmdHEFRK0LZwWqFfUTRaiWECgYEA6KR6KMbqnYn5CglHeD42
NmOgFYXljRLIxS1coTiWWQZM/nUyx/tSk+MAS7770USHoZhAfh6lmEa/AeKSoLVl
Su3yzgtKk1/doAtbiWD8TasHAhacwWmzTuZtH5cZUgW3QIVJg6ADi6m8zswqKxIS
qfU/1N4aHp832v4ggRe/Og0=
-----END PRIVATE KEY-----

View File

@@ -1,5 +1,5 @@
{
"expiryDate": "2025-11-12T14:20:10.043Z",
"issueDate": "2025-08-14T14:20:10.043Z",
"savedAt": "2025-08-14T14:20:10.044Z"
"expiryDate": "2026-05-01T01:40:34.253Z",
"issueDate": "2026-01-31T01:40:34.253Z",
"savedAt": "2026-01-31T01:40:34.253Z"
}

View File

@@ -1,5 +1,348 @@
# Changelog
## 2026-02-19 - 25.7.4 - fix(smart-proxy)
include proxy IPs in smart proxy configuration
- Add proxyIps: this.settings.proxyIPs to proxy options in ts/proxies/smart-proxy/smart-proxy.ts
- Ensures proxy IPs from settings are passed into the proxy implementation (enables proxy IP filtering/whitelisting)
## 2026-02-16 - 25.7.3 - fix(metrics)
centralize connection-closed reporting via ConnectionGuard and remove duplicate explicit metrics.connection_closed calls
- Removed numerous explicit metrics.connection_closed calls from rust/crates/rustproxy-http/src/proxy_service.rs so connection teardown and byte counting are handled by the connection guard / counting body instead of ad-hoc calls.
- Simplified ConnectionGuard in rust/crates/rustproxy-passthrough/src/tcp_listener.rs: removed the disarm flag and disarm() method so Drop always reports connection_closed.
- Stopped disarming the TCP-level guard when handing connections off to HTTP proxy paths (HTTP/WebSocket/streaming flows) to avoid missing or double-reporting metrics.
- Fixes incorrect/duplicate connection-closed metric emission and ensures consistent byte/connection accounting during streaming and WebSocket upgrades.
## 2026-02-16 - 25.7.2 - fix(rustproxy-http)
preserve original Host header when proxying and add X-Forwarded-* headers; add TLS WebSocket echo backend helper and integration test for terminate-and-reencrypt websocket
- Preserve the client's original Host header instead of replacing it with backend host:port when proxying requests.
- Add standard reverse-proxy headers: X-Forwarded-For (appends client IP), X-Forwarded-Host, and X-Forwarded-Proto for upstream requests.
- Ensure raw TCP/HTTP upstream requests copy original headers and skip X-Forwarded-* (which are added explicitly).
- Add start_tls_ws_echo_backend test helper to start a TLS WebSocket echo backend for tests.
- Add integration test test_terminate_and_reencrypt_websocket to verify WS upgrade through terminate-and-reencrypt TLS path.
- Rename unused parameter upstream to _upstream in proxy_service functions to avoid warnings.
## 2026-02-16 - 25.7.1 - fix(proxy)
use TLS to backends for terminate-and-reencrypt routes
- Set upstream.use_tls = true when a route's TLS mode is TerminateAndReencrypt so the proxy re-encrypts to backend servers.
- Add start_tls_http_backend test helper and update integration tests to run TLS-enabled backend servers validating re-encryption behavior.
- Make the selected upstream mutable to allow toggling the use_tls flag during request handling.
## 2026-02-16 - 25.7.0 - feat(routes)
add protocol-based route matching and ensure terminate-and-reencrypt routes HTTP through the full HTTP proxy; update docs and tests
- Introduce a new 'protocol' match field for routes (supports 'http' and 'tcp') and preserve it through cloning/merging.
- Add Rust integration test verifying terminate-and-reencrypt decrypts TLS and routes HTTP traffic via the HTTP proxy (per-request Host/path routing) instead of raw tunneling.
- Add TypeScript unit tests covering protocol field validation, preservation, interaction with terminate-and-reencrypt, cloning, merging, and matching behavior.
- Update README with a Protocol-Specific Routing section and clarify terminate-and-reencrypt behavior (HTTP routed via HTTP proxy; non-HTTP uses raw TLS-to-TLS tunnel).
- Example config: include health check thresholds (unhealthyThreshold and healthyThreshold) in the sample healthCheck settings.
## 2026-02-16 - 25.6.0 - feat(rustproxy)
add protocol-based routing and backend TLS re-encryption support
- Introduce optional route_match.protocol ("http" | "tcp") in Rust and TypeScript route types to allow protocol-restricted routing.
- RouteManager: respect protocol field during matching and treat TLS connections without SNI as not matching domain-restricted routes (except wildcard-only routes).
- HTTP proxy: add BackendStream abstraction to unify plain TCP and tokio-rustls TLS backend streams, and support connecting to upstreams over TLS (upstream.use_tls) with an InsecureBackendVerifier for internal/self-signed backends.
- WebSocket and HTTP forwarding updated to use BackendStream so upstream TLS is handled transparently.
- Passthrough listener: perform post-termination protocol detection for TerminateAndReencrypt; route HTTP flows into HttpProxyService and handle non-HTTP as TLS-to-TLS tunnel.
- Add tests for protocol matching, TLS/no-SNI behavior, and other routing edge cases.
- Add rustls and tokio-rustls dependencies (Cargo.toml/Cargo.lock updates).
## 2026-02-16 - 25.5.0 - feat(tls)
add shared TLS acceptor with SNI resolver and session resumption; prefer shared acceptor and fall back to per-connection when routes specify custom TLS versions
- Add CertResolver that pre-parses PEM certs/keys into CertifiedKey instances for SNI-based lookup and cheap runtime resolution
- Introduce build_shared_tls_acceptor to create a shared ServerConfig with session cache (4096) and Ticketer for session ticket resumption
- Add ArcSwap<Option<TlsAcceptor>> shared_tls_acceptor to tcp_listener for hot-reloadable, pre-built acceptor and update accept loop/handlers to use it
- set_tls_configs now attempts to build and store the shared TLS acceptor, falling back to per-connection acceptors on failure; raw PEM configs are still retained for route-level fallbacks
- Add get_tls_acceptor helper: prefer shared acceptor for performance and session resumption, but build per-connection acceptor when a route requests custom TLS versions
## 2026-02-16 - 25.4.0 - feat(rustproxy)
support dynamically loaded TLS certificates via loadCertificate IPC and include them in listener TLS configs for rebuilds and hot-swap
- Adds loaded_certs: HashMap<String, TlsCertConfig> to RustProxy to store certificates loaded at runtime
- Merge loaded_certs into tls_configs in rebuild and listener hot-swap paths so dynamically loaded certs are served immediately
- Persist loaded certificates on loadCertificate so future rebuilds include them
## 2026-02-15 - 25.3.1 - fix(plugins)
remove unused dependencies and simplify plugin exports
- Removed multiple dependencies from package.json to reduce dependency footprint: @push.rocks/lik, @push.rocks/smartacme, @push.rocks/smartdelay, @push.rocks/smartfile, @push.rocks/smartnetwork, @push.rocks/smartpromise, @push.rocks/smartrequest, @push.rocks/smartrx, @push.rocks/smartstring, @push.rocks/taskbuffer, @types/minimatch, @types/ws, pretty-ms, ws
- ts/plugins.ts: stopped importing/exporting node:https and many push.rocks and third-party modules; plugins now only re-export core node modules (without https), tsclass, smartcrypto, smartlog (+ destination-local), smartrust, and minimatch
- Intended effect: trim surface area and remove unused/optional integrations; patch-level change (no feature/API additions)
## 2026-02-14 - 25.3.0 - feat(smart-proxy)
add background concurrent certificate provisioning with per-domain timeouts and concurrency control
- Add ISmartProxyOptions settings: certProvisionTimeout (ms) and certProvisionConcurrency (default 4)
- Run certProvisionFunction as fire-and-forget background tasks (stores promise on start/route-update and awaited on stop)
- Provision certificates in parallel with a concurrency limit using a new ConcurrencySemaphore utility
- Introduce per-domain timeout handling (default 300000ms) via withTimeout and surface timeout errors as certificate-failed events
- Refactor provisioning into provisionSingleDomain to isolate domain handling, ACME fallback preserved
- Run provisioning outside route update mutex so route updates are not blocked by slow provisioning
## 2026-02-14 - 25.2.2 - fix(smart-proxy)
start metrics polling before certificate provisioning to avoid blocking metrics collection
- Start metrics polling immediately after Rust engine startup so metrics are available without waiting for certificate provisioning.
- Run certProvisionFunction after startup because ACME/DNS-01 provisioning can hang or be slow and must not block observability.
- Code change in ts/proxies/smart-proxy/smart-proxy.ts: metricsAdapter.startPolling() moved to run before provisionCertificatesViaCallback().
## 2026-02-14 - 25.2.1 - fix(smartproxy)
no changes detected in git diff
- The provided diff contains no file changes; no code or documentation updates to release.
## 2026-02-14 - 25.2.0 - feat(metrics)
add per-IP and HTTP-request metrics, propagate source IP through proxy paths, and expose new metrics to the TS adapter
- Add per-IP tracking and IpMetrics in MetricsCollector (active/total connections, bytes, throughput).
- Add HTTP request counters and tracking (total_http_requests, http_requests_per_sec, recent counters and tests).
- Include throughput history (ThroughputSample serialization, retention and snapshotting) and expose history in snapshots.
- Propagate source IP through HTTP and passthrough code paths: CountingBody.record_bytes and MetricsCollector methods now accept source_ip; connection_opened/closed calls include source IP.
- Introduce ForwardMetricsCtx to carry metrics context (collector, route_id, source_ip) into passthrough forwarding routines; update ConnectionGuard to include source_ip.
- TypeScript adapter (rust-metrics-adapter.ts) updated to return per-IP counts, top IPs, per-IP throughput, throughput history mapping, and HTTP request rates/total where available.
- Numerous unit tests added for per-IP tracking, HTTP request tracking, throughput history and ThroughputTracker.history behavior.
## 2026-02-13 - 25.1.0 - feat(metrics)
add real-time throughput sampling and byte-counting metrics
- Add CountingBody wrapper to count HTTP request and response bytes and report them to MetricsCollector.
- Implement lock-free hot-path byte recording and a cold-path sampling API (sample_all) in MetricsCollector with throughput history and configurable retention (default 3600s).
- Spawn a background sampling task in RustProxy (configurable sample_interval_ms) and tear it down on stop so throughput trackers are regularly sampled.
- Instrument passthrough TCP forwarding and socket-relay paths to record per-chunk bytes (lock-free) so long-lived connections contribute to throughput measurements.
- Wrap HTTP request/response bodies with CountingBody in proxy_service to capture bytes_in/bytes_out and report on body completion; connection_closed handling updated accordingly.
- Expose recent throughput metrics to the TypeScript adapter (throughputRecentIn/Out) and pass metrics settings from the TS SmartProxy into Rust.
- Add http-body dependency and update Cargo.toml/Cargo.lock entries for the new body wrapper usage.
- Add unit tests for MetricsCollector throughput tracking and a new end-to-end throughput test (test.throughput.ts).
- Update test certificates (assets/certs cert.pem and key.pem) used by TLS tests.
## 2026-02-13 - 25.0.0 - BREAKING CHANGE(certs)
accept a second eventComms argument in certProvisionFunction, add cert provisioning event types, and emit certificate lifecycle events
- Breaking API change: certProvisionFunction signature changed from (domain: string) => Promise<TSmartProxyCertProvisionObject> to (domain: string, eventComms: ICertProvisionEventComms) => Promise<TSmartProxyCertProvisionObject>. Custom provisioners must accept (or safely ignore) the new second argument.
- New types added and exported: ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent.
- smart-proxy now constructs an eventComms channel that allows provisioners to log/warn/error and set expiry date and source for the issued event.
- Emits 'certificate-issued' (domain, expiryDate, source, isRenewal?) on successful provisioning and 'certificate-failed' (domain, error, source) on failures.
- Updated public exports to include the new types so they are available to consumers.
- Removed readme.byte-counting-audit.md (documentation file deleted).
## 2026-02-13 - 24.0.1 - fix(proxy)
improve proxy robustness: add connect timeouts, graceful shutdown, WebSocket watchdog, and metrics guard
- Add tokio-util CancellationToken to HTTP handlers to support graceful shutdown (stop accepting new requests while letting in-flight requests finish).
- Introduce configurable upstream connect timeout (DEFAULT_CONNECT_TIMEOUT) and return 504 Gateway Timeout on connect timeouts to avoid hanging connections.
- Add WebSocket watchdog with inactivity and max-lifetime checks, activity tracking via AtomicU64, and cancellation-driven tunnel aborts.
- Add ConnectionGuard RAII in passthrough listener to ensure metrics.connection_closed() is called on all exit paths and disarm the guard when handing off to the HTTP proxy.
- Expose HttpProxyService::with_connect_timeout and wire connection timeout from ConnectionConfig into listeners.
- Add tokio-util workspace dependency (CancellationToken) and related code changes across rustproxy-http and rustproxy-passthrough.
## 2026-02-13 - 24.0.0 - BREAKING CHANGE(smart-proxy)
move certificate persistence to an in-memory store and introduce consumer-managed certStore API; add default self-signed fallback cert and change ACME account handling
- Cert persistence removed from Rust side: CertStore is now an in-memory cache (no filesystem reads/writes). Rust no longer persists or loads certs from disk.
- ACME account credentials are no longer persisted by the library; AcmeClient uses ephemeral accounts only and account persistence APIs were removed.
- TypeScript API changes: removed certificateStore option and added ISmartProxyCertStore + certStore option for consumer-provided persistence (loadAll, save, optional remove).
- Default self-signed fallback certificate added (generateDefaultCertificate) and loaded as '*' unless disableDefaultCert is set.
- SmartProxy now pre-loads certificates from consumer certStore on startup and persists certificates by calling certStore.save() after provisioning.
- provisionCertificatesViaCallback signature changed to accept preloaded domains (prevents re-provisioning), and ACME fallback behavior adjusted with clearer logging.
- Rust cert manager methods made infallible for cache-only operations (load_static/store no longer return errors for cache insertions); removed store-backed load_all/remove/base_dir APIs.
- TCP listener tls_configs concurrency improved: switched to ArcSwap<HashMap<...>> so accept loops see hot-reloads immediately.
- Removed dependencies related to filesystem cert persistence from the tls crate (serde_json, tempfile) and corresponding Cargo.lock changes and test updates.
## 2026-02-13 - 23.1.6 - fix(smart-proxy)
disable built-in Rust ACME when a certProvisionFunction is provided and improve certificate provisioning flow
- Pass an optional ACME override into buildRustConfig so Rust ACME can be disabled per-run
- Disable Rust ACME when certProvisionFunction is configured to avoid provisioning race conditions
- Normalize routing glob patterns into concrete domain identifiers for certificate provisioning (expand leading-star globs and warn on unsupported patterns)
- Deduplicate domains during provisioning to avoid repeated attempts
- When the callback returns 'http01', explicitly trigger Rust ACME for the route via bridge.provisionCertificate and log success/failure
## 2026-02-13 - 23.1.5 - fix(smart-proxy)
provision certificates for wildcard domains instead of skipping them
- Removed early continue that skipped domains containing '*' in the domain loop
- Now calls provisionFn for wildcard domains so certificate provisioning can proceed for wildcard hosts
- Fixes cases where wildcard domains never had certificates requested
## 2026-02-12 - 23.1.4 - fix(tests)
make tests more robust and bump small dependencies
- Bump dependencies: @push.rocks/smartrust ^1.2.1 and minimatch ^10.2.0
- Replace hardcoded ports with named constants (ECHO_PORT, PROXY_PORT, PROXY_PORT_1/2) to avoid collisions between tests
- Add server 'error' handlers and reject listen promises on server errors to prevent silent hangs
- Reduce test timeouts and intervals (shorter test durations, more frequent pings) to speed up test runs
- Ensure proxy is stopped between tests and remove forced process.exit; export tap.start() consistently
- Adjust assertions to match the new shorter ping/response counts
## 2026-02-12 - 23.1.3 - fix(rustproxy)
install default rustls crypto provider early; detect and skip raw fast-path for HTTP connections and return proper HTTP 502 when no route matches
- Install ring-based rustls crypto provider at startup to prevent panics from instant-acme/hyper-rustls calling ClientConfig::builder() before TLS listeners are initialized
- Add a non-blocking 10ms peek to detect HTTP traffic in the TCP passthrough fast-path to avoid misrouting HTTP and ensure HTTP proxy handles CORS, errors, and request-level routing
- Skip the fast-path and fall back to the HTTP proxy when HTTP is detected (with a debug log)
- When no route matches for detected HTTP connections, send an HTTP 502 Bad Gateway response and close the connection instead of silently dropping it
## 2026-02-11 - 23.1.2 - fix(core)
use node: scoped builtin imports and add route unit tests
- Replaced bare Node built-in imports (events, fs, http, https, net, path, tls, url, http2, buffer, crypto) with 'node:' specifiers for ESM/bundler compatibility (files updated include ts/plugins.ts, ts/core/models/socket-types.ts, ts/core/utils/enhanced-connection-pool.ts, ts/core/utils/socket-tracker.ts, ts/protocols/common/fragment-handler.ts, ts/protocols/tls/sni/client-hello-parser.ts, ts/protocols/tls/sni/sni-extraction.ts, ts/protocols/websocket/utils.ts, ts/tls/sni/sni-handler.ts).
- Added new unit tests (test/test.bun.ts and test/test.deno.ts) covering route helpers, validators, matching, merging and cloning to improve test coverage.
## 2026-02-11 - 23.1.1 - fix(rust-proxy)
increase rust proxy bridge maxPayloadSize to 100 MB and bump dependencies
- Set maxPayloadSize to 100 * 1024 * 1024 (100 MB) in ts/proxies/smart-proxy/rust-proxy-bridge.ts to support large route configs
- Bump devDependency @types/node from ^25.2.2 to ^25.2.3
- Bump dependency @push.rocks/smartrust from ^1.1.1 to ^1.2.0
## 2026-02-10 - 23.1.0 - feat(rust-bridge)
integrate tsrust to build and locate cross-compiled Rust binaries; refactor rust-proxy bridge to use typed IPC and streamline process handling; add @push.rocks/smartrust and update build/dev dependencies
- Add tsrust to the build script and include dist_rust candidates when locating the Rust binary (enables cross-compiled artifacts produced by tsrust).
- Remove the old rust-binary-locator and refactor rust-proxy-bridge to use explicit, typed IPC command definitions and improved process spawn/cleanup logic.
- Introduce @push.rocks/smartrust for type-safe JSON IPC and export it via plugins; update README with expanded metrics documentation and change initialDataTimeout default from 60s to 120s.
- Add rust/.cargo/config.toml with aarch64 linker configuration to support cross-compilation for arm64.
- Bump several devDependencies and runtime dependencies (e.g. @git.zone/tsbuild, @git.zone/tstest, @push.rocks/smartserve, @push.rocks/taskbuffer, ws, minimatch, etc.).
- Update runtime message guiding local builds to use 'pnpm build' (tsrust) instead of direct cargo invocation.
## 2026-02-09 - 23.0.0 - BREAKING CHANGE(proxies/nftables-proxy)
remove nftables-proxy implementation, models, and utilities from the repository
- Deleted nftables-proxy module files under ts/proxies/nftables-proxy (index, models, utils, command executor, validators, etc.)
- Removed nftables-proxy exports from ts/index.ts and ts/proxies/index.ts
- Updated smart-proxy types to drop dependency on nftables proxy models
- Breaking change: any consumers importing nftables-proxy will no longer find those exports; update imports or install/use the extracted/alternative package if applicable
## 2026-02-09 - 22.6.0 - feat(smart-proxy)
add socket-handler relay, fast-path port-only forwarding, metrics and bridge improvements, and various TS/Rust integration fixes
- Add Unix-domain socket relay for socket-handler routes so Rust can hand off matched connections to TypeScript handlers (metadata JSON + initial bytes, relay implementation in Rust and SocketHandlerServer in TS).
- Implement fast-path port-only forwarding in the TCP accept/handler path to forward simple non-TLS, port-only routes immediately without peeking at client data (improves server-speaks-first protocol handling).
- Use ArcSwap for route manager hot-reload visibility in accept loops and share socket_handler_relay via Arc<RwLock> so listeners see relay path updates immediately.
- Enhance SNI/HTTP parsing: add extract_http_path and extract_http_host to aid domain/path matching from initial data.
- Improve RustProxy shutdown/kill handling: remove listeners, reject pending requests, destroy stdio pipes and unref process to avoid leaking handles.
- Enhance Rust <-> TS metrics bridge and adapter: add immediate poll(), map Rust JSON fields to IMetrics (per-route active/throughput/totals), and use safer polling/unref timers.
- SocketHandlerServer enhancements: track active sockets, destroy on stop, pause/resume to prevent data loss, support async socketHandler callbacks and dynamic function-based target forwarding (resolve host/port functions and forward).
- TypeScript smart-proxy lifecycle tweaks: only set bridge relay after Rust starts, guard unexpected-exit emission when intentionally stopping, stop polling and remove listeners on stop, add stopping flag.
- Misc: README and API ergonomics updates (nft proxy option renames and config comments), various test updates to use stable http.request helper, adjust timeouts/metrics sampling and assertions, and multiple small bugfixes in listeners, timeouts and TLS typings.
## 2026-02-09 - 22.5.0 - feat(rustproxy)
introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates
- Add Rust workspace and multiple crates: rustproxy, rustproxy-config, rustproxy-routing, rustproxy-tls, rustproxy-passthrough, rustproxy-http, rustproxy-nftables, rustproxy-metrics, rustproxy-security
- Implement ACME integration (instant-acme) and an HTTP-01 challenge server with certificate lifecycle management
- Add TLS management: cert store, cert manager, SNI resolver, TLS acceptor/connector and certificate hot-swap support
- Implement TCP/TLS passthrough engine with ClientHello SNI parsing, PROXY v1 support, connection tracking and bidirectional forwarder
- Add Hyper-based HTTP proxy components: request/response filtering, CORS, auth, header templating and upstream selection with load balancing
- Introduce metrics (throughput tracker, metrics collector) and log deduplication utilities
- Implement nftables manager and rule builder (safe no-op behavior when not running as root)
- Add route types, validation, helpers, route manager and matchers (domain/path/header/ip)
- Provide management IPC (JSON over stdin/stdout) for TypeScript wrapper control (start/stop/add/remove ports, load certificates, etc.)
- Include extensive unit and integration tests, test helpers, and an example Rust config.json
- Update README to document the Rust-powered engine, new features and rustBinaryPath lookup
## 2026-01-31 - 22.4.2 - fix(tests)
shorten long-lived connection test timeouts and update certificate metadata timestamps
- Reduced test timeouts from 6570s to 60s and shortened internal waits from ~6165s to 55s to ensure tests complete within CI runner limits (files changed: test/test.long-lived-connections.ts, test/test.websocket-keepalive.node.ts).
- Updated log message to reflect the new 55s wait.
- Bumped certificate metadata timestamps in certs/static-route/meta.json (issueDate, savedAt, expiryDate).
## 2026-01-30 - 22.4.1 - fix(smartproxy)
improve certificate manager mocking in tests, enhance IPv6 validation, and record initial bytes for connection metrics
- Add createMockCertManager and update tests to fully mock createCertificateManager to avoid real ACME calls and make provisioning deterministic
- Record initial data chunk bytes in route-connection-handler and report them to metricsCollector.recordBytes to improve metrics accuracy
- Improve IPv6 validation regex to accept IPv6-mapped IPv4 addresses (::ffff:x.x.x.x)
- Add/set missing mock methods used in tests (setRoutes, generateConnectionId, trackConnectionByRoute, validateAndTrackIP) and small test adjustments (route names, port changes)
- Make test robustness improvements: wait loops for connection cleanup, increase websocket keepalive timeout, and other minor test fixes/whitespace cleanups
- Update certificate meta timestamps (test fixtures)
## 2026-01-30 - 22.4.0 - feat(smart-proxy)
calculate when SNI is required for TLS routing and allow session tickets for single-target passthrough routes; add tests, docs, and npm metadata updates
- Add calculateSniRequirement() and isWildcardOnly() to determine when SNI is required for routing decisions
- Use the new calculation to allow TLS session tickets for single-route passthrough or wildcard-only domains and block them when SNI is required
- Replace previous heuristic in route-connection-handler with the new SNI-based logic
- Add comprehensive unit tests (test/test.sni-requirement.node.ts) covering multiple SNI scenarios
- Update readme.hints.md with Smart SNI Requirement documentation and adjust troubleshooting guidance
- Update npmextra.json keys, add release registries and adjust tsdoc/CI metadata
## 2026-01-30 - 22.3.0 - feat(docs)
update README with installation, improved feature table, expanded quick-start, ACME/email example, API options interface, and clarified licensing/trademark text
- Added Installation section with npm/pnpm commands
- Reformatted features into a markdown table for clarity
- Expanded Quick Start example and updated ACME email placeholder
- Added an ISmartProxyOptions interface example showing acme/defaults/behavior options
- Clarified license file path and expanded trademark/legal wording
- Minor editorial and formatting improvements throughout the README
## 2026-01-30 - 22.2.0 - feat(proxies)
introduce nftables command executor and utilities, default certificate provider, expanded route/socket helper modules, and security improvements
- Added NftCommandExecutor with retry, temp-file support, sync execution, availability and conntrack checks.
- Refactored NfTablesProxy to use executor/utils (normalizePortSpec, validators, port normalizer, IP family filtering) and removed inline command/validation code.
- Introduced DefaultCertificateProvider to replace the deprecated CertificateManager; HttpProxy now uses DefaultCertificateProvider (CertificateManager exported as deprecated alias for compatibility).
- Added extensive route helper modules (http, https, api, load-balancer, nftables, dynamic, websocket, security, socket handlers) to simplify route creation and provide reusable patterns.
- Enhanced SecurityManagers: centralized security utilities (normalizeIP, isIPAuthorized, parseBasicAuthHeader, cleanup helpers), added validateAndTrackIP and JWT token verification, better IP normalization and rate tracking.
- Added many utility modules under ts/proxies/nftables-proxy/utils (command executor, port spec normalizer, rule validator) and exposed them via barrel export.
## 2025-12-09 - 22.1.1 - fix(tests)
Normalize route configurations in tests to use name (remove id) and standardize route names
- Removed deprecated id properties from route configurations in multiple tests and rely on the name property instead
- Standardized route.name values to kebab-case / lowercase (examples: 'tcp-forward', 'tls-passthrough', 'domain-a', 'domain-b', 'test-forward', 'nftables-test', 'regular-test', 'forward-test', 'test-forward', 'tls-test')
- Added explicit names for inner and outer proxies in proxy-chain-cleanup test ('inner-backend', 'outer-frontend')
- Updated certificate metadata timestamps in certs/static-route/meta.json
## 2025-12-09 - 22.1.0 - feat(smart-proxy)
Improve connection/rate-limit atomicity, SNI parsing, HttpProxy & ACME orchestration, and routing utilities
- Fix race conditions for per-IP connection limits by introducing atomic validate-and-track flow (SecurityManager.validateAndTrackIP) and propagating connectionId for atomic tracking.
- Add connection-manager createConnection options (connectionId, skipIpTracking) and avoid double-tracking IPs when validated atomically.
- RouteConnectionHandler now generates connection IDs earlier and uses atomic IP validation to prevent concurrent connection bypasses; cleans up IP tracking on global-limit rejects.
- Enhanced TLS SNI extraction and ClientHello parsing: robust fragmented ClientHello handling, PSK-based SNI extraction for TLS 1.3 resumption, tab-reactivation heuristics and improved logging (new client-hello-parser and sni-extraction modules).
- HttpProxy integration improvements: HttpProxyBridge initialized/synced from SmartProxy, forwardToHttpProxy forwards initial data and preserves client IP via CLIENT_IP header, robust handling of client disconnects during setup.
- Certificate manager (SmartCertManager) improvements: better ACME initialization sequence (deferred provisioning until ports are bound), improved challenge route add/remove handling, custom certificate provisioning hook, expiry handling fallback behavior and safer error messages for port conflicts.
- Route/port orchestration refactor (RouteOrchestrator): port usage mapping, safer add/remove port sequences, NFTables route lifecycle updates and certificate manager recreation on route changes.
- PortManager now refcounts ports and reuses existing listeners instead of rebinding; provides helpers to add/remove/update multiple ports and improved error handling for EADDRINUSE.
- Connection cleanup, inactivity and zombie detection hardened: batched cleanup queue, optimized inactivity checks, half-zombie detection and safer shutdown workflows.
- Metrics, routing helpers and validators: SharedRouteManager exposes expandPortRange/getListeningPorts, route helpers add convenience HTTPS/redirect/loadbalancer builders, route-validator domain rules relaxed to allow 'localhost', '*' and IPs, and tests updated accordingly.
- Tests updated to reflect behavioral changes (connection limit checks adapted to detect closed/ reset connections, HttpProxy integration test skipped in unit suite to avoid complex TLS setup).
## 2025-12-09 - 22.0.0 - BREAKING CHANGE(smart-proxy/utils/route-validator)
Consolidate and refactor route validators; move to class-based API and update usages
Replaced legacy route-validators.ts with a unified route-validator.ts that provides a class-based RouteValidator plus the previous functional API (isValidPort, isValidDomain, validateRouteMatch, validateRouteAction, validateRouteConfig, validateRoutes, hasRequiredPropertiesForAction, assertValidRoute) for backwards compatibility. Updated utils exports and all imports/tests to reference the new module. Also switched static file loading in certificate manager to use SmartFileFactory.nodeFs(), and added @push.rocks/smartserve to devDependencies.
- Rename and consolidate validator module: route-validators.ts removed; route-validator.ts added with RouteValidator class and duplicated functional API for compatibility.
- Updated exports in ts/proxies/smart-proxy/utils/index.ts and all internal imports/tests to reference './route-validator.js' instead of './route-validators.js'.
- Certificate manager now uses plugins.smartfile.SmartFileFactory.nodeFs() to load key/cert files (safer factory usage instead of direct static calls).
- Added @push.rocks/smartserve to devDependencies in package.json.
- Because the validator filename and some import paths changed, this is a breaking change for consumers importing the old module path.
## 2025-08-19 - 21.1.7 - fix(route-validator)
Relax domain validation to accept 'localhost', prefix wildcards (e.g. *example.com) and IP literals; add comprehensive domain validation tests
- Allow 'localhost' as a valid domain pattern in route validation
- Support prefix wildcard patterns like '*example.com' in addition to '*.example.com'
- Accept IPv4 and IPv6 literal addresses in domain validation
- Add test coverage: new test/test.domain-validation.ts with many real-world and edge-case patterns
## 2025-08-19 - 21.1.6 - fix(ip-utils)
Fix IP wildcard/shorthand handling and add validation test
- Support shorthand IPv4 wildcard patterns (e.g. '10.*', '192.168.*') by expanding them to full 4-octet patterns before matching
- Normalize and expand patterns in IpUtils.isGlobIPMatch and SharedSecurityManager IP checks to ensure consistent minimatch comparisons
- Relax route validator wildcard checks to accept 1-4 octet wildcard specifications for IPv4 patterns
- Add test harness test-ip-validation.ts to exercise common wildcard/shorthand IP patterns
## 2025-08-19 - 21.1.5 - fix(core)
Prepare patch release: documentation, tests and stability fixes (metrics, ACME, connection cleanup)

7069
deno.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
{
"gitzone": {
"@git.zone/cli": {
"projectType": "npm",
"module": {
"githost": "code.foss.global",
@@ -26,13 +26,22 @@
"server",
"network security"
]
},
"release": {
"registries": [
"https://verdaccio.lossless.digital",
"https://registry.npmjs.org"
],
"accessLevel": "public"
}
},
"npmci": {
"npmGlobalTools": [],
"npmAccessLevel": "public"
},
"tsdoc": {
"@git.zone/tsdoc": {
"legal": "\n## License and Legal Information\n\nThis repository contains open-source code that is licensed under the MIT License. A copy of the MIT License can be found in the [license](license) file within this repository. \n\n**Please note:** The MIT License does not grant permission to use the trade names, trademarks, service marks, or product names of the project, except as required for reasonable and customary use in describing the origin of the work and reproducing the content of the NOTICE file.\n\n### Trademarks\n\nThis project is owned and maintained by Task Venture Capital GmbH. The names and logos associated with Task Venture Capital GmbH and any related products or services are trademarks of Task Venture Capital GmbH and are not included within the scope of the MIT license granted herein. Use of these trademarks must comply with Task Venture Capital GmbH's Trademark Guidelines, and any usage must be approved in writing by Task Venture Capital GmbH.\n\n### Company Information\n\nTask Venture Capital GmbH \nRegistered at District court Bremen HRB 35230 HB, Germany\n\nFor any legal inquiries or if you require further information, please contact us via email at hello@task.vc.\n\nBy using this repository, you acknowledge that you have read this section, agree to comply with its terms, and understand that the licensing of the code does not imply endorsement by Task Venture Capital GmbH of any derivative works.\n"
},
"@ship.zone/szci": {
"npmGlobalTools": []
},
"@git.zone/tsrust": {
"targets": ["linux_amd64", "linux_arm64"]
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@push.rocks/smartproxy",
"version": "21.1.5",
"version": "25.7.4",
"private": false,
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
"main": "dist_ts/index.js",
@@ -10,37 +10,26 @@
"license": "MIT",
"scripts": {
"test": "(tstest test/**/test*.ts --verbose --timeout 60 --logfile)",
"build": "(tsbuild tsfolders --allowimplicitany)",
"build": "(tsbuild tsfolders --allowimplicitany) && (tsrust)",
"format": "(gitzone format)",
"buildDocs": "tsdoc"
},
"devDependencies": {
"@git.zone/tsbuild": "^2.6.4",
"@git.zone/tsrun": "^1.2.44",
"@git.zone/tstest": "^2.3.1",
"@types/node": "^22.15.29",
"typescript": "^5.8.3",
"@git.zone/tsbuild": "^4.1.2",
"@git.zone/tsrun": "^2.0.1",
"@git.zone/tsrust": "^1.3.0",
"@git.zone/tstest": "^3.1.8",
"@push.rocks/smartserve": "^2.0.1",
"@types/node": "^25.2.3",
"typescript": "^5.9.3",
"why-is-node-running": "^3.2.2"
},
"dependencies": {
"@push.rocks/lik": "^6.2.2",
"@push.rocks/smartacme": "^8.0.0",
"@push.rocks/smartcrypto": "^2.0.4",
"@push.rocks/smartdelay": "^3.0.5",
"@push.rocks/smartfile": "^11.2.5",
"@push.rocks/smartlog": "^3.1.8",
"@push.rocks/smartnetwork": "^4.0.2",
"@push.rocks/smartpromise": "^4.2.3",
"@push.rocks/smartrequest": "^2.1.0",
"@push.rocks/smartrx": "^3.0.10",
"@push.rocks/smartstring": "^4.0.15",
"@push.rocks/taskbuffer": "^3.1.7",
"@tsclass/tsclass": "^9.2.0",
"@types/minimatch": "^5.1.2",
"@types/ws": "^8.18.1",
"minimatch": "^10.0.1",
"pretty-ms": "^9.2.0",
"ws": "^8.18.2"
"@push.rocks/smartlog": "^3.1.10",
"@push.rocks/smartrust": "^1.2.1",
"@tsclass/tsclass": "^9.3.0",
"minimatch": "^10.2.0"
},
"files": [
"ts/**/*",

7763
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,169 +0,0 @@
# SmartProxy Byte Counting Audit Report
## Executive Summary
After a comprehensive audit of the SmartProxy codebase, I can confirm that **byte counting is implemented correctly** with no instances of double counting. Each byte transferred through the proxy is counted exactly once in each direction.
## Byte Counting Implementation
### 1. Core Tracking Mechanisms
SmartProxy uses two complementary tracking systems:
1. **Connection Records** (`IConnectionRecord`):
- `bytesReceived`: Total bytes received from client
- `bytesSent`: Total bytes sent to client
2. **MetricsCollector**:
- Global throughput tracking via `ThroughputTracker`
- Per-connection byte tracking for route/IP metrics
- Called via `recordBytes(connectionId, bytesIn, bytesOut)`
### 2. Where Bytes Are Counted
Bytes are counted in only two files:
#### a) `route-connection-handler.ts`
- **Line 351**: TLS alert bytes when no SNI is provided
- **Lines 1286-1301**: Data forwarding callbacks in `setupBidirectionalForwarding()`
#### b) `http-proxy-bridge.ts`
- **Line 127**: Initial TLS chunk for HttpProxy connections
- **Lines 142-154**: Data forwarding callbacks in `setupBidirectionalForwarding()`
## Connection Flow Analysis
### 1. Direct TCP Connection (No TLS)
```
Client → SmartProxy → Target Server
```
1. Connection arrives at `RouteConnectionHandler.handleConnection()`
2. For non-TLS ports, immediately routes via `routeConnection()`
3. `setupDirectConnection()` creates target connection
4. `setupBidirectionalForwarding()` handles all data transfer:
- `onClientData`: `bytesReceived += chunk.length` + `recordBytes(chunk.length, 0)`
- `onServerData`: `bytesSent += chunk.length` + `recordBytes(0, chunk.length)`
**Result**: ✅ Each byte counted exactly once
### 2. TLS Passthrough Connection
```
Client (TLS) → SmartProxy → Target Server (TLS)
```
1. Connection waits for initial data to detect TLS
2. TLS handshake detected, SNI extracted
3. Route matched, `setupDirectConnection()` called
4. Initial chunk stored in `pendingData` (NOT counted yet)
5. On target connect, `pendingData` written to target (still not counted)
6. `setupBidirectionalForwarding()` counts ALL bytes including initial chunk
**Result**: ✅ Each byte counted exactly once
### 3. TLS Termination via HttpProxy
```
Client (TLS) → SmartProxy → HttpProxy (localhost) → Target Server
```
1. TLS connection detected with `tls.mode = "terminate"`
2. `forwardToHttpProxy()` called:
- Initial chunk: `bytesReceived += chunk.length` + `recordBytes(chunk.length, 0)`
3. Proxy connection created to HttpProxy on localhost
4. `setupBidirectionalForwarding()` handles subsequent data
**Result**: ✅ Each byte counted exactly once
### 4. HTTP Connection via HttpProxy
```
Client (HTTP) → SmartProxy → HttpProxy (localhost) → Target Server
```
1. Connection on configured HTTP port (`useHttpProxy` ports)
2. Same flow as TLS termination
3. All byte counting identical to TLS termination
**Result**: ✅ Each byte counted exactly once
### 5. NFTables Forwarding
```
Client → [Kernel NFTables] → Target Server
```
1. Connection detected, route matched with `forwardingEngine: 'nftables'`
2. Connection marked as `usingNetworkProxy = true`
3. NO application-level forwarding (kernel handles packet routing)
4. NO byte counting in application layer
**Result**: ✅ No counting (correct - kernel handles everything)
## Special Cases
### PROXY Protocol
- PROXY protocol headers sent to backend servers are NOT counted in client metrics
- Only actual client data is counted
- **Correct behavior**: Protocol overhead is not client data
### TLS Alerts
- TLS alerts (e.g., for missing SNI) are counted as sent bytes
- **Correct behavior**: Alerts are actual data sent to the client
### Initial Chunks
- **Direct connections**: Stored in `pendingData`, counted when forwarded
- **HttpProxy connections**: Counted immediately upon receipt
- **Both approaches**: Count each byte exactly once
## Verification Methodology
1. **Code Analysis**: Searched for all instances of:
- `bytesReceived +=` and `bytesSent +=`
- `recordBytes()` calls
- Data forwarding implementations
2. **Flow Tracing**: Followed data path for each connection type from entry to exit
3. **Handler Review**: Examined all forwarding handlers to ensure no additional counting
## Findings
### ✅ No Double Counting Detected
- Each byte is counted exactly once in the direction it flows
- Connection records and metrics are updated consistently
- No overlapping or duplicate counting logic found
### Areas of Excellence
1. **Centralized Counting**: All byte counting happens in just two files
2. **Consistent Pattern**: Uses `setupBidirectionalForwarding()` with callbacks
3. **Clear Separation**: Forwarding handlers don't interfere with proxy metrics
## Recommendations
1. **Debug Logging**: Add optional debug logging to verify byte counts in production:
```typescript
if (settings.debugByteCount) {
logger.log('debug', `Bytes counted: ${connectionId} +${bytes} (total: ${record.bytesReceived})`);
}
```
2. **Unit Tests**: Create specific tests to ensure byte counting accuracy:
- Test initial chunk handling
- Test PROXY protocol overhead exclusion
- Test HttpProxy forwarding accuracy
3. **Protocol Overhead Tracking**: Consider separately tracking:
- PROXY protocol headers
- TLS handshake bytes
- HTTP headers vs body
4. **NFTables Documentation**: Clearly document that NFTables-forwarded connections are not included in application metrics
## Conclusion
SmartProxy's byte counting implementation is **robust and accurate**. The design ensures that each byte is counted exactly once, with clear separation between connection tracking and metrics collection. No remediation is required.

View File

@@ -345,4 +345,187 @@ new SmartProxy({
1. Implement proper certificate expiry date extraction using X.509 parsing
2. Add support for returning expiry date with custom certificates
3. Consider adding validation for custom certificate format
4. Add events/hooks for certificate provisioning lifecycle
4. Add events/hooks for certificate provisioning lifecycle
## HTTPS/TLS Configuration Guide
SmartProxy supports three TLS modes for handling HTTPS traffic. Understanding when to use each mode is crucial for correct configuration.
### TLS Mode: Passthrough (SNI Routing)
**When to use**: Backend server handles its own TLS certificates.
**How it works**:
1. Client connects with TLS ClientHello containing SNI (Server Name Indication)
2. SmartProxy extracts the SNI hostname without decrypting
3. Connection is forwarded to backend as-is (still encrypted)
4. Backend server terminates TLS with its own certificate
**Configuration**:
```typescript
{
match: { ports: 443, domains: 'backend.example.com' },
action: {
type: 'forward',
targets: [{ host: 'backend-server', port: 443 }],
tls: { mode: 'passthrough' }
}
}
```
**Requirements**:
- Backend must have valid TLS certificate for the domain
- Client's SNI must be present (session tickets without SNI will be rejected)
- No HTTP-level inspection possible (encrypted end-to-end)
### TLS Mode: Terminate
**When to use**: SmartProxy handles TLS, backend receives plain HTTP.
**How it works**:
1. Client connects with TLS ClientHello
2. SmartProxy terminates TLS (decrypts traffic)
3. Decrypted HTTP is forwarded to backend on plain HTTP port
4. Backend receives unencrypted traffic
**Configuration**:
```typescript
{
match: { ports: 443, domains: 'api.example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8080 }], // HTTP backend
tls: {
mode: 'terminate',
certificate: 'auto' // Let's Encrypt, or provide { key, cert }
}
}
}
```
**Requirements**:
- ACME email configured for auto certificates: `acme: { email: 'admin@example.com' }`
- Port 80 available for HTTP-01 challenges (or use DNS-01)
- Backend accessible on HTTP port
### TLS Mode: Terminate and Re-encrypt
**When to use**: SmartProxy handles client TLS, but backend also requires TLS.
**How it works**:
1. Client connects with TLS ClientHello
2. SmartProxy terminates client TLS (decrypts)
3. SmartProxy creates new TLS connection to backend
4. Traffic is re-encrypted for the backend connection
**Configuration**:
```typescript
{
match: { ports: 443, domains: 'secure.example.com' },
action: {
type: 'forward',
targets: [{ host: 'backend-tls', port: 443 }], // HTTPS backend
tls: {
mode: 'terminate-and-reencrypt',
certificate: 'auto'
}
}
}
```
**Requirements**:
- Same as 'terminate' mode
- Backend must have valid TLS (can be self-signed for internal use)
### HttpProxy Integration
For TLS termination modes (`terminate` and `terminate-and-reencrypt`), SmartProxy uses an internal HttpProxy component:
- HttpProxy listens on an internal port (default: 8443)
- SmartProxy forwards TLS connections to HttpProxy for termination
- Client IP is preserved via `CLIENT_IP:` header protocol
- HTTP/2 and WebSocket are supported after TLS termination
**Configuration**:
```typescript
{
useHttpProxy: [443], // Ports that use HttpProxy for TLS termination
httpProxyPort: 8443, // Internal HttpProxy port
acme: {
email: 'admin@example.com',
useProduction: true // false for Let's Encrypt staging
}
}
```
### Common Configuration Patterns
**HTTP to HTTPS Redirect**:
```typescript
import { createHttpToHttpsRedirect } from '@push.rocks/smartproxy';
const redirectRoute = createHttpToHttpsRedirect(['example.com', 'www.example.com']);
```
**Complete HTTPS Server (with redirect)**:
```typescript
import { createCompleteHttpsServer } from '@push.rocks/smartproxy';
const routes = createCompleteHttpsServer(
'example.com',
{ host: 'localhost', port: 8080 },
{ certificate: 'auto' }
);
```
**Load Balancer with Health Checks**:
```typescript
import { createLoadBalancerRoute } from '@push.rocks/smartproxy';
const lbRoute = createLoadBalancerRoute(
'api.example.com',
[
{ host: 'backend1', port: 8080 },
{ host: 'backend2', port: 8080 },
{ host: 'backend3', port: 8080 }
],
{ tls: { mode: 'terminate', certificate: 'auto' } }
);
```
### Smart SNI Requirement (v22.3+)
SmartProxy automatically determines when SNI is required for routing. Session tickets (TLS resumption without SNI) are now allowed in more scenarios:
**SNI NOT required (session tickets allowed):**
- Single passthrough route with static target(s) and no domain restriction
- Single passthrough route with wildcard-only domain (`*` or `['*']`)
- TLS termination routes (`terminate` or `terminate-and-reencrypt`)
- Mixed terminate + passthrough routes (termination takes precedence)
**SNI IS required (session tickets blocked):**
- Multiple passthrough routes on the same port (need SNI to pick correct route)
- Route has dynamic host function (e.g., `host: (ctx) => ctx.domain === 'api.example.com' ? 'api-backend' : 'web-backend'`)
- Route has specific domain restriction (e.g., `domains: 'api.example.com'` or `domains: '*.example.com'`)
This allows simple single-target passthrough setups to work with TLS session resumption, improving performance for clients that reuse connections.
### Troubleshooting
**"No SNI detected" errors**:
- Client is using TLS session resumption without SNI
- Solution: Configure route for TLS termination (allows session resumption), or ensure you have a single-target passthrough route with no domain restrictions
**"HttpProxy not available" errors**:
- `useHttpProxy` not configured for the port
- Solution: Add port to `useHttpProxy` array in settings
**Certificate provisioning failures**:
- Port 80 not accessible for HTTP-01 challenges
- ACME email not configured
- Solution: Ensure port 80 is available and `acme.email` is set
**Connection timeouts to HttpProxy**:
- CLIENT_IP header parsing timeout (default: 2000ms)
- Network congestion between SmartProxy and HttpProxy
- Solution: Check localhost connectivity, increase timeout if needed

969
readme.md

File diff suppressed because it is too large Load Diff

2
rust/.cargo/config.toml Normal file
View File

@@ -0,0 +1,2 @@
[target.aarch64-unknown-linux-gnu]
linker = "aarch64-linux-gnu-gcc"

1724
rust/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

99
rust/Cargo.toml Normal file
View File

@@ -0,0 +1,99 @@
[workspace]
resolver = "2"
members = [
"crates/rustproxy",
"crates/rustproxy-config",
"crates/rustproxy-routing",
"crates/rustproxy-tls",
"crates/rustproxy-passthrough",
"crates/rustproxy-http",
"crates/rustproxy-nftables",
"crates/rustproxy-metrics",
"crates/rustproxy-security",
]
[workspace.package]
version = "0.1.0"
edition = "2021"
license = "MIT"
authors = ["Lossless GmbH <hello@lossless.com>"]
[workspace.dependencies]
# Async runtime
tokio = { version = "1", features = ["full"] }
# Serialization
serde = { version = "1", features = ["derive"] }
serde_json = "1"
# HTTP proxy engine (hyper-based)
hyper = { version = "1", features = ["http1", "http2", "server", "client"] }
hyper-util = { version = "0.1", features = ["tokio", "http1", "http2", "client-legacy", "server-auto"] }
http-body = "1"
http-body-util = "0.1"
bytes = "1"
# ACME / Let's Encrypt
instant-acme = { version = "0.7", features = ["hyper-rustls"] }
# TLS for passthrough SNI
rustls = { version = "0.23", features = ["ring"] }
tokio-rustls = "0.26"
rustls-pemfile = "2"
# Self-signed cert generation for tests
rcgen = "0.13"
# Temp directories for tests
tempfile = "3"
# Lock-free atomics
arc-swap = "1"
# Concurrent maps
dashmap = "6"
# Domain wildcard matching
glob-match = "0.2"
# IP/CIDR parsing
ipnet = "2"
# JWT authentication
jsonwebtoken = "9"
# Structured logging
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
# Error handling
thiserror = "2"
anyhow = "1"
# CLI
clap = { version = "4", features = ["derive"] }
# Regex for URL rewriting
regex = "1"
# Base64 for basic auth
base64 = "0.22"
# Cancellation / utility
tokio-util = "0.7"
# Async traits
async-trait = "0.1"
# libc for uid checks
libc = "0.2"
# Internal crates
rustproxy-config = { path = "crates/rustproxy-config" }
rustproxy-routing = { path = "crates/rustproxy-routing" }
rustproxy-tls = { path = "crates/rustproxy-tls" }
rustproxy-passthrough = { path = "crates/rustproxy-passthrough" }
rustproxy-http = { path = "crates/rustproxy-http" }
rustproxy-nftables = { path = "crates/rustproxy-nftables" }
rustproxy-metrics = { path = "crates/rustproxy-metrics" }
rustproxy-security = { path = "crates/rustproxy-security" }

145
rust/config/example.json Normal file
View File

@@ -0,0 +1,145 @@
{
"routes": [
{
"id": "https-passthrough",
"name": "HTTPS Passthrough to Backend",
"match": {
"ports": 443,
"domains": "backend.example.com"
},
"action": {
"type": "forward",
"targets": [
{
"host": "10.0.0.1",
"port": 443
}
],
"tls": {
"mode": "passthrough"
}
},
"priority": 10,
"enabled": true
},
{
"id": "https-terminate",
"name": "HTTPS Terminate for API",
"match": {
"ports": 443,
"domains": "api.example.com"
},
"action": {
"type": "forward",
"targets": [
{
"host": "localhost",
"port": 8080
}
],
"tls": {
"mode": "terminate",
"certificate": "auto"
}
},
"priority": 20,
"enabled": true
},
{
"id": "http-redirect",
"name": "HTTP to HTTPS Redirect",
"match": {
"ports": 80,
"domains": ["api.example.com", "www.example.com"]
},
"action": {
"type": "forward",
"targets": [
{
"host": "localhost",
"port": 8080
}
]
},
"priority": 0
},
{
"id": "load-balanced",
"name": "Load Balanced Backend",
"match": {
"ports": 443,
"domains": "*.example.com"
},
"action": {
"type": "forward",
"targets": [
{
"host": "backend1.internal",
"port": 8080
},
{
"host": "backend2.internal",
"port": 8080
},
{
"host": "backend3.internal",
"port": 8080
}
],
"tls": {
"mode": "terminate",
"certificate": "auto"
},
"loadBalancing": {
"algorithm": "round-robin",
"healthCheck": {
"path": "/health",
"interval": 30,
"timeout": 5,
"unhealthyThreshold": 3,
"healthyThreshold": 2
}
}
},
"security": {
"ipAllowList": ["10.0.0.0/8", "192.168.0.0/16"],
"maxConnections": 1000,
"rateLimit": {
"enabled": true,
"maxRequests": 100,
"window": 60,
"keyBy": "ip"
}
},
"headers": {
"request": {
"X-Forwarded-For": "{clientIp}",
"X-Real-IP": "{clientIp}"
},
"response": {
"X-Powered-By": "RustProxy"
},
"cors": {
"enabled": true,
"allowOrigin": "*",
"allowMethods": "GET,POST,PUT,DELETE,OPTIONS",
"allowHeaders": "Content-Type,Authorization",
"allowCredentials": false,
"maxAge": 86400
}
},
"priority": 5
}
],
"acme": {
"email": "admin@example.com",
"useProduction": false,
"port": 80
},
"connectionTimeout": 30000,
"socketTimeout": 3600000,
"maxConnectionsPerIp": 100,
"connectionRateLimitPerMinute": 300,
"keepAliveTreatment": "extended",
"enableDetailedLogging": false
}

View File

@@ -0,0 +1,13 @@
[package]
name = "rustproxy-config"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Configuration types for RustProxy, compatible with SmartProxy JSON schema"
[dependencies]
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
ipnet = { workspace = true }

View File

@@ -0,0 +1,337 @@
use crate::route_types::*;
use crate::tls_types::*;
/// Create a simple HTTP forwarding route.
/// Equivalent to SmartProxy's `createHttpRoute()`.
pub fn create_http_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(80),
domains: Some(domains.into()),
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(vec![RouteTarget {
target_match: None,
host: HostSpec::Single(target_host.into()),
port: PortSpec::Fixed(target_port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: None,
tags: None,
enabled: None,
}
}
/// Create an HTTPS termination route.
/// Equivalent to SmartProxy's `createHttpsTerminateRoute()`.
pub fn create_https_terminate_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
let mut route = create_http_route(domains, target_host, target_port);
route.route_match.ports = PortRange::Single(443);
route.action.tls = Some(RouteTls {
mode: TlsMode::Terminate,
certificate: Some(CertificateSpec::Auto("auto".to_string())),
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
/// Create a TLS passthrough route.
/// Equivalent to SmartProxy's `createHttpsPassthroughRoute()`.
pub fn create_https_passthrough_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
let mut route = create_http_route(domains, target_host, target_port);
route.route_match.ports = PortRange::Single(443);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
/// Create an HTTP-to-HTTPS redirect route.
/// Equivalent to SmartProxy's `createHttpToHttpsRedirect()`.
pub fn create_http_to_https_redirect(
domains: impl Into<DomainSpec>,
) -> RouteConfig {
let domains = domains.into();
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(80),
domains: Some(domains),
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: None,
tls: None,
websocket: None,
load_balancing: None,
advanced: Some(RouteAdvanced {
timeout: None,
headers: None,
keep_alive: None,
static_files: None,
test_response: Some(RouteTestResponse {
status: 301,
headers: {
let mut h = std::collections::HashMap::new();
h.insert("Location".to_string(), "https://{domain}{path}".to_string());
h
},
body: String::new(),
}),
url_rewrite: None,
}),
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: Some("HTTP to HTTPS Redirect".to_string()),
description: None,
priority: None,
tags: None,
enabled: None,
}
}
/// Create a complete HTTPS server with HTTP redirect.
/// Equivalent to SmartProxy's `createCompleteHttpsServer()`.
pub fn create_complete_https_server(
domain: impl Into<String>,
target_host: impl Into<String>,
target_port: u16,
) -> Vec<RouteConfig> {
let domain = domain.into();
let target_host = target_host.into();
vec![
create_http_to_https_redirect(DomainSpec::Single(domain.clone())),
create_https_terminate_route(
DomainSpec::Single(domain),
target_host,
target_port,
),
]
}
/// Create a load balancer route.
/// Equivalent to SmartProxy's `createLoadBalancerRoute()`.
pub fn create_load_balancer_route(
domains: impl Into<DomainSpec>,
targets: Vec<(String, u16)>,
tls: Option<RouteTls>,
) -> RouteConfig {
let route_targets: Vec<RouteTarget> = targets
.into_iter()
.map(|(host, port)| RouteTarget {
target_match: None,
host: HostSpec::Single(host),
port: PortSpec::Fixed(port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
})
.collect();
let port = if tls.is_some() { 443 } else { 80 };
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(port),
domains: Some(domains.into()),
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(route_targets),
tls,
websocket: None,
load_balancing: Some(RouteLoadBalancing {
algorithm: LoadBalancingAlgorithm::RoundRobin,
health_check: None,
}),
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: Some("Load Balancer".to_string()),
description: None,
priority: None,
tags: None,
enabled: None,
}
}
// Convenience conversions for DomainSpec
impl From<&str> for DomainSpec {
fn from(s: &str) -> Self {
DomainSpec::Single(s.to_string())
}
}
impl From<String> for DomainSpec {
fn from(s: String) -> Self {
DomainSpec::Single(s)
}
}
impl From<Vec<String>> for DomainSpec {
fn from(v: Vec<String>) -> Self {
DomainSpec::List(v)
}
}
impl From<Vec<&str>> for DomainSpec {
fn from(v: Vec<&str>) -> Self {
DomainSpec::List(v.into_iter().map(|s| s.to_string()).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tls_types::TlsMode;
#[test]
fn test_create_http_route() {
let route = create_http_route("example.com", "localhost", 8080);
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
let domains = route.route_match.domains.as_ref().unwrap().to_vec();
assert_eq!(domains, vec!["example.com"]);
let target = &route.action.targets.as_ref().unwrap()[0];
assert_eq!(target.host.first(), "localhost");
assert_eq!(target.port.resolve(80), 8080);
assert!(route.action.tls.is_none());
}
#[test]
fn test_create_https_terminate_route() {
let route = create_https_terminate_route("api.example.com", "backend", 3000);
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
let tls = route.action.tls.as_ref().unwrap();
assert_eq!(tls.mode, TlsMode::Terminate);
assert!(tls.certificate.as_ref().unwrap().is_auto());
}
#[test]
fn test_create_https_passthrough_route() {
let route = create_https_passthrough_route("secure.example.com", "backend", 443);
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
let tls = route.action.tls.as_ref().unwrap();
assert_eq!(tls.mode, TlsMode::Passthrough);
assert!(tls.certificate.is_none());
}
#[test]
fn test_create_http_to_https_redirect() {
let route = create_http_to_https_redirect("example.com");
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
assert!(route.action.targets.is_none());
let test_response = route.action.advanced.as_ref().unwrap().test_response.as_ref().unwrap();
assert_eq!(test_response.status, 301);
assert!(test_response.headers.contains_key("Location"));
}
#[test]
fn test_create_complete_https_server() {
let routes = create_complete_https_server("example.com", "backend", 8080);
assert_eq!(routes.len(), 2);
// First route is HTTP redirect
assert_eq!(routes[0].route_match.ports.to_ports(), vec![80]);
// Second route is HTTPS terminate
assert_eq!(routes[1].route_match.ports.to_ports(), vec![443]);
}
#[test]
fn test_create_load_balancer_route() {
let targets = vec![
("backend1".to_string(), 8080),
("backend2".to_string(), 8080),
("backend3".to_string(), 8080),
];
let route = create_load_balancer_route("*.example.com", targets, None);
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
assert_eq!(route.action.targets.as_ref().unwrap().len(), 3);
let lb = route.action.load_balancing.as_ref().unwrap();
assert_eq!(lb.algorithm, LoadBalancingAlgorithm::RoundRobin);
}
#[test]
fn test_domain_spec_from_str() {
let spec: DomainSpec = "example.com".into();
assert_eq!(spec.to_vec(), vec!["example.com"]);
}
#[test]
fn test_domain_spec_from_vec() {
let spec: DomainSpec = vec!["a.com", "b.com"].into();
assert_eq!(spec.to_vec(), vec!["a.com", "b.com"]);
}
}

View File

@@ -0,0 +1,19 @@
//! # rustproxy-config
//!
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
pub mod route_types;
pub mod proxy_options;
pub mod tls_types;
pub mod security_types;
pub mod validation;
pub mod helpers;
// Re-export all primary types
pub use route_types::*;
pub use proxy_options::*;
pub use tls_types::*;
pub use security_types::*;
pub use validation::*;
pub use helpers::*;

View File

@@ -0,0 +1,435 @@
use serde::{Deserialize, Serialize};
use crate::route_types::RouteConfig;
/// Global ACME configuration options.
/// Matches TypeScript: `IAcmeOptions`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AcmeOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub enabled: Option<bool>,
/// Required when any route uses certificate: 'auto'
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub environment: Option<AcmeEnvironment>,
/// Alias for email
#[serde(skip_serializing_if = "Option::is_none")]
pub account_email: Option<String>,
/// Port for HTTP-01 challenges (default: 80)
#[serde(skip_serializing_if = "Option::is_none")]
pub port: Option<u16>,
/// Use Let's Encrypt production (default: false)
#[serde(skip_serializing_if = "Option::is_none")]
pub use_production: Option<bool>,
/// Days before expiry to renew (default: 30)
#[serde(skip_serializing_if = "Option::is_none")]
pub renew_threshold_days: Option<u32>,
/// Enable automatic renewal (default: true)
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_renew: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_configured_certs: Option<bool>,
/// How often to check for renewals (default: 24)
#[serde(skip_serializing_if = "Option::is_none")]
pub renew_check_interval_hours: Option<u32>,
}
/// ACME environment.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AcmeEnvironment {
Production,
Staging,
}
/// Default target configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DefaultTarget {
pub host: String,
pub port: u16,
}
/// Default security configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DefaultSecurity {
#[serde(skip_serializing_if = "Option::is_none")]
pub ip_allow_list: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ip_block_list: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_connections: Option<u64>,
}
/// Default configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DefaultConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub target: Option<DefaultTarget>,
#[serde(skip_serializing_if = "Option::is_none")]
pub security: Option<DefaultSecurity>,
#[serde(skip_serializing_if = "Option::is_none")]
pub preserve_source_ip: Option<bool>,
}
/// Keep-alive treatment.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum KeepAliveTreatment {
Standard,
Extended,
Immortal,
}
/// Metrics configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MetricsConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub enabled: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sample_interval_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub retention_seconds: Option<u64>,
}
/// RustProxy configuration options.
/// Matches TypeScript: `ISmartProxyOptions`
///
/// This is the top-level configuration that can be loaded from a JSON file
/// or constructed programmatically.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RustProxyOptions {
/// The unified configuration array (required)
pub routes: Vec<RouteConfig>,
/// Preserve client IP when forwarding
#[serde(skip_serializing_if = "Option::is_none")]
pub preserve_source_ip: Option<bool>,
/// List of trusted proxy IPs that can send PROXY protocol
#[serde(skip_serializing_if = "Option::is_none")]
pub proxy_ips: Option<Vec<String>>,
/// Global option to accept PROXY protocol
#[serde(skip_serializing_if = "Option::is_none")]
pub accept_proxy_protocol: Option<bool>,
/// Global option to send PROXY protocol to all targets
#[serde(skip_serializing_if = "Option::is_none")]
pub send_proxy_protocol: Option<bool>,
/// Global/default settings
#[serde(skip_serializing_if = "Option::is_none")]
pub defaults: Option<DefaultConfig>,
// ─── Timeout Settings ────────────────────────────────────────────
/// Timeout for establishing connection to backend (ms), default: 30000
#[serde(skip_serializing_if = "Option::is_none")]
pub connection_timeout: Option<u64>,
/// Timeout for initial data/SNI (ms), default: 60000
#[serde(skip_serializing_if = "Option::is_none")]
pub initial_data_timeout: Option<u64>,
/// Socket inactivity timeout (ms), default: 3600000
#[serde(skip_serializing_if = "Option::is_none")]
pub socket_timeout: Option<u64>,
/// How often to check for inactive connections (ms), default: 60000
#[serde(skip_serializing_if = "Option::is_none")]
pub inactivity_check_interval: Option<u64>,
/// Default max connection lifetime (ms), default: 86400000
#[serde(skip_serializing_if = "Option::is_none")]
pub max_connection_lifetime: Option<u64>,
/// Inactivity timeout (ms), default: 14400000
#[serde(skip_serializing_if = "Option::is_none")]
pub inactivity_timeout: Option<u64>,
/// Maximum time to wait for connections to close during shutdown (ms)
#[serde(skip_serializing_if = "Option::is_none")]
pub graceful_shutdown_timeout: Option<u64>,
// ─── Socket Optimization ─────────────────────────────────────────
/// Disable Nagle's algorithm (default: true)
#[serde(skip_serializing_if = "Option::is_none")]
pub no_delay: Option<bool>,
/// Enable TCP keepalive (default: true)
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive: Option<bool>,
/// Initial delay before sending keepalive probes (ms)
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive_initial_delay: Option<u64>,
/// Maximum bytes to buffer during connection setup
#[serde(skip_serializing_if = "Option::is_none")]
pub max_pending_data_size: Option<u64>,
// ─── Enhanced Features ───────────────────────────────────────────
/// Disable inactivity checking entirely
#[serde(skip_serializing_if = "Option::is_none")]
pub disable_inactivity_check: Option<bool>,
/// Enable TCP keep-alive probes
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_keep_alive_probes: Option<bool>,
/// Enable detailed connection logging
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_detailed_logging: Option<bool>,
/// Enable TLS handshake debug logging
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_tls_debug_logging: Option<bool>,
/// Randomize timeouts to prevent thundering herd
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_randomized_timeouts: Option<bool>,
// ─── Rate Limiting ───────────────────────────────────────────────
/// Maximum simultaneous connections from a single IP
#[serde(skip_serializing_if = "Option::is_none")]
pub max_connections_per_ip: Option<u64>,
/// Max new connections per minute from a single IP
#[serde(skip_serializing_if = "Option::is_none")]
pub connection_rate_limit_per_minute: Option<u64>,
// ─── Keep-Alive Settings ─────────────────────────────────────────
/// How to treat keep-alive connections
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive_treatment: Option<KeepAliveTreatment>,
/// Multiplier for inactivity timeout for keep-alive connections
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive_inactivity_multiplier: Option<f64>,
/// Extended lifetime for keep-alive connections (ms)
#[serde(skip_serializing_if = "Option::is_none")]
pub extended_keep_alive_lifetime: Option<u64>,
// ─── HttpProxy Integration ───────────────────────────────────────
/// Array of ports to forward to HttpProxy
#[serde(skip_serializing_if = "Option::is_none")]
pub use_http_proxy: Option<Vec<u16>>,
/// Port where HttpProxy is listening (default: 8443)
#[serde(skip_serializing_if = "Option::is_none")]
pub http_proxy_port: Option<u16>,
// ─── Metrics ─────────────────────────────────────────────────────
/// Metrics configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub metrics: Option<MetricsConfig>,
// ─── ACME ────────────────────────────────────────────────────────
/// Global ACME configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub acme: Option<AcmeOptions>,
}
impl Default for RustProxyOptions {
fn default() -> Self {
Self {
routes: Vec::new(),
preserve_source_ip: None,
proxy_ips: None,
accept_proxy_protocol: None,
send_proxy_protocol: None,
defaults: None,
connection_timeout: None,
initial_data_timeout: None,
socket_timeout: None,
inactivity_check_interval: None,
max_connection_lifetime: None,
inactivity_timeout: None,
graceful_shutdown_timeout: None,
no_delay: None,
keep_alive: None,
keep_alive_initial_delay: None,
max_pending_data_size: None,
disable_inactivity_check: None,
enable_keep_alive_probes: None,
enable_detailed_logging: None,
enable_tls_debug_logging: None,
enable_randomized_timeouts: None,
max_connections_per_ip: None,
connection_rate_limit_per_minute: None,
keep_alive_treatment: None,
keep_alive_inactivity_multiplier: None,
extended_keep_alive_lifetime: None,
use_http_proxy: None,
http_proxy_port: None,
metrics: None,
acme: None,
}
}
}
impl RustProxyOptions {
/// Load configuration from a JSON file.
pub fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let content = std::fs::read_to_string(path)?;
let options: Self = serde_json::from_str(&content)?;
Ok(options)
}
/// Get the effective connection timeout in milliseconds.
pub fn effective_connection_timeout(&self) -> u64 {
self.connection_timeout.unwrap_or(30_000)
}
/// Get the effective initial data timeout in milliseconds.
pub fn effective_initial_data_timeout(&self) -> u64 {
self.initial_data_timeout.unwrap_or(60_000)
}
/// Get the effective socket timeout in milliseconds.
pub fn effective_socket_timeout(&self) -> u64 {
self.socket_timeout.unwrap_or(3_600_000)
}
/// Get the effective max connection lifetime in milliseconds.
pub fn effective_max_connection_lifetime(&self) -> u64 {
self.max_connection_lifetime.unwrap_or(86_400_000)
}
/// Get all unique ports that routes listen on.
pub fn all_listening_ports(&self) -> Vec<u16> {
let mut ports: Vec<u16> = self.routes
.iter()
.flat_map(|r| r.listening_ports())
.collect();
ports.sort();
ports.dedup();
ports
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::helpers::*;
#[test]
fn test_serde_roundtrip_minimal() {
let options = RustProxyOptions {
routes: vec![create_http_route("example.com", "localhost", 8080)],
..Default::default()
};
let json = serde_json::to_string(&options).unwrap();
let parsed: RustProxyOptions = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.routes.len(), 1);
}
#[test]
fn test_serde_roundtrip_full() {
let options = RustProxyOptions {
routes: vec![
create_http_route("a.com", "backend1", 8080),
create_https_passthrough_route("b.com", "backend2", 443),
],
connection_timeout: Some(5000),
socket_timeout: Some(60000),
max_connections_per_ip: Some(100),
acme: Some(AcmeOptions {
enabled: Some(true),
email: Some("admin@example.com".to_string()),
environment: Some(AcmeEnvironment::Staging),
account_email: None,
port: None,
use_production: None,
renew_threshold_days: None,
auto_renew: None,
skip_configured_certs: None,
renew_check_interval_hours: None,
}),
..Default::default()
};
let json = serde_json::to_string_pretty(&options).unwrap();
let parsed: RustProxyOptions = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.routes.len(), 2);
assert_eq!(parsed.connection_timeout, Some(5000));
}
#[test]
fn test_default_timeouts() {
let options = RustProxyOptions::default();
assert_eq!(options.effective_connection_timeout(), 30_000);
assert_eq!(options.effective_initial_data_timeout(), 60_000);
assert_eq!(options.effective_socket_timeout(), 3_600_000);
assert_eq!(options.effective_max_connection_lifetime(), 86_400_000);
}
#[test]
fn test_custom_timeouts() {
let options = RustProxyOptions {
connection_timeout: Some(5000),
initial_data_timeout: Some(10000),
socket_timeout: Some(30000),
max_connection_lifetime: Some(60000),
..Default::default()
};
assert_eq!(options.effective_connection_timeout(), 5000);
assert_eq!(options.effective_initial_data_timeout(), 10000);
assert_eq!(options.effective_socket_timeout(), 30000);
assert_eq!(options.effective_max_connection_lifetime(), 60000);
}
#[test]
fn test_all_listening_ports() {
let options = RustProxyOptions {
routes: vec![
create_http_route("a.com", "backend", 8080), // port 80
create_https_passthrough_route("b.com", "backend", 443), // port 443
create_http_route("c.com", "backend", 9090), // port 80 (duplicate)
],
..Default::default()
};
let ports = options.all_listening_ports();
assert_eq!(ports, vec![80, 443]);
}
#[test]
fn test_camel_case_field_names() {
let options = RustProxyOptions {
connection_timeout: Some(5000),
max_connections_per_ip: Some(100),
keep_alive_treatment: Some(KeepAliveTreatment::Extended),
..Default::default()
};
let json = serde_json::to_string(&options).unwrap();
assert!(json.contains("connectionTimeout"));
assert!(json.contains("maxConnectionsPerIp"));
assert!(json.contains("keepAliveTreatment"));
}
#[test]
fn test_deserialize_example_json() {
let content = std::fs::read_to_string(
concat!(env!("CARGO_MANIFEST_DIR"), "/../../config/example.json")
).unwrap();
let options: RustProxyOptions = serde_json::from_str(&content).unwrap();
assert_eq!(options.routes.len(), 4);
let ports = options.all_listening_ports();
assert!(ports.contains(&80));
assert!(ports.contains(&443));
}
}

View File

@@ -0,0 +1,607 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::tls_types::RouteTls;
use crate::security_types::RouteSecurity;
// ─── Port Range ──────────────────────────────────────────────────────
/// Port range specification format.
/// Matches TypeScript: `type TPortRange = number | number[] | Array<{ from: number; to: number }>`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PortRange {
/// Single port number
Single(u16),
/// Array of port numbers
List(Vec<u16>),
/// Array of port ranges
Ranges(Vec<PortRangeSpec>),
}
impl PortRange {
/// Expand the port range into a flat list of ports.
pub fn to_ports(&self) -> Vec<u16> {
match self {
PortRange::Single(p) => vec![*p],
PortRange::List(ports) => ports.clone(),
PortRange::Ranges(ranges) => {
ranges.iter().flat_map(|r| r.from..=r.to).collect()
}
}
}
}
/// A from-to port range.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PortRangeSpec {
pub from: u16,
pub to: u16,
}
// ─── Route Action Type ───────────────────────────────────────────────
/// Supported action types for route configurations.
/// Matches TypeScript: `type TRouteActionType = 'forward' | 'socket-handler'`
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum RouteActionType {
Forward,
SocketHandler,
}
// ─── Forwarding Engine ───────────────────────────────────────────────
/// Forwarding engine specification.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ForwardingEngine {
Node,
Nftables,
}
// ─── Route Match ─────────────────────────────────────────────────────
/// Domain specification: single string or array.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum DomainSpec {
Single(String),
List(Vec<String>),
}
impl DomainSpec {
pub fn to_vec(&self) -> Vec<&str> {
match self {
DomainSpec::Single(s) => vec![s.as_str()],
DomainSpec::List(v) => v.iter().map(|s| s.as_str()).collect(),
}
}
}
/// Header match value: either exact string or regex pattern.
/// In JSON, all values come as strings. Regex patterns are prefixed with `/` and suffixed with `/`.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum HeaderMatchValue {
Exact(String),
}
/// Route match criteria for incoming requests.
/// Matches TypeScript: `IRouteMatch`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteMatch {
/// Listen on these ports (required)
pub ports: PortRange,
/// Optional domain patterns to match (default: all domains)
#[serde(skip_serializing_if = "Option::is_none")]
pub domains: Option<DomainSpec>,
/// Match specific paths
#[serde(skip_serializing_if = "Option::is_none")]
pub path: Option<String>,
/// Match specific client IPs
#[serde(skip_serializing_if = "Option::is_none")]
pub client_ip: Option<Vec<String>>,
/// Match specific TLS versions
#[serde(skip_serializing_if = "Option::is_none")]
pub tls_version: Option<Vec<String>>,
/// Match specific HTTP headers
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
/// Match specific protocol: "http" (includes h2 + websocket) or "tcp"
#[serde(skip_serializing_if = "Option::is_none")]
pub protocol: Option<String>,
}
// ─── Target Match ────────────────────────────────────────────────────
/// Target-specific match criteria for sub-routing within a route.
/// Matches TypeScript: `ITargetMatch`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TargetMatch {
/// Match specific ports from the route
#[serde(skip_serializing_if = "Option::is_none")]
pub ports: Option<Vec<u16>>,
/// Match specific paths (supports wildcards like /api/*)
#[serde(skip_serializing_if = "Option::is_none")]
pub path: Option<String>,
/// Match specific HTTP headers
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
/// Match specific HTTP methods
#[serde(skip_serializing_if = "Option::is_none")]
pub method: Option<Vec<String>>,
}
// ─── WebSocket Config ────────────────────────────────────────────────
/// WebSocket configuration.
/// Matches TypeScript: `IRouteWebSocket`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteWebSocket {
pub enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub ping_interval: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ping_timeout: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_payload_size: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub custom_headers: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub subprotocols: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rewrite_path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allowed_origins: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub authenticate_request: Option<bool>,
}
// ─── Load Balancing ──────────────────────────────────────────────────
/// Load balancing algorithm.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum LoadBalancingAlgorithm {
RoundRobin,
LeastConnections,
IpHash,
}
/// Health check configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct HealthCheck {
pub path: String,
pub interval: u64,
pub timeout: u64,
pub unhealthy_threshold: u32,
pub healthy_threshold: u32,
}
/// Load balancing configuration.
/// Matches TypeScript: `IRouteLoadBalancing`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteLoadBalancing {
pub algorithm: LoadBalancingAlgorithm,
#[serde(skip_serializing_if = "Option::is_none")]
pub health_check: Option<HealthCheck>,
}
// ─── CORS ────────────────────────────────────────────────────────────
/// Allowed origin specification.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AllowOrigin {
Single(String),
List(Vec<String>),
}
/// CORS configuration for a route.
/// Matches TypeScript: `IRouteCors`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteCors {
pub enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_origin: Option<AllowOrigin>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_methods: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_headers: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_credentials: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expose_headers: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_age: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub preflight: Option<bool>,
}
// ─── Headers ─────────────────────────────────────────────────────────
/// Headers configuration.
/// Matches TypeScript: `IRouteHeaders`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteHeaders {
/// Headers to add/modify for requests to backend
#[serde(skip_serializing_if = "Option::is_none")]
pub request: Option<HashMap<String, String>>,
/// Headers to add/modify for responses to client
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<HashMap<String, String>>,
/// CORS configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub cors: Option<RouteCors>,
}
// ─── Static Files ────────────────────────────────────────────────────
/// Static file server configuration.
/// Matches TypeScript: `IRouteStaticFiles`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteStaticFiles {
pub root: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub index: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub directory: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub index_files: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub follow_symlinks: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub disable_directory_listing: Option<bool>,
}
// ─── Test Response ───────────────────────────────────────────────────
/// Test route response configuration.
/// Matches TypeScript: `IRouteTestResponse`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteTestResponse {
pub status: u16,
pub headers: HashMap<String, String>,
pub body: String,
}
// ─── URL Rewriting ───────────────────────────────────────────────────
/// URL rewriting configuration.
/// Matches TypeScript: `IRouteUrlRewrite`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteUrlRewrite {
/// RegExp pattern to match in URL
pub pattern: String,
/// Replacement pattern
pub target: String,
/// RegExp flags
#[serde(skip_serializing_if = "Option::is_none")]
pub flags: Option<String>,
/// Only apply to path, not query string
#[serde(skip_serializing_if = "Option::is_none")]
pub only_rewrite_path: Option<bool>,
}
// ─── Advanced Options ────────────────────────────────────────────────
/// Advanced options for route actions.
/// Matches TypeScript: `IRouteAdvanced`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteAdvanced {
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub static_files: Option<RouteStaticFiles>,
#[serde(skip_serializing_if = "Option::is_none")]
pub test_response: Option<RouteTestResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub url_rewrite: Option<RouteUrlRewrite>,
}
// ─── NFTables Options ────────────────────────────────────────────────
/// NFTables protocol type.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum NfTablesProtocol {
Tcp,
Udp,
All,
}
/// NFTables-specific configuration options.
/// Matches TypeScript: `INfTablesOptions`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NfTablesOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub preserve_source_ip: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub protocol: Option<NfTablesProtocol>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_rate: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub table_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_ip_sets: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_advanced_nat: Option<bool>,
}
// ─── Backend Protocol ────────────────────────────────────────────────
/// Backend protocol.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum BackendProtocol {
Http1,
Http2,
}
/// Action options.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ActionOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub backend_protocol: Option<BackendProtocol>,
/// Catch-all for additional options
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
// ─── Route Target ────────────────────────────────────────────────────
/// Host specification: single string or array of strings.
/// Note: Dynamic host functions are only available via programmatic API, not JSON.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum HostSpec {
Single(String),
List(Vec<String>),
}
impl HostSpec {
pub fn to_vec(&self) -> Vec<&str> {
match self {
HostSpec::Single(s) => vec![s.as_str()],
HostSpec::List(v) => v.iter().map(|s| s.as_str()).collect(),
}
}
pub fn first(&self) -> &str {
match self {
HostSpec::Single(s) => s.as_str(),
HostSpec::List(v) => v.first().map(|s| s.as_str()).unwrap_or(""),
}
}
}
/// Port specification: number or "preserve".
/// Note: Dynamic port functions are only available via programmatic API, not JSON.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PortSpec {
/// Fixed port number
Fixed(u16),
/// Special string value like "preserve"
Special(String),
}
impl PortSpec {
/// Resolve the port, using incoming_port when "preserve" is specified.
pub fn resolve(&self, incoming_port: u16) -> u16 {
match self {
PortSpec::Fixed(p) => *p,
PortSpec::Special(s) if s == "preserve" => incoming_port,
PortSpec::Special(_) => incoming_port, // fallback
}
}
}
/// Target configuration for forwarding with sub-matching and overrides.
/// Matches TypeScript: `IRouteTarget`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteTarget {
/// Optional sub-matching criteria within the route
#[serde(rename = "match")]
#[serde(skip_serializing_if = "Option::is_none")]
pub target_match: Option<TargetMatch>,
/// Target host(s)
pub host: HostSpec,
/// Target port
pub port: PortSpec,
/// Override route-level TLS settings
#[serde(skip_serializing_if = "Option::is_none")]
pub tls: Option<RouteTls>,
/// Override route-level WebSocket settings
#[serde(skip_serializing_if = "Option::is_none")]
pub websocket: Option<RouteWebSocket>,
/// Override route-level load balancing
#[serde(skip_serializing_if = "Option::is_none")]
pub load_balancing: Option<RouteLoadBalancing>,
/// Override route-level proxy protocol setting
#[serde(skip_serializing_if = "Option::is_none")]
pub send_proxy_protocol: Option<bool>,
/// Override route-level headers
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<RouteHeaders>,
/// Override route-level advanced settings
#[serde(skip_serializing_if = "Option::is_none")]
pub advanced: Option<RouteAdvanced>,
/// Priority for matching (higher values checked first, default: 0)
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
}
// ─── Route Action ────────────────────────────────────────────────────
/// Action configuration for route handling.
/// Matches TypeScript: `IRouteAction`
///
/// Note: `socketHandler` is not serializable in JSON. Use the programmatic API
/// for socket handler routes.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteAction {
/// Basic routing type
#[serde(rename = "type")]
pub action_type: RouteActionType,
/// Targets for forwarding (array supports multiple targets with sub-matching)
#[serde(skip_serializing_if = "Option::is_none")]
pub targets: Option<Vec<RouteTarget>>,
/// TLS handling (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub tls: Option<RouteTls>,
/// WebSocket support (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub websocket: Option<RouteWebSocket>,
/// Load balancing options (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub load_balancing: Option<RouteLoadBalancing>,
/// Advanced options (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub advanced: Option<RouteAdvanced>,
/// Additional options
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<ActionOptions>,
/// Forwarding engine specification
#[serde(skip_serializing_if = "Option::is_none")]
pub forwarding_engine: Option<ForwardingEngine>,
/// NFTables-specific options
#[serde(skip_serializing_if = "Option::is_none")]
pub nftables: Option<NfTablesOptions>,
/// PROXY protocol support (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub send_proxy_protocol: Option<bool>,
}
// ─── Route Config ────────────────────────────────────────────────────
/// The core unified configuration interface.
/// Matches TypeScript: `IRouteConfig`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteConfig {
/// Unique identifier
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
/// What to match
#[serde(rename = "match")]
pub route_match: RouteMatch,
/// What to do with matched traffic
pub action: RouteAction,
/// Custom headers
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<RouteHeaders>,
/// Security features
#[serde(skip_serializing_if = "Option::is_none")]
pub security: Option<RouteSecurity>,
/// Human-readable name for this route
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
/// Description of the route's purpose
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
/// Controls matching order (higher = matched first)
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
/// Arbitrary tags for categorization
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<String>>,
/// Whether the route is active (default: true)
#[serde(skip_serializing_if = "Option::is_none")]
pub enabled: Option<bool>,
}
impl RouteConfig {
/// Check if this route is enabled (defaults to true).
pub fn is_enabled(&self) -> bool {
self.enabled.unwrap_or(true)
}
/// Get the effective priority (defaults to 0).
pub fn effective_priority(&self) -> i32 {
self.priority.unwrap_or(0)
}
/// Get all ports this route listens on.
pub fn listening_ports(&self) -> Vec<u16> {
self.route_match.ports.to_ports()
}
/// Get the TLS mode for this route (from action-level or first target).
pub fn tls_mode(&self) -> Option<&crate::tls_types::TlsMode> {
// Check action-level TLS first
if let Some(tls) = &self.action.tls {
return Some(&tls.mode);
}
// Check first target's TLS
if let Some(targets) = &self.action.targets {
if let Some(first) = targets.first() {
if let Some(tls) = &first.tls {
return Some(&tls.mode);
}
}
}
None
}
}

View File

@@ -0,0 +1,132 @@
use serde::{Deserialize, Serialize};
/// Rate limiting configuration.
/// Matches TypeScript: `IRouteRateLimit`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteRateLimit {
pub enabled: bool,
pub max_requests: u64,
/// Time window in seconds
pub window: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub key_by: Option<RateLimitKeyBy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub header_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_message: Option<String>,
}
/// Rate limit key selection.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RateLimitKeyBy {
Ip,
Path,
Header,
}
/// Authentication type.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AuthenticationType {
Basic,
Digest,
Oauth,
Jwt,
}
/// Authentication credentials.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AuthCredentials {
pub username: String,
pub password: String,
}
/// Authentication options.
/// Matches TypeScript: `IRouteAuthentication`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteAuthentication {
#[serde(rename = "type")]
pub auth_type: AuthenticationType,
#[serde(skip_serializing_if = "Option::is_none")]
pub credentials: Option<Vec<AuthCredentials>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub realm: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwt_secret: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwt_issuer: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth_provider: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth_client_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth_client_secret: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth_redirect_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<serde_json::Value>,
}
/// Basic auth configuration.
/// Matches TypeScript: `IRouteSecurity.basicAuth`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BasicAuthConfig {
pub enabled: bool,
pub users: Vec<AuthCredentials>,
#[serde(skip_serializing_if = "Option::is_none")]
pub realm: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exclude_paths: Option<Vec<String>>,
}
/// JWT auth configuration.
/// Matches TypeScript: `IRouteSecurity.jwtAuth`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct JwtAuthConfig {
pub enabled: bool,
pub secret: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub algorithm: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub issuer: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audience: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_in: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exclude_paths: Option<Vec<String>>,
}
/// Security options for routes.
/// Matches TypeScript: `IRouteSecurity`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteSecurity {
/// IP addresses that are allowed to connect
#[serde(skip_serializing_if = "Option::is_none")]
pub ip_allow_list: Option<Vec<String>>,
/// IP addresses that are blocked from connecting
#[serde(skip_serializing_if = "Option::is_none")]
pub ip_block_list: Option<Vec<String>>,
/// Maximum concurrent connections
#[serde(skip_serializing_if = "Option::is_none")]
pub max_connections: Option<u64>,
/// Authentication configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub authentication: Option<RouteAuthentication>,
/// Rate limiting
#[serde(skip_serializing_if = "Option::is_none")]
pub rate_limit: Option<RouteRateLimit>,
/// Basic auth
#[serde(skip_serializing_if = "Option::is_none")]
pub basic_auth: Option<BasicAuthConfig>,
/// JWT auth
#[serde(skip_serializing_if = "Option::is_none")]
pub jwt_auth: Option<JwtAuthConfig>,
}

View File

@@ -0,0 +1,93 @@
use serde::{Deserialize, Serialize};
/// TLS handling modes for route configurations.
/// Matches TypeScript: `type TTlsMode = 'passthrough' | 'terminate' | 'terminate-and-reencrypt'`
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum TlsMode {
Passthrough,
Terminate,
TerminateAndReencrypt,
}
/// Static certificate configuration (PEM-encoded).
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CertificateConfig {
/// PEM-encoded private key
pub key: String,
/// PEM-encoded certificate
pub cert: String,
/// PEM-encoded CA chain
#[serde(skip_serializing_if = "Option::is_none")]
pub ca: Option<String>,
/// Path to key file (overrides key)
#[serde(skip_serializing_if = "Option::is_none")]
pub key_file: Option<String>,
/// Path to cert file (overrides cert)
#[serde(skip_serializing_if = "Option::is_none")]
pub cert_file: Option<String>,
}
/// Certificate specification: either automatic (ACME) or static.
/// Matches TypeScript: `certificate?: 'auto' | { key, cert, ca?, keyFile?, certFile? }`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CertificateSpec {
/// Use ACME (Let's Encrypt) for automatic provisioning
Auto(String), // "auto"
/// Static certificate configuration
Static(CertificateConfig),
}
impl CertificateSpec {
/// Check if this is an auto (ACME) certificate
pub fn is_auto(&self) -> bool {
matches!(self, CertificateSpec::Auto(s) if s == "auto")
}
}
/// ACME configuration for automatic certificate provisioning.
/// Matches TypeScript: `IRouteAcme`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteAcme {
/// Contact email for ACME account
pub email: String,
/// Use production ACME servers (default: false)
#[serde(skip_serializing_if = "Option::is_none")]
pub use_production: Option<bool>,
/// Port for HTTP-01 challenges (default: 80)
#[serde(skip_serializing_if = "Option::is_none")]
pub challenge_port: Option<u16>,
/// Days before expiry to renew (default: 30)
#[serde(skip_serializing_if = "Option::is_none")]
pub renew_before_days: Option<u32>,
}
/// TLS configuration for route actions.
/// Matches TypeScript: `IRouteTls`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteTls {
/// TLS mode (passthrough, terminate, terminate-and-reencrypt)
pub mode: TlsMode,
/// Certificate configuration (auto or static)
#[serde(skip_serializing_if = "Option::is_none")]
pub certificate: Option<CertificateSpec>,
/// ACME options when certificate is 'auto'
#[serde(skip_serializing_if = "Option::is_none")]
pub acme: Option<RouteAcme>,
/// Allowed TLS versions
#[serde(skip_serializing_if = "Option::is_none")]
pub versions: Option<Vec<String>>,
/// OpenSSL cipher string
#[serde(skip_serializing_if = "Option::is_none")]
pub ciphers: Option<String>,
/// Use server's cipher preferences
#[serde(skip_serializing_if = "Option::is_none")]
pub honor_cipher_order: Option<bool>,
/// TLS session timeout in seconds
#[serde(skip_serializing_if = "Option::is_none")]
pub session_timeout: Option<u64>,
}

View File

@@ -0,0 +1,158 @@
use thiserror::Error;
use crate::route_types::{RouteConfig, RouteActionType};
/// Validation errors for route configurations.
#[derive(Debug, Error)]
pub enum ValidationError {
#[error("Route '{name}' has no targets but action type is 'forward'")]
MissingTargets { name: String },
#[error("Route '{name}' has empty targets list")]
EmptyTargets { name: String },
#[error("Route '{name}' has no ports specified")]
NoPorts { name: String },
#[error("Route '{name}' port {port} is invalid (must be 1-65535)")]
InvalidPort { name: String, port: u16 },
#[error("Route '{name}': socket-handler action type is not supported in JSON config")]
SocketHandlerInJson { name: String },
#[error("Route '{name}': duplicate route ID '{id}'")]
DuplicateId { name: String, id: String },
#[error("Route '{name}': {message}")]
Custom { name: String, message: String },
}
/// Validate a single route configuration.
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
let mut errors = Vec::new();
let name = route.name.clone().unwrap_or_else(|| {
route.id.clone().unwrap_or_else(|| "unnamed".to_string())
});
// Check ports
let ports = route.listening_ports();
if ports.is_empty() {
errors.push(ValidationError::NoPorts { name: name.clone() });
}
for &port in &ports {
if port == 0 {
errors.push(ValidationError::InvalidPort {
name: name.clone(),
port,
});
}
}
// Check forward action has targets
if route.action.action_type == RouteActionType::Forward {
match &route.action.targets {
None => {
errors.push(ValidationError::MissingTargets { name: name.clone() });
}
Some(targets) if targets.is_empty() => {
errors.push(ValidationError::EmptyTargets { name: name.clone() });
}
_ => {}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
/// Validate an entire list of routes.
pub fn validate_routes(routes: &[RouteConfig]) -> Result<(), Vec<ValidationError>> {
let mut all_errors = Vec::new();
let mut seen_ids = std::collections::HashSet::new();
for route in routes {
// Check for duplicate IDs
if let Some(id) = &route.id {
if !seen_ids.insert(id.clone()) {
let name = route.name.clone().unwrap_or_else(|| id.clone());
all_errors.push(ValidationError::DuplicateId {
name,
id: id.clone(),
});
}
}
// Validate individual route
if let Err(errors) = validate_route(route) {
all_errors.extend(errors);
}
}
if all_errors.is_empty() {
Ok(())
} else {
Err(all_errors)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::route_types::*;
fn make_valid_route() -> RouteConfig {
crate::helpers::create_http_route("example.com", "localhost", 8080)
}
#[test]
fn test_valid_route_passes() {
let route = make_valid_route();
assert!(validate_route(&route).is_ok());
}
#[test]
fn test_missing_targets() {
let mut route = make_valid_route();
route.action.targets = None;
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. })));
}
#[test]
fn test_empty_targets() {
let mut route = make_valid_route();
route.action.targets = Some(vec![]);
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
}
#[test]
fn test_invalid_port_zero() {
let mut route = make_valid_route();
route.route_match.ports = PortRange::Single(0);
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
}
#[test]
fn test_duplicate_ids() {
let mut r1 = make_valid_route();
r1.id = Some("route-1".to_string());
let mut r2 = make_valid_route();
r2.id = Some("route-1".to_string());
let errors = validate_routes(&[r1, r2]).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. })));
}
#[test]
fn test_multiple_errors_collected() {
let mut r1 = make_valid_route();
r1.action.targets = None; // MissingTargets
r1.route_match.ports = PortRange::Single(0); // InvalidPort
let errors = validate_route(&r1).unwrap_err();
assert!(errors.len() >= 2);
}
}

View File

@@ -0,0 +1,28 @@
[package]
name = "rustproxy-http"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Hyper-based HTTP proxy service for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-security = { workspace = true }
rustproxy-metrics = { workspace = true }
hyper = { workspace = true }
hyper-util = { workspace = true }
regex = { workspace = true }
http-body = { workspace = true }
http-body-util = { workspace = true }
bytes = { workspace = true }
tokio = { workspace = true }
rustls = { workspace = true }
tokio-rustls = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
arc-swap = { workspace = true }
dashmap = { workspace = true }
tokio-util = { workspace = true }

View File

@@ -0,0 +1,126 @@
//! A body wrapper that counts bytes flowing through and reports them to MetricsCollector.
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use bytes::Bytes;
use http_body::Frame;
use rustproxy_metrics::MetricsCollector;
/// Wraps any `http_body::Body` and counts data bytes passing through.
///
/// When the body is fully consumed or dropped, accumulated byte counts
/// are reported to the `MetricsCollector`.
///
/// The inner body is pinned on the heap to support `!Unpin` types like `hyper::body::Incoming`.
pub struct CountingBody<B> {
inner: Pin<Box<B>>,
counted_bytes: AtomicU64,
metrics: Arc<MetricsCollector>,
route_id: Option<String>,
source_ip: Option<String>,
/// Whether we count bytes as "in" (request body) or "out" (response body).
direction: Direction,
/// Whether we've already reported the bytes (to avoid double-reporting on drop).
reported: bool,
}
/// Which direction the bytes flow.
#[derive(Clone, Copy)]
pub enum Direction {
/// Request body: bytes flowing from client → upstream (counted as bytes_in)
In,
/// Response body: bytes flowing from upstream → client (counted as bytes_out)
Out,
}
impl<B> CountingBody<B> {
/// Create a new CountingBody wrapping an inner body.
pub fn new(
inner: B,
metrics: Arc<MetricsCollector>,
route_id: Option<String>,
source_ip: Option<String>,
direction: Direction,
) -> Self {
Self {
inner: Box::pin(inner),
counted_bytes: AtomicU64::new(0),
metrics,
route_id,
source_ip,
direction,
reported: false,
}
}
/// Report accumulated bytes to the metrics collector.
fn report(&mut self) {
if self.reported {
return;
}
self.reported = true;
let bytes = self.counted_bytes.load(Ordering::Relaxed);
if bytes == 0 {
return;
}
let route_id = self.route_id.as_deref();
let source_ip = self.source_ip.as_deref();
match self.direction {
Direction::In => self.metrics.record_bytes(bytes, 0, route_id, source_ip),
Direction::Out => self.metrics.record_bytes(0, bytes, route_id, source_ip),
}
}
}
impl<B> Drop for CountingBody<B> {
fn drop(&mut self) {
self.report();
}
}
// CountingBody is Unpin because inner is Pin<Box<B>> (always Unpin).
impl<B> Unpin for CountingBody<B> {}
impl<B> http_body::Body for CountingBody<B>
where
B: http_body::Body<Data = Bytes>,
{
type Data = Bytes;
type Error = B::Error;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.get_mut();
match this.inner.as_mut().poll_frame(cx) {
Poll::Ready(Some(Ok(frame))) => {
if let Some(data) = frame.data_ref() {
this.counted_bytes.fetch_add(data.len() as u64, Ordering::Relaxed);
}
Poll::Ready(Some(Ok(frame)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
// Body is fully consumed — report now
this.report();
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}
}

View File

@@ -0,0 +1,16 @@
//! # rustproxy-http
//!
//! Hyper-based HTTP proxy service for RustProxy.
//! Handles HTTP request parsing, route-based forwarding, and response filtering.
pub mod counting_body;
pub mod proxy_service;
pub mod request_filter;
pub mod response_filter;
pub mod template;
pub mod upstream_selector;
pub use counting_body::*;
pub use proxy_service::*;
pub use template::*;
pub use upstream_selector::*;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,263 @@
//! Request filtering: security checks, auth, CORS preflight.
use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::Full;
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use rustproxy_config::RouteSecurity;
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
pub struct RequestFilter;
impl RequestFilter {
/// Apply security filters. Returns Some(response) if the request should be blocked.
pub fn apply(
security: &RouteSecurity,
req: &Request<Incoming>,
peer_addr: &SocketAddr,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
Self::apply_with_rate_limiter(security, req, peer_addr, None)
}
/// Apply security filters with an optional shared rate limiter.
/// Returns Some(response) if the request should be blocked.
pub fn apply_with_rate_limiter(
security: &RouteSecurity,
req: &Request<Incoming>,
peer_addr: &SocketAddr,
rate_limiter: Option<&Arc<RateLimiter>>,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
let client_ip = peer_addr.ip();
let request_path = req.uri().path();
// IP filter
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(&client_ip);
if !filter.is_allowed(&normalized) {
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
}
}
// Rate limiting
if let Some(ref rate_limit_config) = security.rate_limit {
if rate_limit_config.enabled {
// Use shared rate limiter if provided, otherwise create ephemeral one
let should_block = if let Some(limiter) = rate_limiter {
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
!limiter.check(&key)
} else {
// Create a per-check limiter (less ideal but works for non-shared case)
let limiter = RateLimiter::new(
rate_limit_config.max_requests,
rate_limit_config.window,
);
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
!limiter.check(&key)
};
if should_block {
let message = rate_limit_config.error_message
.as_deref()
.unwrap_or("Rate limit exceeded");
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
}
}
}
// Check exclude paths before auth
let should_skip_auth = Self::path_matches_exclude_list(request_path, security);
if !should_skip_auth {
// Basic auth
if let Some(ref basic_auth) = security.basic_auth {
if basic_auth.enabled {
// Check basic auth exclude paths
let skip_basic = basic_auth.exclude_paths.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
if !skip_basic {
let users: Vec<(String, String)> = basic_auth.users.iter()
.map(|c| (c.username.clone(), c.password.clone()))
.collect();
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
let auth_header = req.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header {
Some(header) => {
if validator.validate(header).is_none() {
return Some(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.body(boxed_body("Invalid credentials"))
.unwrap());
}
}
None => {
return Some(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.body(boxed_body("Authentication required"))
.unwrap());
}
}
}
}
}
// JWT auth
if let Some(ref jwt_auth) = security.jwt_auth {
if jwt_auth.enabled {
// Check JWT auth exclude paths
let skip_jwt = jwt_auth.exclude_paths.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
if !skip_jwt {
let validator = JwtValidator::new(
&jwt_auth.secret,
jwt_auth.algorithm.as_deref(),
jwt_auth.issuer.as_deref(),
jwt_auth.audience.as_deref(),
);
let auth_header = req.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header.and_then(JwtValidator::extract_token) {
Some(token) => {
if validator.validate(token).is_err() {
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
}
}
None => {
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
}
}
}
}
}
}
None
}
/// Check if a request path matches any pattern in the exclude list.
fn path_matches_exclude_list(_path: &str, _security: &RouteSecurity) -> bool {
// No global exclude paths on RouteSecurity currently,
// but we check per-auth exclude paths above.
// This can be extended if a global exclude_paths is added.
false
}
/// Check if a path matches any pattern in the list.
/// Supports simple glob patterns: `/health*` matches `/health`, `/healthz`, `/health/check`
fn path_matches_any(path: &str, patterns: &[String]) -> bool {
for pattern in patterns {
if pattern.ends_with('*') {
let prefix = &pattern[..pattern.len() - 1];
if path.starts_with(prefix) {
return true;
}
} else if path == pattern {
return true;
}
}
false
}
/// Determine the rate limit key based on configuration.
fn rate_limit_key(
config: &rustproxy_config::RouteRateLimit,
req: &Request<Incoming>,
peer_addr: &SocketAddr,
) -> String {
use rustproxy_config::RateLimitKeyBy;
match config.key_by.as_ref().unwrap_or(&RateLimitKeyBy::Ip) {
RateLimitKeyBy::Ip => peer_addr.ip().to_string(),
RateLimitKeyBy::Path => req.uri().path().to_string(),
RateLimitKeyBy::Header => {
if let Some(ref header_name) = config.header_name {
req.headers()
.get(header_name.as_str())
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string()
} else {
peer_addr.ip().to_string()
}
}
}
}
/// Check IP-based security (for use in passthrough / TCP-level connections).
/// Returns true if allowed, false if blocked.
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr) -> bool {
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(client_ip);
filter.is_allowed(&normalized)
} else {
true
}
}
/// Handle CORS preflight (OPTIONS) requests.
/// Returns Some(response) if this is a CORS preflight that should be handled.
pub fn handle_cors_preflight(
req: &Request<Incoming>,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
if req.method() != hyper::Method::OPTIONS {
return None;
}
// Check for CORS preflight indicators
let has_origin = req.headers().contains_key("origin");
let has_request_method = req.headers().contains_key("access-control-request-method");
if !has_origin || !has_request_method {
return None;
}
let origin = req.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("*");
Some(Response::builder()
.status(StatusCode::NO_CONTENT)
.header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
.header("Access-Control-Max-Age", "86400")
.body(boxed_body(""))
.unwrap())
}
}
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
Response::builder()
.status(status)
.header("Content-Type", "text/plain")
.body(boxed_body(message))
.unwrap()
}
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
}

View File

@@ -0,0 +1,92 @@
//! Response filtering: CORS headers, custom headers, security headers.
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
use rustproxy_config::RouteConfig;
use crate::template::{RequestContext, expand_template};
pub struct ResponseFilter;
impl ResponseFilter {
/// Apply response headers from route config and CORS settings.
/// If a `RequestContext` is provided, template variables in header values will be expanded.
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
// Apply custom response headers from route config
if let Some(ref route_headers) = route.headers {
if let Some(ref response_headers) = route_headers.response {
for (key, value) in response_headers {
if let Ok(name) = HeaderName::from_bytes(key.as_bytes()) {
let expanded = match req_ctx {
Some(ctx) => expand_template(value, ctx),
None => value.clone(),
};
if let Ok(val) = HeaderValue::from_str(&expanded) {
headers.insert(name, val);
}
}
}
}
// Apply CORS headers if configured
if let Some(ref cors) = route_headers.cors {
if cors.enabled {
Self::apply_cors_headers(cors, headers);
}
}
}
}
fn apply_cors_headers(cors: &rustproxy_config::RouteCors, headers: &mut HeaderMap) {
// Allow-Origin
if let Some(ref origin) = cors.allow_origin {
let origin_str = match origin {
rustproxy_config::AllowOrigin::Single(s) => s.clone(),
rustproxy_config::AllowOrigin::List(list) => list.join(", "),
};
if let Ok(val) = HeaderValue::from_str(&origin_str) {
headers.insert("access-control-allow-origin", val);
}
} else {
headers.insert(
"access-control-allow-origin",
HeaderValue::from_static("*"),
);
}
// Allow-Methods
if let Some(ref methods) = cors.allow_methods {
if let Ok(val) = HeaderValue::from_str(methods) {
headers.insert("access-control-allow-methods", val);
}
}
// Allow-Headers
if let Some(ref allow_headers) = cors.allow_headers {
if let Ok(val) = HeaderValue::from_str(allow_headers) {
headers.insert("access-control-allow-headers", val);
}
}
// Allow-Credentials
if cors.allow_credentials == Some(true) {
headers.insert(
"access-control-allow-credentials",
HeaderValue::from_static("true"),
);
}
// Expose-Headers
if let Some(ref expose) = cors.expose_headers {
if let Ok(val) = HeaderValue::from_str(expose) {
headers.insert("access-control-expose-headers", val);
}
}
// Max-Age
if let Some(max_age) = cors.max_age {
if let Ok(val) = HeaderValue::from_str(&max_age.to_string()) {
headers.insert("access-control-max-age", val);
}
}
}
}

View File

@@ -0,0 +1,162 @@
//! Header template variable expansion.
//!
//! Supports expanding template variables like `{clientIp}`, `{domain}`, etc.
//! in header values before they are applied to requests or responses.
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
/// Context for template variable expansion.
pub struct RequestContext {
pub client_ip: String,
pub domain: String,
pub port: u16,
pub path: String,
pub route_name: String,
pub connection_id: u64,
}
/// Expand template variables in a header value.
/// Supported variables: {clientIp}, {domain}, {port}, {path}, {routeName}, {connectionId}, {timestamp}
pub fn expand_template(template: &str, ctx: &RequestContext) -> String {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
template
.replace("{clientIp}", &ctx.client_ip)
.replace("{domain}", &ctx.domain)
.replace("{port}", &ctx.port.to_string())
.replace("{path}", &ctx.path)
.replace("{routeName}", &ctx.route_name)
.replace("{connectionId}", &ctx.connection_id.to_string())
.replace("{timestamp}", &timestamp.to_string())
}
/// Expand templates in a map of header key-value pairs.
pub fn expand_headers(
headers: &HashMap<String, String>,
ctx: &RequestContext,
) -> HashMap<String, String> {
headers.iter()
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn test_context() -> RequestContext {
RequestContext {
client_ip: "192.168.1.100".to_string(),
domain: "example.com".to_string(),
port: 443,
path: "/api/v1/users".to_string(),
route_name: "api-route".to_string(),
connection_id: 42,
}
}
#[test]
fn test_expand_client_ip() {
let ctx = test_context();
assert_eq!(expand_template("{clientIp}", &ctx), "192.168.1.100");
}
#[test]
fn test_expand_domain() {
let ctx = test_context();
assert_eq!(expand_template("{domain}", &ctx), "example.com");
}
#[test]
fn test_expand_port() {
let ctx = test_context();
assert_eq!(expand_template("{port}", &ctx), "443");
}
#[test]
fn test_expand_path() {
let ctx = test_context();
assert_eq!(expand_template("{path}", &ctx), "/api/v1/users");
}
#[test]
fn test_expand_route_name() {
let ctx = test_context();
assert_eq!(expand_template("{routeName}", &ctx), "api-route");
}
#[test]
fn test_expand_connection_id() {
let ctx = test_context();
assert_eq!(expand_template("{connectionId}", &ctx), "42");
}
#[test]
fn test_expand_timestamp() {
let ctx = test_context();
let result = expand_template("{timestamp}", &ctx);
// Timestamp should be a valid number
let ts: u64 = result.parse().expect("timestamp should be a number");
// Should be a reasonable Unix timestamp (after 2020)
assert!(ts > 1_577_836_800);
}
#[test]
fn test_expand_mixed_template() {
let ctx = test_context();
let result = expand_template("client={clientIp}, host={domain}:{port}", &ctx);
assert_eq!(result, "client=192.168.1.100, host=example.com:443");
}
#[test]
fn test_expand_no_variables() {
let ctx = test_context();
assert_eq!(expand_template("plain-value", &ctx), "plain-value");
}
#[test]
fn test_expand_empty_string() {
let ctx = test_context();
assert_eq!(expand_template("", &ctx), "");
}
#[test]
fn test_expand_multiple_same_variable() {
let ctx = test_context();
let result = expand_template("{clientIp}-{clientIp}", &ctx);
assert_eq!(result, "192.168.1.100-192.168.1.100");
}
#[test]
fn test_expand_headers_map() {
let ctx = test_context();
let mut headers = HashMap::new();
headers.insert("X-Forwarded-For".to_string(), "{clientIp}".to_string());
headers.insert("X-Route".to_string(), "{routeName}".to_string());
headers.insert("X-Static".to_string(), "no-template".to_string());
let result = expand_headers(&headers, &ctx);
assert_eq!(result.get("X-Forwarded-For").unwrap(), "192.168.1.100");
assert_eq!(result.get("X-Route").unwrap(), "api-route");
assert_eq!(result.get("X-Static").unwrap(), "no-template");
}
#[test]
fn test_expand_all_variables_in_one() {
let ctx = test_context();
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
let result = expand_template(template, &ctx);
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
}
#[test]
fn test_expand_unknown_variable_left_as_is() {
let ctx = test_context();
let result = expand_template("{unknownVar}", &ctx);
assert_eq!(result, "{unknownVar}");
}
}

View File

@@ -0,0 +1,222 @@
//! Route-aware upstream selection with load balancing.
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::Mutex;
use dashmap::DashMap;
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
/// Upstream selection result.
pub struct UpstreamSelection {
pub host: String,
pub port: u16,
pub use_tls: bool,
}
/// Selects upstream backends with load balancing support.
pub struct UpstreamSelector {
/// Round-robin counters per route (keyed by first target host:port)
round_robin: Mutex<HashMap<String, AtomicUsize>>,
/// Active connection counts per host (keyed by "host:port")
active_connections: Arc<DashMap<String, AtomicU64>>,
}
impl UpstreamSelector {
pub fn new() -> Self {
Self {
round_robin: Mutex::new(HashMap::new()),
active_connections: Arc::new(DashMap::new()),
}
}
/// Select an upstream target based on the route target config and load balancing.
pub fn select(
&self,
target: &RouteTarget,
client_addr: &SocketAddr,
incoming_port: u16,
) -> UpstreamSelection {
let hosts = target.host.to_vec();
let port = target.port.resolve(incoming_port);
if hosts.len() <= 1 {
return UpstreamSelection {
host: hosts.first().map(|s| s.to_string()).unwrap_or_default(),
port,
use_tls: target.tls.is_some(),
};
}
// Determine load balancing algorithm
let algorithm = target.load_balancing.as_ref()
.map(|lb| &lb.algorithm)
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
let idx = match algorithm {
LoadBalancingAlgorithm::RoundRobin => {
self.round_robin_select(&hosts, port)
}
LoadBalancingAlgorithm::IpHash => {
let hash = Self::ip_hash(client_addr);
hash % hosts.len()
}
LoadBalancingAlgorithm::LeastConnections => {
self.least_connections_select(&hosts, port)
}
};
UpstreamSelection {
host: hosts[idx].to_string(),
port,
use_tls: target.tls.is_some(),
}
}
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
let key = format!("{}:{}", hosts[0], port);
let mut counters = self.round_robin.lock().unwrap();
let counter = counters
.entry(key)
.or_insert_with(|| AtomicUsize::new(0));
let idx = counter.fetch_add(1, Ordering::Relaxed);
idx % hosts.len()
}
fn least_connections_select(&self, hosts: &[&str], port: u16) -> usize {
let mut min_conns = u64::MAX;
let mut min_idx = 0;
for (i, host) in hosts.iter().enumerate() {
let key = format!("{}:{}", host, port);
let conns = self.active_connections
.get(&key)
.map(|entry| entry.value().load(Ordering::Relaxed))
.unwrap_or(0);
if conns < min_conns {
min_conns = conns;
min_idx = i;
}
}
min_idx
}
/// Record that a connection to the given host has started.
pub fn connection_started(&self, host: &str) {
self.active_connections
.entry(host.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
/// Record that a connection to the given host has ended.
pub fn connection_ended(&self, host: &str) {
if let Some(counter) = self.active_connections.get(host) {
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
// Guard against underflow (shouldn't happen, but be safe)
if prev == 0 {
counter.value().store(0, Ordering::Relaxed);
}
}
}
fn ip_hash(addr: &SocketAddr) -> usize {
let ip_str = addr.ip().to_string();
let mut hash: usize = 5381;
for byte in ip_str.bytes() {
hash = hash.wrapping_mul(33).wrapping_add(byte as usize);
}
hash
}
}
impl Default for UpstreamSelector {
fn default() -> Self {
Self::new()
}
}
impl Clone for UpstreamSelector {
fn clone(&self) -> Self {
Self {
round_robin: Mutex::new(HashMap::new()),
active_connections: Arc::clone(&self.active_connections),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustproxy_config::*;
fn make_target(hosts: Vec<&str>, port: u16) -> RouteTarget {
RouteTarget {
target_match: None,
host: if hosts.len() == 1 {
HostSpec::Single(hosts[0].to_string())
} else {
HostSpec::List(hosts.iter().map(|s| s.to_string()).collect())
},
port: PortSpec::Fixed(port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}
}
#[test]
fn test_single_host() {
let selector = UpstreamSelector::new();
let target = make_target(vec!["backend"], 8080);
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
let result = selector.select(&target, &addr, 80);
assert_eq!(result.host, "backend");
assert_eq!(result.port, 8080);
}
#[test]
fn test_round_robin() {
let selector = UpstreamSelector::new();
let mut target = make_target(vec!["a", "b", "c"], 8080);
target.load_balancing = Some(RouteLoadBalancing {
algorithm: LoadBalancingAlgorithm::RoundRobin,
health_check: None,
});
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
let r1 = selector.select(&target, &addr, 80);
let r2 = selector.select(&target, &addr, 80);
let r3 = selector.select(&target, &addr, 80);
let r4 = selector.select(&target, &addr, 80);
// Should cycle through a, b, c, a
assert_eq!(r1.host, "a");
assert_eq!(r2.host, "b");
assert_eq!(r3.host, "c");
assert_eq!(r4.host, "a");
}
#[test]
fn test_ip_hash_consistent() {
let selector = UpstreamSelector::new();
let mut target = make_target(vec!["a", "b", "c"], 8080);
target.load_balancing = Some(RouteLoadBalancing {
algorithm: LoadBalancingAlgorithm::IpHash,
health_check: None,
});
let addr: SocketAddr = "10.0.0.5:1234".parse().unwrap();
let r1 = selector.select(&target, &addr, 80);
let r2 = selector.select(&target, &addr, 80);
// Same IP should always get same backend
assert_eq!(r1.host, r2.host);
}
}

View File

@@ -0,0 +1,15 @@
[package]
name = "rustproxy-metrics"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Metrics and throughput tracking for RustProxy"
[dependencies]
dashmap = { workspace = true }
tracing = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }

View File

@@ -0,0 +1,668 @@
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
use crate::throughput::{ThroughputSample, ThroughputTracker};
/// Aggregated metrics snapshot.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Metrics {
pub active_connections: u64,
pub total_connections: u64,
pub bytes_in: u64,
pub bytes_out: u64,
pub throughput_in_bytes_per_sec: u64,
pub throughput_out_bytes_per_sec: u64,
pub throughput_recent_in_bytes_per_sec: u64,
pub throughput_recent_out_bytes_per_sec: u64,
pub routes: std::collections::HashMap<String, RouteMetrics>,
pub ips: std::collections::HashMap<String, IpMetrics>,
pub throughput_history: Vec<ThroughputSample>,
pub total_http_requests: u64,
pub http_requests_per_sec: u64,
pub http_requests_per_sec_recent: u64,
}
/// Per-route metrics.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteMetrics {
pub active_connections: u64,
pub total_connections: u64,
pub bytes_in: u64,
pub bytes_out: u64,
pub throughput_in_bytes_per_sec: u64,
pub throughput_out_bytes_per_sec: u64,
pub throughput_recent_in_bytes_per_sec: u64,
pub throughput_recent_out_bytes_per_sec: u64,
}
/// Per-IP metrics.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct IpMetrics {
pub active_connections: u64,
pub total_connections: u64,
pub bytes_in: u64,
pub bytes_out: u64,
pub throughput_in_bytes_per_sec: u64,
pub throughput_out_bytes_per_sec: u64,
}
/// Statistics snapshot.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Statistics {
pub active_connections: u64,
pub total_connections: u64,
pub routes_count: u64,
pub listening_ports: Vec<u16>,
pub uptime_seconds: u64,
}
/// Default retention for throughput samples (1 hour).
const DEFAULT_RETENTION_SECONDS: usize = 3600;
/// Maximum number of IPs to include in a snapshot (top by active connections).
const MAX_IPS_IN_SNAPSHOT: usize = 100;
/// Metrics collector tracking connections and throughput.
///
/// Design: The hot path (`record_bytes`) is entirely lock-free — it only touches
/// `AtomicU64` counters. The cold path (`sample_all`, called at 1Hz) drains
/// those atomics and feeds the throughput trackers under a Mutex. This avoids
/// contention when `record_bytes` is called per-chunk in the TCP copy loop.
pub struct MetricsCollector {
active_connections: AtomicU64,
total_connections: AtomicU64,
total_bytes_in: AtomicU64,
total_bytes_out: AtomicU64,
/// Per-route active connection counts
route_connections: DashMap<String, AtomicU64>,
/// Per-route total connection counts
route_total_connections: DashMap<String, AtomicU64>,
/// Per-route byte counters
route_bytes_in: DashMap<String, AtomicU64>,
route_bytes_out: DashMap<String, AtomicU64>,
// ── Per-IP tracking ──
ip_connections: DashMap<String, AtomicU64>,
ip_total_connections: DashMap<String, AtomicU64>,
ip_bytes_in: DashMap<String, AtomicU64>,
ip_bytes_out: DashMap<String, AtomicU64>,
ip_pending_tp: DashMap<String, (AtomicU64, AtomicU64)>,
ip_throughput: DashMap<String, Mutex<ThroughputTracker>>,
// ── HTTP request tracking ──
total_http_requests: AtomicU64,
pending_http_requests: AtomicU64,
http_request_throughput: Mutex<ThroughputTracker>,
// ── Lock-free pending throughput counters (hot path) ──
global_pending_tp_in: AtomicU64,
global_pending_tp_out: AtomicU64,
route_pending_tp: DashMap<String, (AtomicU64, AtomicU64)>,
// ── Throughput history — only locked during sampling (cold path) ──
global_throughput: Mutex<ThroughputTracker>,
route_throughput: DashMap<String, Mutex<ThroughputTracker>>,
retention_seconds: usize,
}
impl MetricsCollector {
pub fn new() -> Self {
Self::with_retention(DEFAULT_RETENTION_SECONDS)
}
/// Create a MetricsCollector with a custom retention period for throughput history.
pub fn with_retention(retention_seconds: usize) -> Self {
Self {
active_connections: AtomicU64::new(0),
total_connections: AtomicU64::new(0),
total_bytes_in: AtomicU64::new(0),
total_bytes_out: AtomicU64::new(0),
route_connections: DashMap::new(),
route_total_connections: DashMap::new(),
route_bytes_in: DashMap::new(),
route_bytes_out: DashMap::new(),
ip_connections: DashMap::new(),
ip_total_connections: DashMap::new(),
ip_bytes_in: DashMap::new(),
ip_bytes_out: DashMap::new(),
ip_pending_tp: DashMap::new(),
ip_throughput: DashMap::new(),
total_http_requests: AtomicU64::new(0),
pending_http_requests: AtomicU64::new(0),
http_request_throughput: Mutex::new(ThroughputTracker::new(retention_seconds)),
global_pending_tp_in: AtomicU64::new(0),
global_pending_tp_out: AtomicU64::new(0),
route_pending_tp: DashMap::new(),
global_throughput: Mutex::new(ThroughputTracker::new(retention_seconds)),
route_throughput: DashMap::new(),
retention_seconds,
}
}
/// Record a new connection.
pub fn connection_opened(&self, route_id: Option<&str>, source_ip: Option<&str>) {
self.active_connections.fetch_add(1, Ordering::Relaxed);
self.total_connections.fetch_add(1, Ordering::Relaxed);
if let Some(route_id) = route_id {
self.route_connections
.entry(route_id.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
self.route_total_connections
.entry(route_id.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
if let Some(ip) = source_ip {
self.ip_connections
.entry(ip.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
self.ip_total_connections
.entry(ip.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
}
/// Record a connection closing.
pub fn connection_closed(&self, route_id: Option<&str>, source_ip: Option<&str>) {
self.active_connections.fetch_sub(1, Ordering::Relaxed);
if let Some(route_id) = route_id {
if let Some(counter) = self.route_connections.get(route_id) {
let val = counter.load(Ordering::Relaxed);
if val > 0 {
counter.fetch_sub(1, Ordering::Relaxed);
}
}
}
if let Some(ip) = source_ip {
if let Some(counter) = self.ip_connections.get(ip) {
let val = counter.load(Ordering::Relaxed);
if val > 0 {
counter.fetch_sub(1, Ordering::Relaxed);
}
// Clean up zero-count entries to prevent memory growth
if val <= 1 {
drop(counter);
self.ip_connections.remove(ip);
}
}
}
}
/// Record bytes transferred (lock-free hot path).
///
/// Called per-chunk in the TCP copy loop. Only touches AtomicU64 counters —
/// no Mutex is taken. The throughput trackers are fed during `sample_all()`.
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64, route_id: Option<&str>, source_ip: Option<&str>) {
self.total_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
self.total_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
// Accumulate into lock-free pending throughput counters
self.global_pending_tp_in.fetch_add(bytes_in, Ordering::Relaxed);
self.global_pending_tp_out.fetch_add(bytes_out, Ordering::Relaxed);
if let Some(route_id) = route_id {
self.route_bytes_in
.entry(route_id.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(bytes_in, Ordering::Relaxed);
self.route_bytes_out
.entry(route_id.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(bytes_out, Ordering::Relaxed);
// Accumulate into per-route pending throughput counters (lock-free)
let entry = self.route_pending_tp
.entry(route_id.to_string())
.or_insert_with(|| (AtomicU64::new(0), AtomicU64::new(0)));
entry.0.fetch_add(bytes_in, Ordering::Relaxed);
entry.1.fetch_add(bytes_out, Ordering::Relaxed);
}
if let Some(ip) = source_ip {
self.ip_bytes_in
.entry(ip.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(bytes_in, Ordering::Relaxed);
self.ip_bytes_out
.entry(ip.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(bytes_out, Ordering::Relaxed);
// Accumulate into per-IP pending throughput counters (lock-free)
let entry = self.ip_pending_tp
.entry(ip.to_string())
.or_insert_with(|| (AtomicU64::new(0), AtomicU64::new(0)));
entry.0.fetch_add(bytes_in, Ordering::Relaxed);
entry.1.fetch_add(bytes_out, Ordering::Relaxed);
}
}
/// Record an HTTP request (called once per request in the HTTP proxy).
pub fn record_http_request(&self) {
self.total_http_requests.fetch_add(1, Ordering::Relaxed);
self.pending_http_requests.fetch_add(1, Ordering::Relaxed);
}
/// Take a throughput sample on all trackers (cold path, call at 1Hz or configured interval).
///
/// Drains the lock-free pending counters and feeds the accumulated bytes
/// into the throughput trackers (under Mutex). This is the only place
/// the Mutex is locked.
pub fn sample_all(&self) {
// Drain global pending bytes and feed into the tracker
let global_in = self.global_pending_tp_in.swap(0, Ordering::Relaxed);
let global_out = self.global_pending_tp_out.swap(0, Ordering::Relaxed);
if let Ok(mut tracker) = self.global_throughput.lock() {
tracker.record_bytes(global_in, global_out);
tracker.sample();
}
// Drain per-route pending bytes; collect into a Vec to avoid holding DashMap shards
let mut route_samples: Vec<(String, u64, u64)> = Vec::new();
for entry in self.route_pending_tp.iter() {
let route_id = entry.key().clone();
let pending_in = entry.value().0.swap(0, Ordering::Relaxed);
let pending_out = entry.value().1.swap(0, Ordering::Relaxed);
route_samples.push((route_id, pending_in, pending_out));
}
// Feed pending bytes into route trackers and sample
let retention = self.retention_seconds;
for (route_id, pending_in, pending_out) in &route_samples {
// Ensure the tracker exists
self.route_throughput
.entry(route_id.clone())
.or_insert_with(|| Mutex::new(ThroughputTracker::new(retention)));
// Now get a separate ref and lock it
if let Some(tracker_ref) = self.route_throughput.get(route_id) {
if let Ok(mut tracker) = tracker_ref.value().lock() {
tracker.record_bytes(*pending_in, *pending_out);
tracker.sample();
}
}
}
// Also sample any route trackers that had no new pending bytes
// (to keep their sample window advancing)
for entry in self.route_throughput.iter() {
if !self.route_pending_tp.contains_key(entry.key()) {
if let Ok(mut tracker) = entry.value().lock() {
tracker.sample();
}
}
}
// Drain per-IP pending bytes and feed into IP throughput trackers
let mut ip_samples: Vec<(String, u64, u64)> = Vec::new();
for entry in self.ip_pending_tp.iter() {
let ip = entry.key().clone();
let pending_in = entry.value().0.swap(0, Ordering::Relaxed);
let pending_out = entry.value().1.swap(0, Ordering::Relaxed);
ip_samples.push((ip, pending_in, pending_out));
}
for (ip, pending_in, pending_out) in &ip_samples {
self.ip_throughput
.entry(ip.clone())
.or_insert_with(|| Mutex::new(ThroughputTracker::new(retention)));
if let Some(tracker_ref) = self.ip_throughput.get(ip) {
if let Ok(mut tracker) = tracker_ref.value().lock() {
tracker.record_bytes(*pending_in, *pending_out);
tracker.sample();
}
}
}
// Sample idle IP trackers
for entry in self.ip_throughput.iter() {
if !self.ip_pending_tp.contains_key(entry.key()) {
if let Ok(mut tracker) = entry.value().lock() {
tracker.sample();
}
}
}
// Drain pending HTTP request count and feed into HTTP throughput tracker
let pending_reqs = self.pending_http_requests.swap(0, Ordering::Relaxed);
if let Ok(mut tracker) = self.http_request_throughput.lock() {
// Use bytes_in field to track request count (each request = 1 "byte")
tracker.record_bytes(pending_reqs, 0);
tracker.sample();
}
}
/// Get current active connection count.
pub fn active_connections(&self) -> u64 {
self.active_connections.load(Ordering::Relaxed)
}
/// Get total connection count.
pub fn total_connections(&self) -> u64 {
self.total_connections.load(Ordering::Relaxed)
}
/// Get total bytes received.
pub fn total_bytes_in(&self) -> u64 {
self.total_bytes_in.load(Ordering::Relaxed)
}
/// Get total bytes sent.
pub fn total_bytes_out(&self) -> u64 {
self.total_bytes_out.load(Ordering::Relaxed)
}
/// Get a full metrics snapshot including per-route and per-IP data.
pub fn snapshot(&self) -> Metrics {
let mut routes = std::collections::HashMap::new();
// Get global throughput (instant = last 1 sample, recent = last 10 samples)
let (global_tp_in, global_tp_out, global_recent_in, global_recent_out, throughput_history) =
self.global_throughput
.lock()
.map(|t| {
let (i_in, i_out) = t.instant();
let (r_in, r_out) = t.recent();
let history = t.history(60);
(i_in, i_out, r_in, r_out, history)
})
.unwrap_or((0, 0, 0, 0, Vec::new()));
// Collect per-route metrics
for entry in self.route_total_connections.iter() {
let route_id = entry.key().clone();
let total = entry.value().load(Ordering::Relaxed);
let active = self.route_connections
.get(&route_id)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0);
let bytes_in = self.route_bytes_in
.get(&route_id)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0);
let bytes_out = self.route_bytes_out
.get(&route_id)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0);
let (route_tp_in, route_tp_out, route_recent_in, route_recent_out) = self.route_throughput
.get(&route_id)
.and_then(|entry| entry.value().lock().ok().map(|t| {
let (i_in, i_out) = t.instant();
let (r_in, r_out) = t.recent();
(i_in, i_out, r_in, r_out)
}))
.unwrap_or((0, 0, 0, 0));
routes.insert(route_id, RouteMetrics {
active_connections: active,
total_connections: total,
bytes_in,
bytes_out,
throughput_in_bytes_per_sec: route_tp_in,
throughput_out_bytes_per_sec: route_tp_out,
throughput_recent_in_bytes_per_sec: route_recent_in,
throughput_recent_out_bytes_per_sec: route_recent_out,
});
}
// Collect per-IP metrics — only IPs with active connections or total > 0,
// capped at top MAX_IPS_IN_SNAPSHOT sorted by active count
let mut ip_entries: Vec<(String, u64, u64, u64, u64, u64, u64)> = Vec::new();
for entry in self.ip_total_connections.iter() {
let ip = entry.key().clone();
let total = entry.value().load(Ordering::Relaxed);
let active = self.ip_connections
.get(&ip)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0);
let bytes_in = self.ip_bytes_in
.get(&ip)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0);
let bytes_out = self.ip_bytes_out
.get(&ip)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0);
let (tp_in, tp_out) = self.ip_throughput
.get(&ip)
.and_then(|entry| entry.value().lock().ok().map(|t| t.instant()))
.unwrap_or((0, 0));
ip_entries.push((ip, active, total, bytes_in, bytes_out, tp_in, tp_out));
}
// Sort by active connections descending, then cap
ip_entries.sort_by(|a, b| b.1.cmp(&a.1));
ip_entries.truncate(MAX_IPS_IN_SNAPSHOT);
let mut ips = std::collections::HashMap::new();
for (ip, active, total, bytes_in, bytes_out, tp_in, tp_out) in ip_entries {
ips.insert(ip, IpMetrics {
active_connections: active,
total_connections: total,
bytes_in,
bytes_out,
throughput_in_bytes_per_sec: tp_in,
throughput_out_bytes_per_sec: tp_out,
});
}
// HTTP request rates
let (http_rps, http_rps_recent) = self.http_request_throughput
.lock()
.map(|t| {
let (instant, _) = t.instant();
let (recent, _) = t.recent();
(instant, recent)
})
.unwrap_or((0, 0));
Metrics {
active_connections: self.active_connections(),
total_connections: self.total_connections(),
bytes_in: self.total_bytes_in(),
bytes_out: self.total_bytes_out(),
throughput_in_bytes_per_sec: global_tp_in,
throughput_out_bytes_per_sec: global_tp_out,
throughput_recent_in_bytes_per_sec: global_recent_in,
throughput_recent_out_bytes_per_sec: global_recent_out,
routes,
ips,
throughput_history,
total_http_requests: self.total_http_requests.load(Ordering::Relaxed),
http_requests_per_sec: http_rps,
http_requests_per_sec_recent: http_rps_recent,
}
}
}
impl Default for MetricsCollector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_initial_state_zeros() {
let collector = MetricsCollector::new();
assert_eq!(collector.active_connections(), 0);
assert_eq!(collector.total_connections(), 0);
}
#[test]
fn test_connection_opened_increments() {
let collector = MetricsCollector::new();
collector.connection_opened(None, None);
assert_eq!(collector.active_connections(), 1);
assert_eq!(collector.total_connections(), 1);
collector.connection_opened(None, None);
assert_eq!(collector.active_connections(), 2);
assert_eq!(collector.total_connections(), 2);
}
#[test]
fn test_connection_closed_decrements() {
let collector = MetricsCollector::new();
collector.connection_opened(None, None);
collector.connection_opened(None, None);
assert_eq!(collector.active_connections(), 2);
collector.connection_closed(None, None);
assert_eq!(collector.active_connections(), 1);
// total_connections should stay at 2
assert_eq!(collector.total_connections(), 2);
}
#[test]
fn test_route_specific_tracking() {
let collector = MetricsCollector::new();
collector.connection_opened(Some("route-a"), None);
collector.connection_opened(Some("route-a"), None);
collector.connection_opened(Some("route-b"), None);
assert_eq!(collector.active_connections(), 3);
assert_eq!(collector.total_connections(), 3);
collector.connection_closed(Some("route-a"), None);
assert_eq!(collector.active_connections(), 2);
}
#[test]
fn test_record_bytes() {
let collector = MetricsCollector::new();
collector.record_bytes(100, 200, Some("route-a"), None);
collector.record_bytes(50, 75, Some("route-a"), None);
collector.record_bytes(25, 30, None, None);
let total_in = collector.total_bytes_in.load(Ordering::Relaxed);
let total_out = collector.total_bytes_out.load(Ordering::Relaxed);
assert_eq!(total_in, 175);
assert_eq!(total_out, 305);
// Route-specific bytes
let route_in = collector.route_bytes_in.get("route-a").unwrap();
assert_eq!(route_in.load(Ordering::Relaxed), 150);
}
#[test]
fn test_throughput_tracking() {
let collector = MetricsCollector::with_retention(60);
// Open a connection so the route appears in the snapshot
collector.connection_opened(Some("route-a"), None);
// Record some bytes
collector.record_bytes(1000, 2000, Some("route-a"), None);
collector.record_bytes(500, 750, None, None);
// Take a sample (simulates the 1Hz tick)
collector.sample_all();
// Check global throughput
let snapshot = collector.snapshot();
assert_eq!(snapshot.throughput_in_bytes_per_sec, 1500);
assert_eq!(snapshot.throughput_out_bytes_per_sec, 2750);
// Check per-route throughput
let route_a = snapshot.routes.get("route-a").unwrap();
assert_eq!(route_a.throughput_in_bytes_per_sec, 1000);
assert_eq!(route_a.throughput_out_bytes_per_sec, 2000);
}
#[test]
fn test_throughput_zero_before_sampling() {
let collector = MetricsCollector::with_retention(60);
collector.record_bytes(1000, 2000, None, None);
// Without sampling, throughput should be 0
let snapshot = collector.snapshot();
assert_eq!(snapshot.throughput_in_bytes_per_sec, 0);
assert_eq!(snapshot.throughput_out_bytes_per_sec, 0);
}
#[test]
fn test_per_ip_tracking() {
let collector = MetricsCollector::with_retention(60);
collector.connection_opened(Some("route-a"), Some("1.2.3.4"));
collector.connection_opened(Some("route-a"), Some("1.2.3.4"));
collector.connection_opened(Some("route-b"), Some("5.6.7.8"));
// Check IP active connections (drop DashMap refs immediately to avoid deadlock)
assert_eq!(
collector.ip_connections.get("1.2.3.4").unwrap().load(Ordering::Relaxed),
2
);
assert_eq!(
collector.ip_connections.get("5.6.7.8").unwrap().load(Ordering::Relaxed),
1
);
// Record bytes per IP
collector.record_bytes(100, 200, Some("route-a"), Some("1.2.3.4"));
collector.record_bytes(300, 400, Some("route-b"), Some("5.6.7.8"));
collector.sample_all();
let snapshot = collector.snapshot();
assert_eq!(snapshot.ips.len(), 2);
let ip1_metrics = snapshot.ips.get("1.2.3.4").unwrap();
assert_eq!(ip1_metrics.active_connections, 2);
assert_eq!(ip1_metrics.bytes_in, 100);
// Close connections
collector.connection_closed(Some("route-a"), Some("1.2.3.4"));
assert_eq!(
collector.ip_connections.get("1.2.3.4").unwrap().load(Ordering::Relaxed),
1
);
// Close last connection for IP — should be cleaned up
collector.connection_closed(Some("route-a"), Some("1.2.3.4"));
assert!(collector.ip_connections.get("1.2.3.4").is_none());
}
#[test]
fn test_http_request_tracking() {
let collector = MetricsCollector::with_retention(60);
collector.record_http_request();
collector.record_http_request();
collector.record_http_request();
assert_eq!(collector.total_http_requests.load(Ordering::Relaxed), 3);
collector.sample_all();
let snapshot = collector.snapshot();
assert_eq!(snapshot.total_http_requests, 3);
assert_eq!(snapshot.http_requests_per_sec, 3);
}
#[test]
fn test_throughput_history_in_snapshot() {
let collector = MetricsCollector::with_retention(60);
for i in 1..=5 {
collector.record_bytes(i * 100, i * 200, None, None);
collector.sample_all();
}
let snapshot = collector.snapshot();
assert_eq!(snapshot.throughput_history.len(), 5);
// History should be chronological (oldest first)
assert_eq!(snapshot.throughput_history[0].bytes_in, 100);
assert_eq!(snapshot.throughput_history[4].bytes_in, 500);
}
}

View File

@@ -0,0 +1,11 @@
//! # rustproxy-metrics
//!
//! Metrics and throughput tracking for RustProxy.
pub mod throughput;
pub mod collector;
pub mod log_dedup;
pub use throughput::*;
pub use collector::*;
pub use log_dedup::*;

View File

@@ -0,0 +1,219 @@
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tracing::info;
/// An aggregated event during the deduplication window.
struct AggregatedEvent {
category: String,
first_message: String,
count: AtomicU64,
first_seen: Instant,
#[allow(dead_code)]
last_seen: Instant,
}
/// Log deduplicator that batches similar events over a time window.
///
/// Events are grouped by a composite key of `category:key`. Within each
/// deduplication window (`flush_interval`) identical events are counted
/// instead of being emitted individually. When the window expires (or the
/// batch reaches `max_batch_size`) a single summary line is written via
/// `tracing::info!`.
pub struct LogDeduplicator {
events: DashMap<String, AggregatedEvent>,
flush_interval: Duration,
max_batch_size: u64,
#[allow(dead_code)]
rapid_threshold: u64, // events/sec that triggers immediate flush
}
impl LogDeduplicator {
pub fn new() -> Self {
Self {
events: DashMap::new(),
flush_interval: Duration::from_secs(5),
max_batch_size: 100,
rapid_threshold: 50,
}
}
/// Log an event, deduplicating by `category` + `key`.
///
/// If the batch for this composite key reaches `max_batch_size` the
/// accumulated events are flushed immediately.
pub fn log(&self, category: &str, key: &str, message: &str) {
let map_key = format!("{}:{}", category, key);
let now = Instant::now();
let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent {
category: category.to_string(),
first_message: message.to_string(),
count: AtomicU64::new(0),
first_seen: now,
last_seen: now,
});
let count = entry.count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if we should flush (batch size exceeded)
if count >= self.max_batch_size {
drop(entry);
self.flush();
}
}
/// Flush all accumulated events, emitting summary log lines.
pub fn flush(&self) {
// Collect and remove all events
self.events.retain(|_key, event| {
let count = event.count.load(Ordering::Relaxed);
if count > 0 {
let elapsed = event.first_seen.elapsed();
if count == 1 {
info!("[{}] {}", event.category, event.first_message);
} else {
info!(
"[SUMMARY] {} {} events in {:.1}s: {}",
count,
event.category,
elapsed.as_secs_f64(),
event.first_message
);
}
}
false // remove all entries after flushing
});
}
/// Start a background flush task that periodically drains accumulated
/// events. The task runs until the supplied `CancellationToken` is
/// cancelled, at which point it performs one final flush before exiting.
pub fn start_flush_task(self: &Arc<Self>, cancel: tokio_util::sync::CancellationToken) {
let dedup = Arc::clone(self);
let interval = self.flush_interval;
tokio::spawn(async move {
loop {
tokio::select! {
_ = cancel.cancelled() => {
dedup.flush();
break;
}
_ = tokio::time::sleep(interval) => {
dedup.flush();
}
}
}
});
}
}
impl Default for LogDeduplicator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_single_event_emitted_as_is() {
let dedup = LogDeduplicator::new();
dedup.log("conn", "open", "connection opened from 1.2.3.4");
// One event should exist
assert_eq!(dedup.events.len(), 1);
let entry = dedup.events.get("conn:open").unwrap();
assert_eq!(entry.count.load(Ordering::Relaxed), 1);
assert_eq!(entry.first_message, "connection opened from 1.2.3.4");
drop(entry);
dedup.flush();
// After flush, map should be empty
assert_eq!(dedup.events.len(), 0);
}
#[test]
fn test_duplicate_events_aggregated() {
let dedup = LogDeduplicator::new();
for _ in 0..10 {
dedup.log("conn", "timeout", "connection timed out");
}
assert_eq!(dedup.events.len(), 1);
let entry = dedup.events.get("conn:timeout").unwrap();
assert_eq!(entry.count.load(Ordering::Relaxed), 10);
drop(entry);
dedup.flush();
assert_eq!(dedup.events.len(), 0);
}
#[test]
fn test_different_keys_separate() {
let dedup = LogDeduplicator::new();
dedup.log("conn", "open", "opened");
dedup.log("conn", "close", "closed");
dedup.log("tls", "handshake", "TLS handshake");
assert_eq!(dedup.events.len(), 3);
dedup.flush();
assert_eq!(dedup.events.len(), 0);
}
#[test]
fn test_flush_clears_events() {
let dedup = LogDeduplicator::new();
dedup.log("a", "b", "msg1");
dedup.log("a", "b", "msg2");
dedup.flush();
assert_eq!(dedup.events.len(), 0);
// Logging after flush creates a new entry
dedup.log("a", "b", "msg3");
assert_eq!(dedup.events.len(), 1);
let entry = dedup.events.get("a:b").unwrap();
assert_eq!(entry.count.load(Ordering::Relaxed), 1);
assert_eq!(entry.first_message, "msg3");
}
#[test]
fn test_max_batch_triggers_flush() {
let dedup = LogDeduplicator::new();
// max_batch_size defaults to 100
for i in 0..100 {
dedup.log("flood", "key", &format!("event {}", i));
}
// After hitting max_batch_size the events map should have been flushed
assert_eq!(dedup.events.len(), 0);
}
#[test]
fn test_default_trait() {
let dedup = LogDeduplicator::default();
assert_eq!(dedup.flush_interval, Duration::from_secs(5));
assert_eq!(dedup.max_batch_size, 100);
}
#[tokio::test]
async fn test_background_flush_task() {
let dedup = Arc::new(LogDeduplicator {
events: DashMap::new(),
flush_interval: Duration::from_millis(50),
max_batch_size: 100,
rapid_threshold: 50,
});
let cancel = tokio_util::sync::CancellationToken::new();
dedup.start_flush_task(cancel.clone());
// Log some events
dedup.log("bg", "test", "background flush test");
assert_eq!(dedup.events.len(), 1);
// Wait for the background task to flush
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(dedup.events.len(), 0);
// Cancel the task
cancel.cancel();
tokio::time::sleep(Duration::from_millis(20)).await;
}
}

View File

@@ -0,0 +1,232 @@
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
/// A single throughput sample.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThroughputSample {
pub timestamp_ms: u64,
pub bytes_in: u64,
pub bytes_out: u64,
}
/// Circular buffer for 1Hz throughput sampling.
/// Matches smartproxy's ThroughputTracker.
pub struct ThroughputTracker {
/// Circular buffer of samples
samples: Vec<ThroughputSample>,
/// Current write index
write_index: usize,
/// Number of valid samples
count: usize,
/// Maximum number of samples to retain
capacity: usize,
/// Accumulated bytes since last sample
pending_bytes_in: AtomicU64,
pending_bytes_out: AtomicU64,
/// When the tracker was created
created_at: Instant,
}
impl ThroughputTracker {
/// Create a new tracker with the given capacity (seconds of retention).
pub fn new(retention_seconds: usize) -> Self {
Self {
samples: Vec::with_capacity(retention_seconds),
write_index: 0,
count: 0,
capacity: retention_seconds,
pending_bytes_in: AtomicU64::new(0),
pending_bytes_out: AtomicU64::new(0),
created_at: Instant::now(),
}
}
/// Record bytes (called from data flow callbacks).
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64) {
self.pending_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
self.pending_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
}
/// Take a sample (called at 1Hz).
pub fn sample(&mut self) {
let bytes_in = self.pending_bytes_in.swap(0, Ordering::Relaxed);
let bytes_out = self.pending_bytes_out.swap(0, Ordering::Relaxed);
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let sample = ThroughputSample {
timestamp_ms,
bytes_in,
bytes_out,
};
if self.samples.len() < self.capacity {
self.samples.push(sample);
} else {
self.samples[self.write_index] = sample;
}
self.write_index = (self.write_index + 1) % self.capacity;
self.count = (self.count + 1).min(self.capacity);
}
/// Get throughput over the last N seconds.
pub fn throughput(&self, window_seconds: usize) -> (u64, u64) {
let window = window_seconds.min(self.count);
if window == 0 {
return (0, 0);
}
let mut total_in = 0u64;
let mut total_out = 0u64;
for i in 0..window {
let idx = if self.write_index >= i + 1 {
self.write_index - i - 1
} else {
self.capacity - (i + 1 - self.write_index)
};
if idx < self.samples.len() {
total_in += self.samples[idx].bytes_in;
total_out += self.samples[idx].bytes_out;
}
}
(total_in / window as u64, total_out / window as u64)
}
/// Get instant throughput (last 1 second).
pub fn instant(&self) -> (u64, u64) {
self.throughput(1)
}
/// Get recent throughput (last 10 seconds).
pub fn recent(&self) -> (u64, u64) {
self.throughput(10)
}
/// Return the last N samples in chronological order (oldest first).
pub fn history(&self, window_seconds: usize) -> Vec<ThroughputSample> {
let window = window_seconds.min(self.count);
if window == 0 {
return Vec::new();
}
let mut result = Vec::with_capacity(window);
for i in 0..window {
let idx = if self.write_index >= i + 1 {
self.write_index - i - 1
} else {
self.capacity - (i + 1 - self.write_index)
};
if idx < self.samples.len() {
result.push(self.samples[idx]);
}
}
result.reverse(); // Return oldest-first (chronological)
result
}
/// How long this tracker has been alive.
pub fn uptime(&self) -> std::time::Duration {
self.created_at.elapsed()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_throughput() {
let tracker = ThroughputTracker::new(60);
let (bytes_in, bytes_out) = tracker.throughput(10);
assert_eq!(bytes_in, 0);
assert_eq!(bytes_out, 0);
}
#[test]
fn test_single_sample() {
let mut tracker = ThroughputTracker::new(60);
tracker.record_bytes(1000, 2000);
tracker.sample();
let (bytes_in, bytes_out) = tracker.instant();
assert_eq!(bytes_in, 1000);
assert_eq!(bytes_out, 2000);
}
#[test]
fn test_circular_buffer_wrap() {
let mut tracker = ThroughputTracker::new(3); // Small capacity
for i in 0..5 {
tracker.record_bytes(i * 100, i * 200);
tracker.sample();
}
// Should still work after wrapping
let (bytes_in, bytes_out) = tracker.throughput(3);
assert!(bytes_in > 0);
assert!(bytes_out > 0);
}
#[test]
fn test_window_averaging() {
let mut tracker = ThroughputTracker::new(60);
// Record 3 samples of different sizes
tracker.record_bytes(100, 200);
tracker.sample();
tracker.record_bytes(200, 400);
tracker.sample();
tracker.record_bytes(300, 600);
tracker.sample();
// Average over 3 samples: (100+200+300)/3 = 200, (200+400+600)/3 = 400
let (avg_in, avg_out) = tracker.throughput(3);
assert_eq!(avg_in, 200);
assert_eq!(avg_out, 400);
}
#[test]
fn test_uptime_positive() {
let tracker = ThroughputTracker::new(60);
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(tracker.uptime().as_millis() >= 10);
}
#[test]
fn test_history_returns_chronological() {
let mut tracker = ThroughputTracker::new(60);
for i in 1..=5 {
tracker.record_bytes(i * 100, i * 200);
tracker.sample();
}
let history = tracker.history(5);
assert_eq!(history.len(), 5);
// First sample should have 100 bytes_in, last should have 500
assert_eq!(history[0].bytes_in, 100);
assert_eq!(history[4].bytes_in, 500);
}
#[test]
fn test_history_wraps_around() {
let mut tracker = ThroughputTracker::new(3); // Small capacity
for i in 1..=5 {
tracker.record_bytes(i * 100, i * 200);
tracker.sample();
}
// Only last 3 should be retained
let history = tracker.history(10); // Ask for more than available
assert_eq!(history.len(), 3);
assert_eq!(history[0].bytes_in, 300);
assert_eq!(history[1].bytes_in, 400);
assert_eq!(history[2].bytes_in, 500);
}
#[test]
fn test_history_empty() {
let tracker = ThroughputTracker::new(60);
let history = tracker.history(10);
assert!(history.is_empty());
}
}

View File

@@ -0,0 +1,17 @@
[package]
name = "rustproxy-nftables"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "NFTables kernel-level forwarding for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
libc = { workspace = true }

View File

@@ -0,0 +1,10 @@
//! # rustproxy-nftables
//!
//! NFTables kernel-level forwarding for RustProxy.
//! Generates and manages nft CLI rules for DNAT/SNAT.
pub mod nft_manager;
pub mod rule_builder;
pub use nft_manager::*;
pub use rule_builder::*;

View File

@@ -0,0 +1,238 @@
use thiserror::Error;
use std::collections::HashMap;
use tracing::{debug, info, warn};
#[derive(Debug, Error)]
pub enum NftError {
#[error("nft command failed: {0}")]
CommandFailed(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Not running as root")]
NotRoot,
}
/// Manager for nftables rules.
///
/// Executes `nft` CLI commands to manage kernel-level packet forwarding.
/// Requires root privileges; operations are skipped gracefully if not root.
pub struct NftManager {
table_name: String,
/// Active rules indexed by route ID
active_rules: HashMap<String, Vec<String>>,
/// Whether the table has been initialized
table_initialized: bool,
}
impl NftManager {
pub fn new(table_name: Option<String>) -> Self {
Self {
table_name: table_name.unwrap_or_else(|| "rustproxy".to_string()),
active_rules: HashMap::new(),
table_initialized: false,
}
}
/// Check if we are running as root.
fn is_root() -> bool {
unsafe { libc::geteuid() == 0 }
}
/// Execute a single nft command via the CLI.
async fn exec_nft(command: &str) -> Result<String, NftError> {
// The command starts with "nft ", strip it to get the args
let args = if command.starts_with("nft ") {
&command[4..]
} else {
command
};
let output = tokio::process::Command::new("nft")
.args(args.split_whitespace())
.output()
.await
.map_err(NftError::Io)?;
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).to_string())
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
Err(NftError::CommandFailed(format!(
"Command '{}' failed: {}",
command, stderr
)))
}
}
/// Ensure the nftables table and chains are set up.
async fn ensure_table(&mut self) -> Result<(), NftError> {
if self.table_initialized {
return Ok(());
}
let setup_commands = crate::rule_builder::build_table_setup(&self.table_name);
for cmd in &setup_commands {
Self::exec_nft(cmd).await?;
}
self.table_initialized = true;
info!("NFTables table '{}' initialized", self.table_name);
Ok(())
}
/// Apply rules for a route.
///
/// Executes the nft commands via the CLI. If not running as root,
/// the rules are stored locally but not applied to the kernel.
pub async fn apply_rules(&mut self, route_id: &str, rules: Vec<String>) -> Result<(), NftError> {
if !Self::is_root() {
warn!("Not running as root, nftables rules will not be applied to kernel");
self.active_rules.insert(route_id.to_string(), rules);
return Ok(());
}
self.ensure_table().await?;
for cmd in &rules {
Self::exec_nft(cmd).await?;
debug!("Applied nft rule: {}", cmd);
}
info!("Applied {} nftables rules for route '{}'", rules.len(), route_id);
self.active_rules.insert(route_id.to_string(), rules);
Ok(())
}
/// Remove rules for a route.
///
/// Currently removes the route from tracking. To fully remove specific
/// rules would require handle-based tracking; for now, cleanup() removes
/// the entire table.
pub async fn remove_rules(&mut self, route_id: &str) -> Result<(), NftError> {
if let Some(rules) = self.active_rules.remove(route_id) {
info!("Removed {} tracked nft rules for route '{}'", rules.len(), route_id);
}
Ok(())
}
/// Clean up all managed rules by deleting the entire nftables table.
pub async fn cleanup(&mut self) -> Result<(), NftError> {
if !Self::is_root() {
warn!("Not running as root, skipping nftables cleanup");
self.active_rules.clear();
self.table_initialized = false;
return Ok(());
}
if self.table_initialized {
let cleanup_commands = crate::rule_builder::build_table_cleanup(&self.table_name);
for cmd in &cleanup_commands {
match Self::exec_nft(cmd).await {
Ok(_) => debug!("Cleanup: {}", cmd),
Err(e) => warn!("Cleanup command failed (may be ok): {}", e),
}
}
info!("NFTables table '{}' cleaned up", self.table_name);
}
self.active_rules.clear();
self.table_initialized = false;
Ok(())
}
/// Get the table name.
pub fn table_name(&self) -> &str {
&self.table_name
}
/// Whether the table has been initialized in the kernel.
pub fn is_initialized(&self) -> bool {
self.table_initialized
}
/// Get the number of active route rule sets.
pub fn active_route_count(&self) -> usize {
self.active_rules.len()
}
/// Get the status of all active rules.
pub fn status(&self) -> HashMap<String, serde_json::Value> {
let mut status = HashMap::new();
for (route_id, rules) in &self.active_rules {
status.insert(
route_id.clone(),
serde_json::json!({
"ruleCount": rules.len(),
"rules": rules,
}),
);
}
status
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_default_table_name() {
let mgr = NftManager::new(None);
assert_eq!(mgr.table_name(), "rustproxy");
assert!(!mgr.is_initialized());
}
#[test]
fn test_new_custom_table_name() {
let mgr = NftManager::new(Some("custom".to_string()));
assert_eq!(mgr.table_name(), "custom");
}
#[tokio::test]
async fn test_apply_rules_non_root() {
let mut mgr = NftManager::new(None);
// When not root, rules are stored but not applied to kernel
let rules = vec!["nft add rule ip rustproxy prerouting tcp dport 443 dnat to 10.0.0.1:8443".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
assert_eq!(mgr.active_route_count(), 1);
let status = mgr.status();
assert!(status.contains_key("route-1"));
assert_eq!(status["route-1"]["ruleCount"], 1);
}
#[tokio::test]
async fn test_remove_rules() {
let mut mgr = NftManager::new(None);
let rules = vec!["nft add rule test".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
assert_eq!(mgr.active_route_count(), 1);
mgr.remove_rules("route-1").await.unwrap();
assert_eq!(mgr.active_route_count(), 0);
}
#[tokio::test]
async fn test_cleanup_non_root() {
let mut mgr = NftManager::new(None);
let rules = vec!["nft add rule test".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
mgr.apply_rules("route-2", vec!["nft add rule test2".to_string()]).await.unwrap();
mgr.cleanup().await.unwrap();
assert_eq!(mgr.active_route_count(), 0);
assert!(!mgr.is_initialized());
}
#[tokio::test]
async fn test_status_multiple_routes() {
let mut mgr = NftManager::new(None);
mgr.apply_rules("web", vec!["rule1".to_string(), "rule2".to_string()]).await.unwrap();
mgr.apply_rules("api", vec!["rule3".to_string()]).await.unwrap();
let status = mgr.status();
assert_eq!(status.len(), 2);
assert_eq!(status["web"]["ruleCount"], 2);
assert_eq!(status["api"]["ruleCount"], 1);
}
}

View File

@@ -0,0 +1,123 @@
use rustproxy_config::{NfTablesOptions, NfTablesProtocol};
/// Build nftables DNAT rule for port forwarding.
pub fn build_dnat_rule(
table_name: &str,
chain_name: &str,
source_port: u16,
target_host: &str,
target_port: u16,
options: &NfTablesOptions,
) -> Vec<String> {
let protocol = match options.protocol.as_ref().unwrap_or(&NfTablesProtocol::Tcp) {
NfTablesProtocol::Tcp => "tcp",
NfTablesProtocol::Udp => "udp",
NfTablesProtocol::All => "tcp", // TODO: handle "all"
};
let mut rules = Vec::new();
// DNAT rule
rules.push(format!(
"nft add rule ip {} {} {} dport {} dnat to {}:{}",
table_name, chain_name, protocol, source_port, target_host, target_port,
));
// SNAT rule if preserving source IP is not enabled
if !options.preserve_source_ip.unwrap_or(false) {
rules.push(format!(
"nft add rule ip {} postrouting {} dport {} masquerade",
table_name, protocol, target_port,
));
}
// Rate limiting
if let Some(max_rate) = &options.max_rate {
rules.push(format!(
"nft add rule ip {} {} {} dport {} limit rate {} accept",
table_name, chain_name, protocol, source_port, max_rate,
));
}
rules
}
/// Build the initial table and chain setup commands.
pub fn build_table_setup(table_name: &str) -> Vec<String> {
vec![
format!("nft add table ip {}", table_name),
format!("nft add chain ip {} prerouting {{ type nat hook prerouting priority 0 \\; }}", table_name),
format!("nft add chain ip {} postrouting {{ type nat hook postrouting priority 100 \\; }}", table_name),
]
}
/// Build cleanup commands to remove the table.
pub fn build_table_cleanup(table_name: &str) -> Vec<String> {
vec![format!("nft delete table ip {}", table_name)]
}
#[cfg(test)]
mod tests {
use super::*;
fn make_options() -> NfTablesOptions {
NfTablesOptions {
preserve_source_ip: None,
protocol: None,
max_rate: None,
priority: None,
table_name: None,
use_ip_sets: None,
use_advanced_nat: None,
}
}
#[test]
fn test_basic_dnat_rule() {
let options = make_options();
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
assert!(rules.len() >= 1);
assert!(rules[0].contains("dnat to 10.0.0.1:8443"));
assert!(rules[0].contains("dport 443"));
}
#[test]
fn test_preserve_source_ip() {
let mut options = make_options();
options.preserve_source_ip = Some(true);
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
// When preserving source IP, no masquerade rule
assert!(rules.iter().all(|r| !r.contains("masquerade")));
}
#[test]
fn test_without_preserve_source_ip() {
let options = make_options();
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
assert!(rules.iter().any(|r| r.contains("masquerade")));
}
#[test]
fn test_rate_limited_rule() {
let mut options = make_options();
options.max_rate = Some("100/second".to_string());
let rules = build_dnat_rule("rustproxy", "prerouting", 80, "10.0.0.1", 8080, &options);
assert!(rules.iter().any(|r| r.contains("limit rate 100/second")));
}
#[test]
fn test_table_setup_commands() {
let commands = build_table_setup("rustproxy");
assert_eq!(commands.len(), 3);
assert!(commands[0].contains("add table ip rustproxy"));
assert!(commands[1].contains("prerouting"));
assert!(commands[2].contains("postrouting"));
}
#[test]
fn test_table_cleanup() {
let commands = build_table_cleanup("rustproxy");
assert_eq!(commands.len(), 1);
assert!(commands[0].contains("delete table ip rustproxy"));
}
}

View File

@@ -0,0 +1,25 @@
[package]
name = "rustproxy-passthrough"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Raw TCP/SNI passthrough engine for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-metrics = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
dashmap = { workspace = true }
arc-swap = { workspace = true }
rustproxy-http = { workspace = true }
rustls = { workspace = true }
tokio-rustls = { workspace = true }
rustls-pemfile = { workspace = true }
tokio-util = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View File

@@ -0,0 +1,155 @@
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
/// Per-connection tracking record with atomics for lock-free updates.
///
/// Each field uses atomics so that the forwarding tasks can update
/// bytes_received / bytes_sent / last_activity without holding any lock,
/// while the zombie scanner reads them concurrently.
pub struct ConnectionRecord {
/// Unique connection ID assigned by the ConnectionTracker.
pub id: u64,
/// Wall-clock instant when this connection was created.
pub created_at: Instant,
/// Milliseconds since `created_at` when the last activity occurred.
/// Updated atomically by the forwarding loops.
pub last_activity: AtomicU64,
/// Total bytes received from the client (inbound).
pub bytes_received: AtomicU64,
/// Total bytes sent to the client (outbound / from backend).
pub bytes_sent: AtomicU64,
/// True once the client side of the connection has closed.
pub client_closed: AtomicBool,
/// True once the backend side of the connection has closed.
pub backend_closed: AtomicBool,
/// Whether this connection uses TLS (affects zombie thresholds).
pub is_tls: AtomicBool,
/// Whether this connection has keep-alive semantics.
pub has_keep_alive: AtomicBool,
}
impl ConnectionRecord {
/// Create a new connection record with the given ID.
/// All counters start at zero, all flags start as false.
pub fn new(id: u64) -> Self {
Self {
id,
created_at: Instant::now(),
last_activity: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
client_closed: AtomicBool::new(false),
backend_closed: AtomicBool::new(false),
is_tls: AtomicBool::new(false),
has_keep_alive: AtomicBool::new(false),
}
}
/// Update `last_activity` to reflect the current elapsed time.
pub fn touch(&self) {
let elapsed_ms = self.created_at.elapsed().as_millis() as u64;
self.last_activity.store(elapsed_ms, Ordering::Relaxed);
}
/// Record `n` bytes received from the client (inbound).
pub fn record_bytes_in(&self, n: u64) {
self.bytes_received.fetch_add(n, Ordering::Relaxed);
self.touch();
}
/// Record `n` bytes sent to the client (outbound / from backend).
pub fn record_bytes_out(&self, n: u64) {
self.bytes_sent.fetch_add(n, Ordering::Relaxed);
self.touch();
}
/// How long since the last activity on this connection.
pub fn idle_duration(&self) -> Duration {
let last_ms = self.last_activity.load(Ordering::Relaxed);
let age_ms = self.created_at.elapsed().as_millis() as u64;
Duration::from_millis(age_ms.saturating_sub(last_ms))
}
/// Total age of this connection (time since creation).
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_new_record() {
let record = ConnectionRecord::new(42);
assert_eq!(record.id, 42);
assert_eq!(record.bytes_received.load(Ordering::Relaxed), 0);
assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 0);
assert!(!record.client_closed.load(Ordering::Relaxed));
assert!(!record.backend_closed.load(Ordering::Relaxed));
assert!(!record.is_tls.load(Ordering::Relaxed));
assert!(!record.has_keep_alive.load(Ordering::Relaxed));
}
#[test]
fn test_record_bytes() {
let record = ConnectionRecord::new(1);
record.record_bytes_in(100);
record.record_bytes_in(200);
assert_eq!(record.bytes_received.load(Ordering::Relaxed), 300);
record.record_bytes_out(50);
record.record_bytes_out(75);
assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 125);
}
#[test]
fn test_touch_updates_activity() {
let record = ConnectionRecord::new(1);
assert_eq!(record.last_activity.load(Ordering::Relaxed), 0);
// Sleep briefly so elapsed time is nonzero
thread::sleep(Duration::from_millis(10));
record.touch();
let activity = record.last_activity.load(Ordering::Relaxed);
assert!(activity >= 10, "last_activity should be at least 10ms, got {}", activity);
}
#[test]
fn test_idle_duration() {
let record = ConnectionRecord::new(1);
// Initially idle_duration ~ age since last_activity is 0
thread::sleep(Duration::from_millis(20));
let idle = record.idle_duration();
assert!(idle >= Duration::from_millis(20));
// After touch, idle should be near zero
record.touch();
let idle = record.idle_duration();
assert!(idle < Duration::from_millis(10));
}
#[test]
fn test_age() {
let record = ConnectionRecord::new(1);
thread::sleep(Duration::from_millis(20));
let age = record.age();
assert!(age >= Duration::from_millis(20));
}
#[test]
fn test_flags() {
let record = ConnectionRecord::new(1);
record.client_closed.store(true, Ordering::Relaxed);
record.is_tls.store(true, Ordering::Relaxed);
record.has_keep_alive.store(true, Ordering::Relaxed);
assert!(record.client_closed.load(Ordering::Relaxed));
assert!(!record.backend_closed.load(Ordering::Relaxed));
assert!(record.is_tls.load(Ordering::Relaxed));
assert!(record.has_keep_alive.load(Ordering::Relaxed));
}
}

View File

@@ -0,0 +1,402 @@
use dashmap::DashMap;
use std::collections::VecDeque;
use std::net::IpAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use super::connection_record::ConnectionRecord;
/// Thresholds for zombie detection (non-TLS connections).
const HALF_ZOMBIE_TIMEOUT_PLAIN: Duration = Duration::from_secs(30);
/// Thresholds for zombie detection (TLS connections).
const HALF_ZOMBIE_TIMEOUT_TLS: Duration = Duration::from_secs(300);
/// Stuck connection timeout (non-TLS): received data but never sent any.
const STUCK_TIMEOUT_PLAIN: Duration = Duration::from_secs(60);
/// Stuck connection timeout (TLS): received data but never sent any.
const STUCK_TIMEOUT_TLS: Duration = Duration::from_secs(300);
/// Tracks active connections per IP and enforces per-IP limits and rate limiting.
/// Also maintains per-connection records for zombie detection.
pub struct ConnectionTracker {
/// Active connection counts per IP
active: DashMap<IpAddr, AtomicU64>,
/// Connection timestamps per IP for rate limiting
timestamps: DashMap<IpAddr, VecDeque<Instant>>,
/// Maximum concurrent connections per IP (None = unlimited)
max_per_ip: Option<u64>,
/// Maximum new connections per minute per IP (None = unlimited)
rate_limit_per_minute: Option<u64>,
/// Per-connection tracking records for zombie detection
connections: DashMap<u64, Arc<ConnectionRecord>>,
/// Monotonically increasing connection ID counter
next_id: AtomicU64,
}
impl ConnectionTracker {
pub fn new(max_per_ip: Option<u64>, rate_limit_per_minute: Option<u64>) -> Self {
Self {
active: DashMap::new(),
timestamps: DashMap::new(),
max_per_ip,
rate_limit_per_minute,
connections: DashMap::new(),
next_id: AtomicU64::new(1),
}
}
/// Try to accept a new connection from the given IP.
/// Returns true if allowed, false if over limit.
pub fn try_accept(&self, ip: &IpAddr) -> bool {
// Check per-IP connection limit
if let Some(max) = self.max_per_ip {
let count = self.active
.get(ip)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
if count >= max {
return false;
}
}
// Check rate limit
if let Some(rate_limit) = self.rate_limit_per_minute {
let now = Instant::now();
let one_minute = std::time::Duration::from_secs(60);
let mut entry = self.timestamps.entry(*ip).or_default();
let timestamps = entry.value_mut();
// Remove timestamps older than 1 minute
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
timestamps.pop_front();
}
if timestamps.len() as u64 >= rate_limit {
return false;
}
timestamps.push_back(now);
}
true
}
/// Record that a connection was opened from the given IP.
pub fn connection_opened(&self, ip: &IpAddr) {
self.active
.entry(*ip)
.or_insert_with(|| AtomicU64::new(0))
.value()
.fetch_add(1, Ordering::Relaxed);
}
/// Record that a connection was closed from the given IP.
pub fn connection_closed(&self, ip: &IpAddr) {
if let Some(counter) = self.active.get(ip) {
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
// Clean up zero entries
if prev <= 1 {
drop(counter);
self.active.remove(ip);
}
}
}
/// Get the current number of active connections for an IP.
pub fn active_connections(&self, ip: &IpAddr) -> u64 {
self.active
.get(ip)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0)
}
/// Get the total number of tracked IPs.
pub fn tracked_ips(&self) -> usize {
self.active.len()
}
/// Register a new connection and return its tracking record.
///
/// The returned `Arc<ConnectionRecord>` should be passed to the forwarding
/// loop so it can update bytes / activity atomics in real time.
pub fn register_connection(&self, is_tls: bool) -> Arc<ConnectionRecord> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let record = Arc::new(ConnectionRecord::new(id));
record.is_tls.store(is_tls, Ordering::Relaxed);
self.connections.insert(id, Arc::clone(&record));
record
}
/// Remove a connection record when the connection is fully closed.
pub fn unregister_connection(&self, id: u64) {
self.connections.remove(&id);
}
/// Scan all tracked connections and return IDs of zombie connections.
///
/// A connection is considered a zombie in any of these cases:
/// - **Full zombie**: both `client_closed` and `backend_closed` are true.
/// - **Half zombie**: one side closed for longer than the threshold
/// (5 min for TLS, 30s for non-TLS).
/// - **Stuck**: `bytes_received > 0` but `bytes_sent == 0` for longer
/// than the stuck threshold (5 min for TLS, 60s for non-TLS).
pub fn scan_zombies(&self) -> Vec<u64> {
let mut zombies = Vec::new();
for entry in self.connections.iter() {
let record = entry.value();
let id = *entry.key();
let is_tls = record.is_tls.load(Ordering::Relaxed);
let client_closed = record.client_closed.load(Ordering::Relaxed);
let backend_closed = record.backend_closed.load(Ordering::Relaxed);
let idle = record.idle_duration();
let bytes_in = record.bytes_received.load(Ordering::Relaxed);
let bytes_out = record.bytes_sent.load(Ordering::Relaxed);
// Full zombie: both sides closed
if client_closed && backend_closed {
zombies.push(id);
continue;
}
// Half zombie: one side closed for too long
let half_timeout = if is_tls {
HALF_ZOMBIE_TIMEOUT_TLS
} else {
HALF_ZOMBIE_TIMEOUT_PLAIN
};
if (client_closed || backend_closed) && idle >= half_timeout {
zombies.push(id);
continue;
}
// Stuck: received data but never sent anything for too long
let stuck_timeout = if is_tls {
STUCK_TIMEOUT_TLS
} else {
STUCK_TIMEOUT_PLAIN
};
if bytes_in > 0 && bytes_out == 0 && idle >= stuck_timeout {
zombies.push(id);
}
}
zombies
}
/// Start a background task that periodically scans for zombie connections.
///
/// The scanner runs every 10 seconds and logs any zombies it finds.
/// It stops when the provided `CancellationToken` is cancelled.
pub fn start_zombie_scanner(self: &Arc<Self>, cancel: CancellationToken) {
let tracker = Arc::clone(self);
tokio::spawn(async move {
let interval = Duration::from_secs(10);
loop {
tokio::select! {
_ = cancel.cancelled() => {
debug!("Zombie scanner shutting down");
break;
}
_ = tokio::time::sleep(interval) => {
let zombies = tracker.scan_zombies();
if !zombies.is_empty() {
warn!(
"Detected {} zombie connection(s): {:?}",
zombies.len(),
zombies
);
}
}
}
}
});
}
/// Get the total number of tracked connections (with records).
pub fn total_connections(&self) -> usize {
self.connections.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_tracking() {
let tracker = ConnectionTracker::new(None, None);
let ip: IpAddr = "127.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
assert_eq!(tracker.active_connections(&ip), 1);
tracker.connection_opened(&ip);
assert_eq!(tracker.active_connections(&ip), 2);
tracker.connection_closed(&ip);
assert_eq!(tracker.active_connections(&ip), 1);
tracker.connection_closed(&ip);
assert_eq!(tracker.active_connections(&ip), 0);
}
#[test]
fn test_per_ip_limit() {
let tracker = ConnectionTracker::new(Some(2), None);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
// Third connection should be rejected
assert!(!tracker.try_accept(&ip));
// Different IP should still be allowed
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
assert!(tracker.try_accept(&ip2));
}
#[test]
fn test_rate_limit() {
let tracker = ConnectionTracker::new(None, Some(3));
let ip: IpAddr = "10.0.0.1".parse().unwrap();
assert!(tracker.try_accept(&ip));
assert!(tracker.try_accept(&ip));
assert!(tracker.try_accept(&ip));
// 4th attempt within the minute should be rejected
assert!(!tracker.try_accept(&ip));
}
#[test]
fn test_no_limits() {
let tracker = ConnectionTracker::new(None, None);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
for _ in 0..1000 {
assert!(tracker.try_accept(&ip));
tracker.connection_opened(&ip);
}
assert_eq!(tracker.active_connections(&ip), 1000);
}
#[test]
fn test_tracked_ips() {
let tracker = ConnectionTracker::new(None, None);
assert_eq!(tracker.tracked_ips(), 0);
let ip1: IpAddr = "10.0.0.1".parse().unwrap();
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
tracker.connection_opened(&ip1);
tracker.connection_opened(&ip2);
assert_eq!(tracker.tracked_ips(), 2);
tracker.connection_closed(&ip1);
assert_eq!(tracker.tracked_ips(), 1);
}
#[test]
fn test_register_unregister_connection() {
let tracker = ConnectionTracker::new(None, None);
assert_eq!(tracker.total_connections(), 0);
let record1 = tracker.register_connection(false);
assert_eq!(tracker.total_connections(), 1);
assert!(!record1.is_tls.load(Ordering::Relaxed));
let record2 = tracker.register_connection(true);
assert_eq!(tracker.total_connections(), 2);
assert!(record2.is_tls.load(Ordering::Relaxed));
// IDs should be unique
assert_ne!(record1.id, record2.id);
tracker.unregister_connection(record1.id);
assert_eq!(tracker.total_connections(), 1);
tracker.unregister_connection(record2.id);
assert_eq!(tracker.total_connections(), 0);
}
#[test]
fn test_full_zombie_detection() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
// Not a zombie initially
assert!(tracker.scan_zombies().is_empty());
// Set both sides closed -> full zombie
record.client_closed.store(true, Ordering::Relaxed);
record.backend_closed.store(true, Ordering::Relaxed);
let zombies = tracker.scan_zombies();
assert_eq!(zombies.len(), 1);
assert_eq!(zombies[0], record.id);
}
#[test]
fn test_half_zombie_not_triggered_immediately() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
record.touch(); // mark activity now
// Only one side closed, but just now -> not a zombie yet
record.client_closed.store(true, Ordering::Relaxed);
assert!(tracker.scan_zombies().is_empty());
}
#[test]
fn test_stuck_connection_not_triggered_immediately() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
record.touch(); // mark activity now
// Has received data but sent nothing -> but just started, not stuck yet
record.bytes_received.store(1000, Ordering::Relaxed);
assert!(tracker.scan_zombies().is_empty());
}
#[test]
fn test_unregister_removes_from_zombie_scan() {
let tracker = ConnectionTracker::new(None, None);
let record = tracker.register_connection(false);
let id = record.id;
// Make it a full zombie
record.client_closed.store(true, Ordering::Relaxed);
record.backend_closed.store(true, Ordering::Relaxed);
assert_eq!(tracker.scan_zombies().len(), 1);
// Unregister should remove it
tracker.unregister_connection(id);
assert!(tracker.scan_zombies().is_empty());
}
#[test]
fn test_total_connections() {
let tracker = ConnectionTracker::new(None, None);
assert_eq!(tracker.total_connections(), 0);
let r1 = tracker.register_connection(false);
let r2 = tracker.register_connection(true);
let r3 = tracker.register_connection(false);
assert_eq!(tracker.total_connections(), 3);
tracker.unregister_connection(r2.id);
assert_eq!(tracker.total_connections(), 2);
tracker.unregister_connection(r1.id);
tracker.unregister_connection(r3.id);
assert_eq!(tracker.total_connections(), 0);
}
}

View File

@@ -0,0 +1,192 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::debug;
use rustproxy_metrics::MetricsCollector;
/// Context for forwarding metrics, replacing the growing tuple pattern.
#[derive(Clone)]
pub struct ForwardMetricsCtx {
pub collector: Arc<MetricsCollector>,
pub route_id: Option<String>,
pub source_ip: Option<String>,
}
/// Perform bidirectional TCP forwarding between client and backend.
///
/// This is the core data path for passthrough connections.
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes.
pub async fn forward_bidirectional(
mut client: TcpStream,
mut backend: TcpStream,
initial_data: Option<&[u8]>,
) -> std::io::Result<(u64, u64)> {
// Send initial data (peeked bytes) to backend
if let Some(data) = initial_data {
backend.write_all(data).await?;
}
let (mut client_read, mut client_write) = client.split();
let (mut backend_read, mut backend_write) = backend.split();
let client_to_backend = async {
let mut buf = vec![0u8; 65536];
let mut total = initial_data.map_or(0u64, |d| d.len() as u64);
loop {
let n = client_read.read(&mut buf).await?;
if n == 0 {
break;
}
backend_write.write_all(&buf[..n]).await?;
total += n as u64;
}
backend_write.shutdown().await?;
Ok::<u64, std::io::Error>(total)
};
let backend_to_client = async {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = backend_read.read(&mut buf).await?;
if n == 0 {
break;
}
client_write.write_all(&buf[..n]).await?;
total += n as u64;
}
client_write.shutdown().await?;
Ok::<u64, std::io::Error>(total)
};
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
Ok((c2b.unwrap_or(0), b2c.unwrap_or(0)))
}
/// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts.
///
/// When `metrics` is provided, bytes are reported to the MetricsCollector
/// per-chunk (lock-free) as they flow through the copy loops, enabling
/// real-time throughput sampling for long-lived connections.
///
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out.
pub async fn forward_bidirectional_with_timeouts(
client: TcpStream,
mut backend: TcpStream,
initial_data: Option<&[u8]>,
inactivity_timeout: std::time::Duration,
max_lifetime: std::time::Duration,
cancel: CancellationToken,
metrics: Option<ForwardMetricsCtx>,
) -> std::io::Result<(u64, u64)> {
// Send initial data (peeked bytes) to backend
if let Some(data) = initial_data {
backend.write_all(data).await?;
if let Some(ref ctx) = metrics {
ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
}
}
let (mut client_read, mut client_write) = client.into_split();
let (mut backend_read, mut backend_write) = backend.into_split();
let last_activity = Arc::new(AtomicU64::new(0));
let start = std::time::Instant::now();
let la1 = Arc::clone(&last_activity);
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
let metrics_c2b = metrics.clone();
let c2b = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = initial_len;
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if backend_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_c2b {
ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
}
}
let _ = backend_write.shutdown().await;
total
});
let la2 = Arc::clone(&last_activity);
let metrics_b2c = metrics;
let b2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = match backend_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_b2c {
ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
}
}
let _ = client_write.shutdown().await;
total
});
// Watchdog: inactivity, max lifetime, and cancellation
let la_watch = Arc::clone(&last_activity);
let c2b_handle = c2b.abort_handle();
let b2c_handle = b2c.abort_handle();
let watchdog = tokio::spawn(async move {
let check_interval = std::time::Duration::from_secs(5);
let mut last_seen = 0u64;
loop {
tokio::select! {
_ = cancel.cancelled() => {
debug!("Connection cancelled by shutdown");
c2b_handle.abort();
b2c_handle.abort();
break;
}
_ = tokio::time::sleep(check_interval) => {
// Check max lifetime
if start.elapsed() >= max_lifetime {
debug!("Connection exceeded max lifetime, closing");
c2b_handle.abort();
b2c_handle.abort();
break;
}
// Check inactivity
let current = la_watch.load(Ordering::Relaxed);
if current == last_seen {
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
c2b_handle.abort();
b2c_handle.abort();
break;
}
}
last_seen = current;
}
}
}
});
let bytes_in = c2b.await.unwrap_or(0);
let bytes_out = b2c.await.unwrap_or(0);
watchdog.abort();
Ok((bytes_in, bytes_out))
}

View File

@@ -0,0 +1,22 @@
//! # rustproxy-passthrough
//!
//! Raw TCP/SNI passthrough engine for RustProxy.
//! Handles TCP listening, TLS ClientHello SNI extraction, and bidirectional forwarding.
pub mod tcp_listener;
pub mod sni_parser;
pub mod forwarder;
pub mod proxy_protocol;
pub mod tls_handler;
pub mod connection_record;
pub mod connection_tracker;
pub mod socket_relay;
pub use tcp_listener::*;
pub use sni_parser::*;
pub use forwarder::*;
pub use proxy_protocol::*;
pub use tls_handler::*;
pub use connection_record::*;
pub use connection_tracker::*;
pub use socket_relay::*;

View File

@@ -0,0 +1,129 @@
use std::net::SocketAddr;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ProxyProtocolError {
#[error("Invalid PROXY protocol header")]
InvalidHeader,
#[error("Unsupported PROXY protocol version")]
UnsupportedVersion,
#[error("Parse error: {0}")]
Parse(String),
}
/// Parsed PROXY protocol v1 header.
#[derive(Debug, Clone)]
pub struct ProxyProtocolHeader {
pub source_addr: SocketAddr,
pub dest_addr: SocketAddr,
pub protocol: ProxyProtocol,
}
/// Protocol in PROXY header.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProxyProtocol {
Tcp4,
Tcp6,
Unknown,
}
/// Parse a PROXY protocol v1 header from data.
///
/// Format: `PROXY TCP4 <src_ip> <dst_ip> <src_port> <dst_port>\r\n`
pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> {
// Find the end of the header line
let line_end = data
.windows(2)
.position(|w| w == b"\r\n")
.ok_or(ProxyProtocolError::InvalidHeader)?;
let line = std::str::from_utf8(&data[..line_end])
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
if !line.starts_with("PROXY ") {
return Err(ProxyProtocolError::InvalidHeader);
}
let parts: Vec<&str> = line.split(' ').collect();
if parts.len() != 6 {
return Err(ProxyProtocolError::InvalidHeader);
}
let protocol = match parts[1] {
"TCP4" => ProxyProtocol::Tcp4,
"TCP6" => ProxyProtocol::Tcp6,
"UNKNOWN" => ProxyProtocol::Unknown,
_ => return Err(ProxyProtocolError::UnsupportedVersion),
};
let src_ip: std::net::IpAddr = parts[2]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid source IP".to_string()))?;
let dst_ip: std::net::IpAddr = parts[3]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid destination IP".to_string()))?;
let src_port: u16 = parts[4]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid source port".to_string()))?;
let dst_port: u16 = parts[5]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid destination port".to_string()))?;
let header = ProxyProtocolHeader {
source_addr: SocketAddr::new(src_ip, src_port),
dest_addr: SocketAddr::new(dst_ip, dst_port),
protocol,
};
// Consumed bytes = line + \r\n
Ok((header, line_end + 2))
}
/// Generate a PROXY protocol v1 header string.
pub fn generate_v1(source: &SocketAddr, dest: &SocketAddr) -> String {
let proto = if source.is_ipv4() { "TCP4" } else { "TCP6" };
format!(
"PROXY {} {} {} {} {}\r\n",
proto,
source.ip(),
dest.ip(),
source.port(),
dest.port()
)
}
/// Check if data starts with a PROXY protocol v1 header.
pub fn is_proxy_protocol_v1(data: &[u8]) -> bool {
data.starts_with(b"PROXY ")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_v1_tcp4() {
let header = b"PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n";
let (parsed, consumed) = parse_v1(header).unwrap();
assert_eq!(consumed, header.len());
assert_eq!(parsed.protocol, ProxyProtocol::Tcp4);
assert_eq!(parsed.source_addr.ip().to_string(), "192.168.1.100");
assert_eq!(parsed.source_addr.port(), 12345);
assert_eq!(parsed.dest_addr.ip().to_string(), "10.0.0.1");
assert_eq!(parsed.dest_addr.port(), 443);
}
#[test]
fn test_generate_v1() {
let source: SocketAddr = "192.168.1.100:12345".parse().unwrap();
let dest: SocketAddr = "10.0.0.1:443".parse().unwrap();
let header = generate_v1(&source, &dest);
assert_eq!(header, "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n");
}
#[test]
fn test_is_proxy_protocol() {
assert!(is_proxy_protocol_v1(b"PROXY TCP4 ..."));
assert!(!is_proxy_protocol_v1(b"GET / HTTP/1.1"));
}
}

View File

@@ -0,0 +1,322 @@
//! ClientHello SNI extraction via manual byte parsing.
//! No TLS stack needed - we just parse enough of the ClientHello to extract the SNI.
/// Result of SNI extraction.
#[derive(Debug)]
pub enum SniResult {
/// Successfully extracted SNI hostname.
Found(String),
/// TLS ClientHello detected but no SNI extension present.
NoSni,
/// Not a TLS ClientHello (plain HTTP or other protocol).
NotTls,
/// Need more data to determine.
NeedMoreData,
}
/// Extract the SNI hostname from a TLS ClientHello message.
///
/// This parses just enough of the TLS record to find the SNI extension,
/// without performing any actual TLS operations.
pub fn extract_sni(data: &[u8]) -> SniResult {
// Minimum TLS record header is 5 bytes
if data.len() < 5 {
return SniResult::NeedMoreData;
}
// Check for TLS record: content_type=22 (Handshake)
if data[0] != 0x16 {
return SniResult::NotTls;
}
// TLS version (major.minor) - accept any
// data[1..2] = version
// Record length
let record_len = ((data[3] as usize) << 8) | (data[4] as usize);
let _total_len = 5 + record_len;
// We need at least the handshake header (5 TLS + 4 handshake = 9)
if data.len() < 9 {
return SniResult::NeedMoreData;
}
// Handshake type = 1 (ClientHello)
if data[5] != 0x01 {
return SniResult::NotTls;
}
// Handshake length (3 bytes) - informational, we parse incrementally
let _handshake_len = ((data[6] as usize) << 16)
| ((data[7] as usize) << 8)
| (data[8] as usize);
let hello = &data[9..];
// ClientHello structure:
// 2 bytes: client version
// 32 bytes: random
// 1 byte: session_id length + session_id
let mut pos = 2 + 32; // skip version + random
if pos >= hello.len() {
return SniResult::NeedMoreData;
}
// Session ID
let session_id_len = hello[pos] as usize;
pos += 1 + session_id_len;
if pos + 2 > hello.len() {
return SniResult::NeedMoreData;
}
// Cipher suites
let cipher_suites_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize);
pos += 2 + cipher_suites_len;
if pos + 1 > hello.len() {
return SniResult::NeedMoreData;
}
// Compression methods
let compression_len = hello[pos] as usize;
pos += 1 + compression_len;
if pos + 2 > hello.len() {
// No extensions
return SniResult::NoSni;
}
// Extensions length
let extensions_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize);
pos += 2;
let extensions_end = pos + extensions_len;
if extensions_end > hello.len() {
// Partial extensions, try to parse what we have
}
// Parse extensions looking for SNI (type 0x0000)
while pos + 4 <= hello.len() && pos < extensions_end {
let ext_type = ((hello[pos] as u16) << 8) | (hello[pos + 1] as u16);
let ext_len = ((hello[pos + 2] as usize) << 8) | (hello[pos + 3] as usize);
pos += 4;
if ext_type == 0x0000 {
// SNI extension
return parse_sni_extension(&hello[pos..(pos + ext_len).min(hello.len())], ext_len);
}
pos += ext_len;
}
SniResult::NoSni
}
/// Parse the SNI extension data.
fn parse_sni_extension(data: &[u8], _ext_len: usize) -> SniResult {
if data.len() < 5 {
return SniResult::NeedMoreData;
}
// Server name list length
let _list_len = ((data[0] as usize) << 8) | (data[1] as usize);
// Server name type (0 = hostname)
if data[2] != 0x00 {
return SniResult::NoSni;
}
// Hostname length
let name_len = ((data[3] as usize) << 8) | (data[4] as usize);
if data.len() < 5 + name_len {
return SniResult::NeedMoreData;
}
match std::str::from_utf8(&data[5..5 + name_len]) {
Ok(hostname) => SniResult::Found(hostname.to_lowercase()),
Err(_) => SniResult::NoSni,
}
}
/// Check if the initial bytes look like a TLS ClientHello.
pub fn is_tls(data: &[u8]) -> bool {
data.len() >= 3 && data[0] == 0x16 && data[1] == 0x03
}
/// Extract the HTTP request path from initial data.
/// E.g., from "GET /foo/bar HTTP/1.1\r\n..." returns Some("/foo/bar").
pub fn extract_http_path(data: &[u8]) -> Option<String> {
let text = std::str::from_utf8(data).ok()?;
// Find first space (after method)
let method_end = text.find(' ')?;
let rest = &text[method_end + 1..];
// Find end of path (next space before "HTTP/...")
let path_end = rest.find(' ').unwrap_or(rest.len());
let path = &rest[..path_end];
// Strip query string for path matching
let path = path.split('?').next().unwrap_or(path);
if path.starts_with('/') {
Some(path.to_string())
} else {
None
}
}
/// Extract the HTTP Host header from initial data.
/// E.g., from "GET / HTTP/1.1\r\nHost: example.com\r\n..." returns Some("example.com").
pub fn extract_http_host(data: &[u8]) -> Option<String> {
let text = std::str::from_utf8(data).ok()?;
for line in text.split("\r\n") {
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) {
// Strip port if present
let host = value.split(':').next().unwrap_or(value).trim();
if !host.is_empty() {
return Some(host.to_lowercase());
}
}
}
None
}
/// Check if the initial bytes look like HTTP.
pub fn is_http(data: &[u8]) -> bool {
if data.len() < 4 {
return false;
}
// Check for common HTTP methods
let starts = [
b"GET " as &[u8],
b"POST",
b"PUT ",
b"HEAD",
b"DELE",
b"PATC",
b"OPTI",
b"CONN",
];
starts.iter().any(|s| data.starts_with(s))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_not_tls() {
let http_data = b"GET / HTTP/1.1\r\n";
assert!(matches!(extract_sni(http_data), SniResult::NotTls));
}
#[test]
fn test_too_short() {
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
}
#[test]
fn test_is_tls() {
assert!(is_tls(&[0x16, 0x03, 0x01]));
assert!(!is_tls(&[0x47, 0x45, 0x54])); // "GET"
}
#[test]
fn test_is_http() {
assert!(is_http(b"GET /"));
assert!(is_http(b"POST /api"));
assert!(!is_http(&[0x16, 0x03, 0x01]));
}
#[test]
fn test_real_client_hello() {
// A minimal TLS 1.2 ClientHello with SNI "example.com"
let client_hello: Vec<u8> = build_test_client_hello("example.com");
match extract_sni(&client_hello) {
SniResult::Found(sni) => assert_eq!(sni, "example.com"),
other => panic!("Expected Found, got {:?}", other),
}
}
/// Build a minimal TLS ClientHello for testing.
fn build_test_client_hello(hostname: &str) -> Vec<u8> {
let hostname_bytes = hostname.as_bytes();
// SNI extension
let sni_ext_data = {
let mut d = Vec::new();
// Server name list length
let name_entry_len = 3 + hostname_bytes.len(); // type(1) + len(2) + name
d.push(((name_entry_len >> 8) & 0xFF) as u8);
d.push((name_entry_len & 0xFF) as u8);
// Host name type = 0
d.push(0x00);
// Host name length
d.push(((hostname_bytes.len() >> 8) & 0xFF) as u8);
d.push((hostname_bytes.len() & 0xFF) as u8);
// Host name
d.extend_from_slice(hostname_bytes);
d
};
// Extension: type=0x0000 (SNI), length, data
let sni_extension = {
let mut e = Vec::new();
e.push(0x00); e.push(0x00); // SNI type
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
e.push((sni_ext_data.len() & 0xFF) as u8);
e.extend_from_slice(&sni_ext_data);
e
};
// Extensions block
let extensions = {
let mut ext = Vec::new();
ext.push(((sni_extension.len() >> 8) & 0xFF) as u8);
ext.push((sni_extension.len() & 0xFF) as u8);
ext.extend_from_slice(&sni_extension);
ext
};
// ClientHello body
let hello_body = {
let mut h = Vec::new();
// Client version TLS 1.2
h.push(0x03); h.push(0x03);
// Random (32 bytes)
h.extend_from_slice(&[0u8; 32]);
// Session ID length = 0
h.push(0x00);
// Cipher suites: length=2, one suite
h.push(0x00); h.push(0x02);
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
// Compression methods: length=1, null
h.push(0x01); h.push(0x00);
// Extensions
h.extend_from_slice(&extensions);
h
};
// Handshake: type=1 (ClientHello), length
let handshake = {
let mut hs = Vec::new();
hs.push(0x01); // ClientHello
// 3-byte length
hs.push(((hello_body.len() >> 16) & 0xFF) as u8);
hs.push(((hello_body.len() >> 8) & 0xFF) as u8);
hs.push((hello_body.len() & 0xFF) as u8);
hs.extend_from_slice(&hello_body);
hs
};
// TLS record: type=0x16, version TLS 1.0, length
let mut record = Vec::new();
record.push(0x16); // Handshake
record.push(0x03); record.push(0x01); // TLS 1.0
record.push(((handshake.len() >> 8) & 0xFF) as u8);
record.push((handshake.len() & 0xFF) as u8);
record.extend_from_slice(&handshake);
record
}
}

View File

@@ -0,0 +1,126 @@
//! Socket handler relay for connecting client connections to a TypeScript handler
//! via a Unix domain socket.
//!
//! Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
use tokio::net::UnixStream;
use tokio::io::{AsyncWriteExt, AsyncReadExt};
use tokio::net::TcpStream;
use serde::Serialize;
use tracing::debug;
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct RelayMetadata {
connection_id: u64,
remote_ip: String,
remote_port: u16,
local_port: u16,
sni: Option<String>,
route_name: String,
initial_data_base64: Option<String>,
}
/// Relay a client connection to a TypeScript handler via Unix domain socket.
///
/// Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
pub async fn relay_to_handler(
client: TcpStream,
relay_socket_path: &str,
connection_id: u64,
remote_ip: String,
remote_port: u16,
local_port: u16,
sni: Option<String>,
route_name: String,
initial_data: Option<&[u8]>,
) -> std::io::Result<()> {
debug!(
"Relaying connection {} to handler socket {}",
connection_id, relay_socket_path
);
// Connect to TypeScript handler Unix socket
let mut handler = UnixStream::connect(relay_socket_path).await?;
// Build and send metadata header
let initial_data_base64 = initial_data.map(base64_encode);
let metadata = RelayMetadata {
connection_id,
remote_ip,
remote_port,
local_port,
sni,
route_name,
initial_data_base64,
};
let metadata_json = serde_json::to_string(&metadata)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
handler.write_all(metadata_json.as_bytes()).await?;
handler.write_all(b"\n").await?;
// Bidirectional relay between client and handler
let (mut client_read, mut client_write) = client.into_split();
let (mut handler_read, mut handler_write) = handler.into_split();
let c2h = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if handler_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
let _ = handler_write.shutdown().await;
});
let h2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match handler_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
let _ = client_write.shutdown().await;
});
let _ = tokio::join!(c2h, h2c);
debug!("Relay connection {} completed", connection_id);
Ok(())
}
/// Simple base64 encoding without external dependency.
fn base64_encode(data: &[u8]) -> String {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
for chunk in data.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let n = (b0 << 16) | (b1 << 8) | b2;
result.push(CHARS[((n >> 18) & 0x3F) as usize] as char);
result.push(CHARS[((n >> 12) & 0x3F) as usize] as char);
if chunk.len() > 1 {
result.push(CHARS[((n >> 6) & 0x3F) as usize] as char);
} else {
result.push('=');
}
if chunk.len() > 2 {
result.push(CHARS[(n & 0x3F) as usize] as char);
} else {
result.push('=');
}
}
result
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,272 @@
use std::collections::HashMap;
use std::io::BufReader;
use std::sync::Arc;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey;
use rustls::ServerConfig;
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
use tracing::{debug, info};
use crate::tcp_listener::TlsCertConfig;
/// Ensure the default crypto provider is installed.
fn ensure_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
/// SNI-based certificate resolver with pre-parsed CertifiedKeys.
/// Enables shared ServerConfig across connections — avoids per-connection PEM parsing
/// and enables TLS session resumption.
#[derive(Debug)]
pub struct CertResolver {
certs: HashMap<String, Arc<CertifiedKey>>,
fallback: Option<Arc<CertifiedKey>>,
}
impl CertResolver {
/// Build a resolver from PEM-encoded cert/key configs.
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let provider = rustls::crypto::ring::default_provider();
let mut certs = HashMap::new();
let mut fallback = None;
for (domain, cfg) in configs {
let cert_chain = load_certs(&cfg.cert_pem)?;
let key = load_private_key(&cfg.key_pem)?;
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider)
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?);
if domain == "*" {
fallback = Some(Arc::clone(&ck));
}
certs.insert(domain.clone(), ck);
}
// If no explicit "*" fallback, use the first available cert
if fallback.is_none() {
fallback = certs.values().next().map(Arc::clone);
}
Ok(Self { certs, fallback })
}
}
impl ResolvesServerCert for CertResolver {
fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let domain = match client_hello.server_name() {
Some(name) => name,
None => return self.fallback.clone(),
};
// Exact match
if let Some(ck) = self.certs.get(domain) {
return Some(Arc::clone(ck));
}
// Wildcard: sub.example.com → *.example.com
if let Some(dot) = domain.find('.') {
let wc = format!("*.{}", &domain[dot + 1..]);
if let Some(ck) = self.certs.get(&wc) {
return Some(Arc::clone(ck));
}
}
self.fallback.clone()
}
}
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
/// The returned acceptor can be reused across all connections (cheap Arc clone).
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver));
// Shared session cache — enables session ID resumption across connections
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
config.ticketer = rustls::crypto::ring::Ticketer::new()
.map_err(|e| format!("Ticketer: {}", e))?;
info!("Built shared TLS config with session cache (4096) and ticket support");
Ok(TlsAcceptor::from(Arc::new(config)))
}
/// Build a TLS acceptor from PEM-encoded cert and key data.
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
build_tls_acceptor_with_config(cert_pem, key_pem, None)
}
/// Build a TLS acceptor with optional RouteTls configuration for version/cipher tuning.
pub fn build_tls_acceptor_with_config(
cert_pem: &str,
key_pem: &str,
tls_config: Option<&rustproxy_config::RouteTls>,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let certs = load_certs(cert_pem)?;
let key = load_private_key(key_pem)?;
let mut config = if let Some(route_tls) = tls_config {
// Apply TLS version restrictions
let versions = resolve_tls_versions(route_tls.versions.as_deref());
let builder = ServerConfig::builder_with_protocol_versions(&versions);
builder
.with_no_client_auth()
.with_single_cert(certs, key)?
} else {
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?
};
// Apply session timeout if configured
if let Some(route_tls) = tls_config {
if let Some(timeout_secs) = route_tls.session_timeout {
config.session_storage = rustls::server::ServerSessionMemoryCache::new(
256, // max sessions
);
debug!("TLS session timeout configured: {}s", timeout_secs);
}
}
Ok(TlsAcceptor::from(Arc::new(config)))
}
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
let versions = match versions {
Some(v) if !v.is_empty() => v,
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
};
let mut result = Vec::new();
for v in versions {
match v.as_str() {
"TLSv1.2" | "TLS1.2" | "1.2" | "TLSv12" => {
if !result.contains(&&rustls::version::TLS12) {
result.push(&rustls::version::TLS12);
}
}
"TLSv1.3" | "TLS1.3" | "1.3" | "TLSv13" => {
if !result.contains(&&rustls::version::TLS13) {
result.push(&rustls::version::TLS13);
}
}
other => {
debug!("Unknown TLS version '{}', ignoring", other);
}
}
}
if result.is_empty() {
// Fallback to both if no valid versions specified
vec![&rustls::version::TLS12, &rustls::version::TLS13]
} else {
result
}
}
/// Accept a TLS connection from a client stream.
pub async fn accept_tls(
stream: TcpStream,
acceptor: &TlsAcceptor,
) -> Result<ServerTlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
let tls_stream = acceptor.accept(stream).await?;
debug!("TLS handshake completed");
Ok(tls_stream)
}
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
pub async fn connect_tls(
host: &str,
port: u16,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
stream.set_nodelay(true)?;
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?;
let tls_stream = connector.connect(server_name, stream).await?;
debug!("Backend TLS connection established to {}:{}", host, port);
Ok(tls_stream)
}
/// Load certificates from PEM string.
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() {
return Err("No certificates found in PEM data".into());
}
Ok(certs)
}
/// Load private key from PEM string.
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
// Try PKCS8 first, then RSA, then EC
let key = rustls_pemfile::private_key(&mut reader)?
.ok_or("No private key found in PEM data")?;
Ok(key)
}
/// Insecure certificate verifier for backend connections (terminate-and-reencrypt).
/// In internal networks, backends may use self-signed certs.
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}

View File

@@ -0,0 +1,16 @@
[package]
name = "rustproxy-routing"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "Route matching engine for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
glob-match = { workspace = true }
ipnet = { workspace = true }
regex = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
arc-swap = { workspace = true }

View File

@@ -0,0 +1,9 @@
//! # rustproxy-routing
//!
//! Route matching engine for RustProxy.
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
pub mod route_manager;
pub mod matchers;
pub use route_manager::*;

View File

@@ -0,0 +1,86 @@
/// Match a domain against a pattern supporting wildcards.
///
/// Supported patterns:
/// - `*` matches any domain
/// - `*.example.com` matches any subdomain of example.com
/// - `example.com` exact match
/// - `**.example.com` matches any depth of subdomain
pub fn domain_matches(pattern: &str, domain: &str) -> bool {
let pattern = pattern.trim().to_lowercase();
let domain = domain.trim().to_lowercase();
if pattern == "*" {
return true;
}
if pattern == domain {
return true;
}
// Wildcard patterns
if pattern.starts_with("*.") {
let suffix = &pattern[2..]; // e.g., "example.com"
// Match exact parent or any single-level subdomain
if domain == suffix {
return true;
}
if domain.ends_with(&format!(".{}", suffix)) {
// Check it's a single level subdomain for `*.`
let prefix = &domain[..domain.len() - suffix.len() - 1];
return !prefix.contains('.');
}
return false;
}
if pattern.starts_with("**.") {
let suffix = &pattern[3..];
// Match exact parent or any depth of subdomain
return domain == suffix || domain.ends_with(&format!(".{}", suffix));
}
// Use glob-match for more complex patterns
glob_match::glob_match(&pattern, &domain)
}
/// Check if a domain matches any of the given patterns.
pub fn domain_matches_any(patterns: &[&str], domain: &str) -> bool {
patterns.iter().any(|p| domain_matches(p, domain))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match() {
assert!(domain_matches("example.com", "example.com"));
assert!(!domain_matches("example.com", "other.com"));
}
#[test]
fn test_wildcard_all() {
assert!(domain_matches("*", "anything.com"));
assert!(domain_matches("*", "sub.domain.example.com"));
}
#[test]
fn test_wildcard_subdomain() {
assert!(domain_matches("*.example.com", "www.example.com"));
assert!(domain_matches("*.example.com", "api.example.com"));
assert!(domain_matches("*.example.com", "example.com"));
assert!(!domain_matches("*.example.com", "deep.sub.example.com"));
}
#[test]
fn test_double_wildcard() {
assert!(domain_matches("**.example.com", "www.example.com"));
assert!(domain_matches("**.example.com", "deep.sub.example.com"));
assert!(domain_matches("**.example.com", "example.com"));
}
#[test]
fn test_case_insensitive() {
assert!(domain_matches("Example.COM", "example.com"));
assert!(domain_matches("*.EXAMPLE.com", "WWW.example.COM"));
}
}

View File

@@ -0,0 +1,98 @@
use std::collections::HashMap;
use regex::Regex;
/// Match HTTP headers against a set of patterns.
///
/// Pattern values can be:
/// - Exact string: `"application/json"`
/// - Regex (surrounded by /): `"/^text\/.*/"`
pub fn headers_match(
patterns: &HashMap<String, String>,
headers: &HashMap<String, String>,
) -> bool {
for (key, pattern) in patterns {
let key_lower = key.to_lowercase();
// Find the header (case-insensitive)
let header_value = headers
.iter()
.find(|(k, _)| k.to_lowercase() == key_lower)
.map(|(_, v)| v.as_str());
let header_value = match header_value {
Some(v) => v,
None => return false, // Required header not present
};
// Check if pattern is a regex (surrounded by /)
if pattern.starts_with('/') && pattern.ends_with('/') && pattern.len() > 2 {
let regex_str = &pattern[1..pattern.len() - 1];
match Regex::new(regex_str) {
Ok(re) => {
if !re.is_match(header_value) {
return false;
}
}
Err(_) => {
// Invalid regex, fall back to exact match
if header_value != pattern {
return false;
}
}
}
} else {
// Exact match
if header_value != pattern {
return false;
}
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_header_match() {
let patterns: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("Content-Type".to_string(), "application/json".to_string());
m
};
let headers: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("content-type".to_string(), "application/json".to_string());
m
};
assert!(headers_match(&patterns, &headers));
}
#[test]
fn test_regex_header_match() {
let patterns: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("Content-Type".to_string(), "/^text\\/.*/".to_string());
m
};
let headers: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("content-type".to_string(), "text/html".to_string());
m
};
assert!(headers_match(&patterns, &headers));
}
#[test]
fn test_missing_header() {
let patterns: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("X-Custom".to_string(), "value".to_string());
m
};
let headers: HashMap<String, String> = HashMap::new();
assert!(!headers_match(&patterns, &headers));
}
}

View File

@@ -0,0 +1,126 @@
use std::net::IpAddr;
use std::str::FromStr;
use ipnet::IpNet;
/// Match an IP address against a pattern.
///
/// Supported patterns:
/// - `*` matches any IP
/// - `192.168.1.0/24` CIDR range
/// - `192.168.1.100` exact match
/// - `192.168.1.*` wildcard (converted to CIDR)
/// - `::ffff:192.168.1.100` IPv6-mapped IPv4
pub fn ip_matches(pattern: &str, ip: &str) -> bool {
let pattern = pattern.trim();
if pattern == "*" {
return true;
}
// Normalize IPv4-mapped IPv6
let normalized_ip = normalize_ip_str(ip);
// Try CIDR match
if pattern.contains('/') {
if let Ok(net) = IpNet::from_str(pattern) {
if let Ok(addr) = IpAddr::from_str(&normalized_ip) {
return net.contains(&addr);
}
}
return false;
}
// Handle wildcard patterns like 192.168.1.*
if pattern.contains('*') {
let pattern_cidr = wildcard_to_cidr(pattern);
if let Some(cidr) = pattern_cidr {
if let Ok(net) = IpNet::from_str(&cidr) {
if let Ok(addr) = IpAddr::from_str(&normalized_ip) {
return net.contains(&addr);
}
}
}
return false;
}
// Exact match
let normalized_pattern = normalize_ip_str(pattern);
normalized_ip == normalized_pattern
}
/// Check if an IP matches any of the given patterns.
pub fn ip_matches_any(patterns: &[String], ip: &str) -> bool {
patterns.iter().any(|p| ip_matches(p, ip))
}
/// Normalize IPv4-mapped IPv6 addresses.
fn normalize_ip_str(ip: &str) -> String {
let ip = ip.trim();
if ip.starts_with("::ffff:") {
return ip[7..].to_string();
}
ip.to_string()
}
/// Convert a wildcard IP pattern to CIDR notation.
/// e.g., "192.168.1.*" -> "192.168.1.0/24"
fn wildcard_to_cidr(pattern: &str) -> Option<String> {
let parts: Vec<&str> = pattern.split('.').collect();
if parts.len() != 4 {
return None;
}
let mut octets = [0u8; 4];
let mut prefix_len = 0;
for (i, part) in parts.iter().enumerate() {
if *part == "*" {
break;
}
if let Ok(n) = part.parse::<u8>() {
octets[i] = n;
prefix_len += 8;
} else {
return None;
}
}
Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wildcard_all() {
assert!(ip_matches("*", "192.168.1.100"));
assert!(ip_matches("*", "::1"));
}
#[test]
fn test_exact_match() {
assert!(ip_matches("192.168.1.100", "192.168.1.100"));
assert!(!ip_matches("192.168.1.100", "192.168.1.101"));
}
#[test]
fn test_cidr() {
assert!(ip_matches("192.168.1.0/24", "192.168.1.100"));
assert!(ip_matches("192.168.1.0/24", "192.168.1.1"));
assert!(!ip_matches("192.168.1.0/24", "192.168.2.1"));
}
#[test]
fn test_wildcard_pattern() {
assert!(ip_matches("192.168.1.*", "192.168.1.100"));
assert!(ip_matches("192.168.1.*", "192.168.1.1"));
assert!(!ip_matches("192.168.1.*", "192.168.2.1"));
}
#[test]
fn test_ipv6_mapped() {
assert!(ip_matches("192.168.1.100", "::ffff:192.168.1.100"));
assert!(ip_matches("192.168.1.0/24", "::ffff:192.168.1.50"));
}
}

View File

@@ -0,0 +1,9 @@
pub mod domain;
pub mod path;
pub mod ip;
pub mod header;
pub use domain::*;
pub use path::*;
pub use ip::*;
pub use header::*;

View File

@@ -0,0 +1,65 @@
/// Match a URL path against a pattern supporting wildcards.
///
/// Supported patterns:
/// - `/api/*` matches `/api/anything` (single level)
/// - `/api/**` matches `/api/any/depth/here`
/// - `/exact/path` exact match
/// - `/prefix*` prefix match
pub fn path_matches(pattern: &str, path: &str) -> bool {
// Exact match
if pattern == path {
return true;
}
// Double-star: match any depth
if pattern.ends_with("/**") {
let prefix = &pattern[..pattern.len() - 3];
return path == prefix || path.starts_with(&format!("{}/", prefix));
}
// Single-star at end: match single path segment
if pattern.ends_with("/*") {
let prefix = &pattern[..pattern.len() - 2];
if path == prefix {
return true;
}
if path.starts_with(&format!("{}/", prefix)) {
let rest = &path[prefix.len() + 1..];
// Single level means no more slashes
return !rest.contains('/');
}
return false;
}
// Star anywhere: use glob matching
if pattern.contains('*') {
return glob_match::glob_match(pattern, path);
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_path() {
assert!(path_matches("/api/users", "/api/users"));
assert!(!path_matches("/api/users", "/api/posts"));
}
#[test]
fn test_single_wildcard() {
assert!(path_matches("/api/*", "/api/users"));
assert!(path_matches("/api/*", "/api/posts"));
assert!(!path_matches("/api/*", "/api/users/123"));
}
#[test]
fn test_double_wildcard() {
assert!(path_matches("/api/**", "/api/users"));
assert!(path_matches("/api/**", "/api/users/123"));
assert!(path_matches("/api/**", "/api/users/123/posts"));
}
}

View File

@@ -0,0 +1,776 @@
use std::collections::HashMap;
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode};
use crate::matchers;
/// Context for route matching (subset of connection info).
pub struct MatchContext<'a> {
pub port: u16,
pub domain: Option<&'a str>,
pub path: Option<&'a str>,
pub client_ip: Option<&'a str>,
pub tls_version: Option<&'a str>,
pub headers: Option<&'a HashMap<String, String>>,
pub is_tls: bool,
/// Detected protocol: "http" or "tcp". None when unknown (e.g. pre-TLS-termination).
pub protocol: Option<&'a str>,
}
/// Result of a route match.
pub struct RouteMatchResult<'a> {
pub route: &'a RouteConfig,
pub target: Option<&'a RouteTarget>,
}
/// Port-indexed route lookup with priority-based matching.
/// This is the core routing engine.
pub struct RouteManager {
/// Routes indexed by port for O(1) port lookup.
port_index: HashMap<u16, Vec<usize>>,
/// All routes, sorted by priority (highest first).
routes: Vec<RouteConfig>,
}
impl RouteManager {
/// Create a new RouteManager from a list of routes.
pub fn new(routes: Vec<RouteConfig>) -> Self {
let mut manager = Self {
port_index: HashMap::new(),
routes: Vec::new(),
};
// Filter enabled routes and sort by priority
let mut enabled_routes: Vec<RouteConfig> = routes
.into_iter()
.filter(|r| r.is_enabled())
.collect();
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
// Build port index
for (idx, route) in enabled_routes.iter().enumerate() {
for port in route.listening_ports() {
manager.port_index
.entry(port)
.or_default()
.push(idx);
}
}
manager.routes = enabled_routes;
manager
}
/// Find the best matching route for the given context.
pub fn find_route<'a>(&'a self, ctx: &MatchContext<'_>) -> Option<RouteMatchResult<'a>> {
// Get routes for this port
let route_indices = self.port_index.get(&ctx.port)?;
for &idx in route_indices {
let route = &self.routes[idx];
if self.matches_route(route, ctx) {
// Find the best matching target within the route
let target = self.find_target(route, ctx);
return Some(RouteMatchResult { route, target });
}
}
None
}
/// Check if a route matches the given context.
fn matches_route(&self, route: &RouteConfig, ctx: &MatchContext<'_>) -> bool {
let rm = &route.route_match;
// Domain matching
if let Some(ref domains) = rm.domains {
if let Some(domain) = ctx.domain {
let patterns = domains.to_vec();
if !matchers::domain_matches_any(&patterns, domain) {
return false;
}
} else if ctx.is_tls {
// TLS connection without SNI cannot match a domain-restricted route.
// This prevents session-ticket resumption from misrouting when clients
// omit SNI (RFC 8446 recommends but doesn't mandate SNI on resumption).
// Wildcard-only routes (domains: ["*"]) still match since they accept all.
let patterns = domains.to_vec();
let is_wildcard_only = patterns.iter().all(|d| *d == "*");
if !is_wildcard_only {
return false;
}
}
}
// Path matching
if let Some(ref pattern) = rm.path {
if let Some(path) = ctx.path {
if !matchers::path_matches(pattern, path) {
return false;
}
} else {
// Route requires path but none provided
return false;
}
}
// Client IP matching
if let Some(ref client_ips) = rm.client_ip {
if let Some(ip) = ctx.client_ip {
if !matchers::ip_matches_any(client_ips, ip) {
return false;
}
} else {
return false;
}
}
// TLS version matching
if let Some(ref tls_versions) = rm.tls_version {
if let Some(version) = ctx.tls_version {
if !tls_versions.iter().any(|v| v == version) {
return false;
}
} else {
return false;
}
}
// Header matching
if let Some(ref patterns) = rm.headers {
if let Some(headers) = ctx.headers {
if !matchers::headers_match(patterns, headers) {
return false;
}
} else {
return false;
}
}
// Protocol matching
if let Some(ref required_protocol) = rm.protocol {
if let Some(protocol) = ctx.protocol {
if required_protocol != protocol {
return false;
}
}
// If protocol not yet known (None), allow match — protocol will be
// validated after detection (post-TLS-termination peek)
}
true
}
/// Find the best matching target within a route.
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> {
let targets = route.action.targets.as_ref()?;
if targets.len() == 1 && targets[0].target_match.is_none() {
return Some(&targets[0]);
}
// Sort candidates by priority (already in order from config)
let mut best: Option<&RouteTarget> = None;
let mut best_priority = i32::MIN;
for target in targets {
let priority = target.priority.unwrap_or(0);
if let Some(ref tm) = target.target_match {
if !self.matches_target(tm, ctx) {
continue;
}
}
if priority > best_priority || best.is_none() {
best = Some(target);
best_priority = priority;
}
}
// Fall back to first target without match criteria
best.or_else(|| {
targets.iter().find(|t| t.target_match.is_none())
})
}
/// Check if a target match criteria matches the context.
fn matches_target(
&self,
tm: &rustproxy_config::TargetMatch,
ctx: &MatchContext<'_>,
) -> bool {
// Port matching
if let Some(ref ports) = tm.ports {
if !ports.contains(&ctx.port) {
return false;
}
}
// Path matching
if let Some(ref pattern) = tm.path {
if let Some(path) = ctx.path {
if !matchers::path_matches(pattern, path) {
return false;
}
} else {
return false;
}
}
// Header matching
if let Some(ref patterns) = tm.headers {
if let Some(headers) = ctx.headers {
if !matchers::headers_match(patterns, headers) {
return false;
}
} else {
return false;
}
}
true
}
/// Get all unique listening ports.
pub fn listening_ports(&self) -> Vec<u16> {
let mut ports: Vec<u16> = self.port_index.keys().copied().collect();
ports.sort();
ports
}
/// Get all routes for a specific port.
pub fn routes_for_port(&self, port: u16) -> Vec<&RouteConfig> {
self.port_index
.get(&port)
.map(|indices| indices.iter().map(|&i| &self.routes[i]).collect())
.unwrap_or_default()
}
/// Get the total number of enabled routes.
pub fn route_count(&self) -> usize {
self.routes.len()
}
/// Check if any route on the given port requires SNI.
pub fn port_requires_sni(&self, port: u16) -> bool {
let routes = self.routes_for_port(port);
// If multiple passthrough routes on same port, SNI is needed
let passthrough_routes: Vec<_> = routes
.iter()
.filter(|r| {
r.tls_mode() == Some(&TlsMode::Passthrough)
})
.collect();
if passthrough_routes.len() > 1 {
return true;
}
// Single passthrough route with specific domain restriction needs SNI
if let Some(route) = passthrough_routes.first() {
if let Some(ref domains) = route.route_match.domains {
let domain_list = domains.to_vec();
// If it's not just a wildcard, SNI is needed
if !domain_list.iter().all(|d| *d == "*") {
return true;
}
}
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustproxy_config::*;
fn make_route(port: u16, domain: Option<&str>, priority: i32) -> RouteConfig {
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(port),
domains: domain.map(|d| DomainSpec::Single(d.to_string())),
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(vec![RouteTarget {
target_match: None,
host: HostSpec::Single("localhost".to_string()),
port: PortSpec::Fixed(8080),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: Some(priority),
tags: None,
enabled: None,
}
}
#[test]
fn test_basic_routing() {
let routes = vec![
make_route(80, Some("example.com"), 0),
make_route(80, Some("other.com"), 0),
];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
let result = manager.find_route(&ctx);
assert!(result.is_some());
}
#[test]
fn test_priority_ordering() {
let routes = vec![
make_route(80, Some("*.example.com"), 0),
make_route(80, Some("api.example.com"), 10), // Higher priority
];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("api.example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
let result = manager.find_route(&ctx).unwrap();
// Should match the higher-priority specific route
assert!(result.route.route_match.domains.as_ref()
.map(|d| d.to_vec())
.unwrap()
.contains(&"api.example.com"));
}
#[test]
fn test_no_match() {
let routes = vec![make_route(80, Some("example.com"), 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 443, // Different port
domain: Some("example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_disabled_routes_excluded() {
let mut route = make_route(80, Some("example.com"), 0);
route.enabled = Some(false);
let manager = RouteManager::new(vec![route]);
assert_eq!(manager.route_count(), 0);
}
#[test]
fn test_listening_ports() {
let routes = vec![
make_route(80, Some("a.com"), 0),
make_route(443, Some("b.com"), 0),
make_route(80, Some("c.com"), 0), // duplicate port
];
let manager = RouteManager::new(routes);
let ports = manager.listening_ports();
assert_eq!(ports, vec![80, 443]);
}
#[test]
fn test_port_requires_sni_single_passthrough() {
let mut route = make_route(443, Some("example.com"), 0);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
let manager = RouteManager::new(vec![route]);
// Single passthrough route with specific domain needs SNI
assert!(manager.port_requires_sni(443));
}
#[test]
fn test_port_requires_sni_wildcard_only() {
let mut route = make_route(443, Some("*"), 0);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
let manager = RouteManager::new(vec![route]);
// Single passthrough route with wildcard doesn't need SNI
assert!(!manager.port_requires_sni(443));
}
#[test]
fn test_routes_for_port() {
let routes = vec![
make_route(80, Some("a.com"), 0),
make_route(80, Some("b.com"), 0),
make_route(443, Some("c.com"), 0),
];
let manager = RouteManager::new(routes);
assert_eq!(manager.routes_for_port(80).len(), 2);
assert_eq!(manager.routes_for_port(443).len(), 1);
assert_eq!(manager.routes_for_port(8080).len(), 0);
}
#[test]
fn test_wildcard_domain_matches_any() {
let routes = vec![make_route(80, Some("*"), 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("anything.example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_tls_no_sni_rejects_domain_restricted_route() {
let routes = vec![make_route(443, Some("example.com"), 0)];
let manager = RouteManager::new(routes);
// TLS connection without SNI should NOT match a domain-restricted route
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_tls_no_sni_rejects_wildcard_subdomain_route() {
let routes = vec![make_route(443, Some("*.example.com"), 0)];
let manager = RouteManager::new(routes);
// TLS connection without SNI should NOT match *.example.com
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_tls_no_sni_matches_wildcard_only_route() {
let routes = vec![make_route(443, Some("*"), 0)];
let manager = RouteManager::new(routes);
// TLS connection without SNI SHOULD match a wildcard-only route
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_tls_no_sni_skips_domain_restricted_matches_fallback() {
// Two routes: first is domain-restricted, second is wildcard catch-all
let routes = vec![
make_route(443, Some("specific.com"), 10),
make_route(443, Some("*"), 0),
];
let manager = RouteManager::new(routes);
// TLS without SNI should skip specific.com and fall through to wildcard
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
};
let result = manager.find_route(&ctx);
assert!(result.is_some());
let matched_domains = result.unwrap().route.route_match.domains.as_ref()
.map(|d| d.to_vec()).unwrap();
assert!(matched_domains.contains(&"*"));
}
#[test]
fn test_non_tls_no_domain_still_matches_domain_restricted() {
// Non-TLS (plain HTTP) without domain should still match domain-restricted routes
// (the HTTP proxy layer handles Host-based routing)
let routes = vec![make_route(80, Some("example.com"), 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_no_domain_route_matches_any_domain() {
let routes = vec![make_route(80, None, 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_target_sub_matching() {
let mut route = make_route(80, Some("example.com"), 0);
route.action.targets = Some(vec![
RouteTarget {
target_match: Some(rustproxy_config::TargetMatch {
ports: None,
path: Some("/api/*".to_string()),
headers: None,
method: None,
}),
host: HostSpec::Single("api-backend".to_string()),
port: PortSpec::Fixed(3000),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: Some(10),
},
RouteTarget {
target_match: None,
host: HostSpec::Single("default-backend".to_string()),
port: PortSpec::Fixed(8080),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
},
]);
let manager = RouteManager::new(vec![route]);
// Should match the API target
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: Some("/api/users"),
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
let result = manager.find_route(&ctx).unwrap();
assert_eq!(result.target.unwrap().host.first(), "api-backend");
// Should fall back to default target
let ctx = MatchContext {
port: 80,
domain: Some("example.com"),
path: Some("/home"),
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
};
let result = manager.find_route(&ctx).unwrap();
assert_eq!(result.target.unwrap().host.first(), "default-backend");
}
fn make_route_with_protocol(port: u16, domain: Option<&str>, protocol: Option<&str>) -> RouteConfig {
let mut route = make_route(port, domain, 0);
route.route_match.protocol = protocol.map(|s| s.to_string());
route
}
#[test]
fn test_protocol_http_matches_http() {
let routes = vec![make_route_with_protocol(80, None, Some("http"))];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("http"),
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_protocol_http_rejects_tcp() {
let routes = vec![make_route_with_protocol(80, None, Some("http"))];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("tcp"),
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_protocol_none_matches_any() {
// Route with no protocol restriction matches any protocol
let routes = vec![make_route_with_protocol(80, None, None)];
let manager = RouteManager::new(routes);
let ctx_http = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("http"),
};
assert!(manager.find_route(&ctx_http).is_some());
let ctx_tcp = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("tcp"),
};
assert!(manager.find_route(&ctx_tcp).is_some());
}
#[test]
fn test_protocol_http_matches_when_unknown() {
// Route with protocol: "http" should match when ctx.protocol is None
// (pre-TLS-termination, protocol not yet known)
let routes = vec![make_route_with_protocol(443, None, Some("http"))];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
};
assert!(manager.find_route(&ctx).is_some());
}
}

View File

@@ -0,0 +1,17 @@
[package]
name = "rustproxy-security"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "IP filtering, rate limiting, and authentication for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
dashmap = { workspace = true }
ipnet = { workspace = true }
jsonwebtoken = { workspace = true }
base64 = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
serde = { workspace = true }

View File

@@ -0,0 +1,111 @@
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
/// Basic auth validator.
pub struct BasicAuthValidator {
users: Vec<(String, String)>,
realm: String,
}
impl BasicAuthValidator {
pub fn new(users: Vec<(String, String)>, realm: Option<String>) -> Self {
Self {
users,
realm: realm.unwrap_or_else(|| "Restricted".to_string()),
}
}
/// Validate an Authorization header value.
/// Returns the username if valid.
pub fn validate(&self, auth_header: &str) -> Option<String> {
let auth_header = auth_header.trim();
if !auth_header.starts_with("Basic ") {
return None;
}
let encoded = &auth_header[6..];
let decoded = BASE64.decode(encoded).ok()?;
let credentials = String::from_utf8(decoded).ok()?;
let mut parts = credentials.splitn(2, ':');
let username = parts.next()?;
let password = parts.next()?;
for (u, p) in &self.users {
if u == username && p == password {
return Some(username.to_string());
}
}
None
}
/// Get the realm for WWW-Authenticate header.
pub fn realm(&self) -> &str {
&self.realm
}
/// Generate the WWW-Authenticate header value.
pub fn www_authenticate(&self) -> String {
format!("Basic realm=\"{}\"", self.realm)
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine;
fn make_validator() -> BasicAuthValidator {
BasicAuthValidator::new(
vec![
("admin".to_string(), "secret".to_string()),
("user".to_string(), "pass".to_string()),
],
Some("TestRealm".to_string()),
)
}
fn encode_basic(user: &str, pass: &str) -> String {
let encoded = BASE64.encode(format!("{}:{}", user, pass));
format!("Basic {}", encoded)
}
#[test]
fn test_valid_credentials() {
let validator = make_validator();
let header = encode_basic("admin", "secret");
assert_eq!(validator.validate(&header), Some("admin".to_string()));
}
#[test]
fn test_invalid_password() {
let validator = make_validator();
let header = encode_basic("admin", "wrong");
assert_eq!(validator.validate(&header), None);
}
#[test]
fn test_not_basic_scheme() {
let validator = make_validator();
assert_eq!(validator.validate("Bearer sometoken"), None);
}
#[test]
fn test_malformed_base64() {
let validator = make_validator();
assert_eq!(validator.validate("Basic !!!not-base64!!!"), None);
}
#[test]
fn test_www_authenticate_format() {
let validator = make_validator();
assert_eq!(validator.www_authenticate(), "Basic realm=\"TestRealm\"");
}
#[test]
fn test_default_realm() {
let validator = BasicAuthValidator::new(vec![], None);
assert_eq!(validator.www_authenticate(), "Basic realm=\"Restricted\"");
}
}

View File

@@ -0,0 +1,189 @@
use ipnet::IpNet;
use std::net::IpAddr;
use std::str::FromStr;
/// IP filter supporting CIDR ranges, wildcards, and exact matches.
pub struct IpFilter {
allow_list: Vec<IpPattern>,
block_list: Vec<IpPattern>,
}
/// Represents an IP pattern for matching.
#[derive(Debug)]
enum IpPattern {
/// Exact IP match
Exact(IpAddr),
/// CIDR range match
Cidr(IpNet),
/// Wildcard (matches everything)
Wildcard,
}
impl IpPattern {
fn parse(s: &str) -> Self {
let s = s.trim();
if s == "*" {
return IpPattern::Wildcard;
}
if let Ok(net) = IpNet::from_str(s) {
return IpPattern::Cidr(net);
}
if let Ok(addr) = IpAddr::from_str(s) {
return IpPattern::Exact(addr);
}
// Try as CIDR by appending default prefix
if let Ok(addr) = IpAddr::from_str(s) {
return IpPattern::Exact(addr);
}
// Fallback: treat as exact, will never match an invalid string
IpPattern::Exact(IpAddr::from_str("0.0.0.0").unwrap())
}
fn matches(&self, ip: &IpAddr) -> bool {
match self {
IpPattern::Wildcard => true,
IpPattern::Exact(addr) => addr == ip,
IpPattern::Cidr(net) => net.contains(ip),
}
}
}
impl IpFilter {
/// Create a new IP filter from allow and block lists.
pub fn new(allow_list: &[String], block_list: &[String]) -> Self {
Self {
allow_list: allow_list.iter().map(|s| IpPattern::parse(s)).collect(),
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
}
}
/// Check if an IP is allowed.
/// If allow_list is non-empty, IP must match at least one entry.
/// If block_list is non-empty, IP must NOT match any entry.
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
// Check block list first
if !self.block_list.is_empty() {
for pattern in &self.block_list {
if pattern.matches(ip) {
return false;
}
}
}
// If allow list is non-empty, must match at least one
if !self.allow_list.is_empty() {
return self.allow_list.iter().any(|p| p.matches(ip));
}
true
}
/// Normalize IPv4-mapped IPv6 addresses (::ffff:x.x.x.x -> x.x.x.x)
pub fn normalize_ip(ip: &IpAddr) -> IpAddr {
match ip {
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
IpAddr::V4(v4)
} else {
*ip
}
}
_ => *ip,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_lists_allow_all() {
let filter = IpFilter::new(&[], &[]);
let ip: IpAddr = "192.168.1.1".parse().unwrap();
assert!(filter.is_allowed(&ip));
}
#[test]
fn test_allow_list_exact() {
let filter = IpFilter::new(
&["10.0.0.1".to_string()],
&[],
);
let allowed: IpAddr = "10.0.0.1".parse().unwrap();
let denied: IpAddr = "10.0.0.2".parse().unwrap();
assert!(filter.is_allowed(&allowed));
assert!(!filter.is_allowed(&denied));
}
#[test]
fn test_allow_list_cidr() {
let filter = IpFilter::new(
&["10.0.0.0/8".to_string()],
&[],
);
let allowed: IpAddr = "10.255.255.255".parse().unwrap();
let denied: IpAddr = "192.168.1.1".parse().unwrap();
assert!(filter.is_allowed(&allowed));
assert!(!filter.is_allowed(&denied));
}
#[test]
fn test_block_list() {
let filter = IpFilter::new(
&[],
&["192.168.1.100".to_string()],
);
let blocked: IpAddr = "192.168.1.100".parse().unwrap();
let allowed: IpAddr = "192.168.1.101".parse().unwrap();
assert!(!filter.is_allowed(&blocked));
assert!(filter.is_allowed(&allowed));
}
#[test]
fn test_block_trumps_allow() {
let filter = IpFilter::new(
&["10.0.0.0/8".to_string()],
&["10.0.0.5".to_string()],
);
let blocked: IpAddr = "10.0.0.5".parse().unwrap();
let allowed: IpAddr = "10.0.0.6".parse().unwrap();
assert!(!filter.is_allowed(&blocked));
assert!(filter.is_allowed(&allowed));
}
#[test]
fn test_wildcard_allow() {
let filter = IpFilter::new(
&["*".to_string()],
&[],
);
let ip: IpAddr = "1.2.3.4".parse().unwrap();
assert!(filter.is_allowed(&ip));
}
#[test]
fn test_wildcard_block() {
let filter = IpFilter::new(
&[],
&["*".to_string()],
);
let ip: IpAddr = "1.2.3.4".parse().unwrap();
assert!(!filter.is_allowed(&ip));
}
#[test]
fn test_normalize_ipv4_mapped_ipv6() {
let mapped: IpAddr = "::ffff:192.168.1.1".parse().unwrap();
let normalized = IpFilter::normalize_ip(&mapped);
let expected: IpAddr = "192.168.1.1".parse().unwrap();
assert_eq!(normalized, expected);
}
#[test]
fn test_normalize_pure_ipv4() {
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let normalized = IpFilter::normalize_ip(&ip);
assert_eq!(normalized, ip);
}
}

View File

@@ -0,0 +1,174 @@
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
use serde::{Deserialize, Serialize};
/// JWT claims (minimal structure).
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: Option<String>,
pub exp: Option<u64>,
pub iss: Option<String>,
pub aud: Option<String>,
}
/// JWT auth validator.
pub struct JwtValidator {
decoding_key: DecodingKey,
validation: Validation,
}
impl JwtValidator {
pub fn new(
secret: &str,
algorithm: Option<&str>,
issuer: Option<&str>,
audience: Option<&str>,
) -> Self {
let algo = match algorithm {
Some("HS384") => Algorithm::HS384,
Some("HS512") => Algorithm::HS512,
Some("RS256") => Algorithm::RS256,
_ => Algorithm::HS256,
};
let mut validation = Validation::new(algo);
if let Some(iss) = issuer {
validation.set_issuer(&[iss]);
}
if let Some(aud) = audience {
validation.set_audience(&[aud]);
}
Self {
decoding_key: DecodingKey::from_secret(secret.as_bytes()),
validation,
}
}
/// Validate a JWT token string (without "Bearer " prefix).
/// Returns the claims if valid.
pub fn validate(&self, token: &str) -> Result<Claims, String> {
decode::<Claims>(token, &self.decoding_key, &self.validation)
.map(|data| data.claims)
.map_err(|e| e.to_string())
}
/// Extract token from Authorization header.
pub fn extract_token(auth_header: &str) -> Option<&str> {
let header = auth_header.trim();
if header.starts_with("Bearer ") {
Some(&header[7..])
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{encode, EncodingKey, Header};
fn make_token(secret: &str, claims: &Claims) -> String {
encode(
&Header::default(),
claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap()
}
fn future_exp() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600
}
fn past_exp() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
- 3600
}
#[test]
fn test_valid_token() {
let secret = "test-secret";
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(future_exp()),
iss: None,
aud: None,
};
let token = make_token(secret, &claims);
let validator = JwtValidator::new(secret, None, None, None);
let result = validator.validate(&token);
assert!(result.is_ok());
assert_eq!(result.unwrap().sub, Some("user123".to_string()));
}
#[test]
fn test_expired_token() {
let secret = "test-secret";
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(past_exp()),
iss: None,
aud: None,
};
let token = make_token(secret, &claims);
let validator = JwtValidator::new(secret, None, None, None);
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_wrong_secret() {
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(future_exp()),
iss: None,
aud: None,
};
let token = make_token("correct-secret", &claims);
let validator = JwtValidator::new("wrong-secret", None, None, None);
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_issuer_validation() {
let secret = "test-secret";
let claims = Claims {
sub: Some("user123".to_string()),
exp: Some(future_exp()),
iss: Some("my-issuer".to_string()),
aud: None,
};
let token = make_token(secret, &claims);
// Correct issuer
let validator = JwtValidator::new(secret, None, Some("my-issuer"), None);
assert!(validator.validate(&token).is_ok());
// Wrong issuer
let validator = JwtValidator::new(secret, None, Some("other-issuer"), None);
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_extract_token_bearer() {
assert_eq!(
JwtValidator::extract_token("Bearer abc123"),
Some("abc123")
);
}
#[test]
fn test_extract_token_non_bearer() {
assert_eq!(JwtValidator::extract_token("Basic abc123"), None);
assert_eq!(JwtValidator::extract_token("abc123"), None);
}
}

View File

@@ -0,0 +1,13 @@
//! # rustproxy-security
//!
//! IP filtering, rate limiting, and authentication for RustProxy.
pub mod ip_filter;
pub mod rate_limiter;
pub mod basic_auth;
pub mod jwt_auth;
pub use ip_filter::*;
pub use rate_limiter::*;
pub use basic_auth::*;
pub use jwt_auth::*;

View File

@@ -0,0 +1,97 @@
use dashmap::DashMap;
use std::time::Instant;
/// Sliding window rate limiter.
pub struct RateLimiter {
/// Map of key -> list of request timestamps
windows: DashMap<String, Vec<Instant>>,
/// Maximum requests per window
max_requests: u64,
/// Window duration in seconds
window_seconds: u64,
}
impl RateLimiter {
pub fn new(max_requests: u64, window_seconds: u64) -> Self {
Self {
windows: DashMap::new(),
max_requests,
window_seconds,
}
}
/// Check if a request is allowed for the given key.
/// Returns true if allowed, false if rate limited.
pub fn check(&self, key: &str) -> bool {
let now = Instant::now();
let window = std::time::Duration::from_secs(self.window_seconds);
let mut entry = self.windows.entry(key.to_string()).or_default();
let timestamps = entry.value_mut();
// Remove expired entries
timestamps.retain(|t| now.duration_since(*t) < window);
if timestamps.len() as u64 >= self.max_requests {
false
} else {
timestamps.push(now);
true
}
}
/// Clean up expired entries (call periodically).
pub fn cleanup(&self) {
let now = Instant::now();
let window = std::time::Duration::from_secs(self.window_seconds);
self.windows.retain(|_, timestamps| {
timestamps.retain(|t| now.duration_since(*t) < window);
!timestamps.is_empty()
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allow_under_limit() {
let limiter = RateLimiter::new(5, 60);
for _ in 0..5 {
assert!(limiter.check("client-1"));
}
}
#[test]
fn test_block_over_limit() {
let limiter = RateLimiter::new(3, 60);
assert!(limiter.check("client-1"));
assert!(limiter.check("client-1"));
assert!(limiter.check("client-1"));
assert!(!limiter.check("client-1")); // 4th request blocked
}
#[test]
fn test_different_keys_independent() {
let limiter = RateLimiter::new(2, 60);
assert!(limiter.check("client-a"));
assert!(limiter.check("client-a"));
assert!(!limiter.check("client-a")); // blocked
// Different key should still be allowed
assert!(limiter.check("client-b"));
assert!(limiter.check("client-b"));
}
#[test]
fn test_cleanup_removes_expired() {
let limiter = RateLimiter::new(100, 0); // 0 second window = immediately expired
limiter.check("client-1");
// Sleep briefly to let entries expire
std::thread::sleep(std::time::Duration::from_millis(10));
limiter.cleanup();
// After cleanup, the key should be allowed again (entries expired)
assert!(limiter.check("client-1"));
}
}

View File

@@ -0,0 +1,20 @@
[package]
name = "rustproxy-tls"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "TLS certificate management for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
tokio = { workspace = true }
rustls = { workspace = true }
instant-acme = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
serde = { workspace = true }
rcgen = { workspace = true }
[dev-dependencies]

View File

@@ -0,0 +1,275 @@
//! ACME (Let's Encrypt) integration using instant-acme.
//!
//! This module handles HTTP-01 challenge creation and certificate provisioning.
//! Account credentials are ephemeral — the consumer owns all persistence.
use instant_acme::{
Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus,
AccountCredentials,
};
use rcgen::{CertificateParams, KeyPair};
use thiserror::Error;
use tracing::{debug, info};
#[derive(Debug, Error)]
pub enum AcmeError {
#[error("ACME account creation failed: {0}")]
AccountCreation(String),
#[error("ACME order failed: {0}")]
OrderFailed(String),
#[error("Challenge failed: {0}")]
ChallengeFailed(String),
#[error("Certificate finalization failed: {0}")]
FinalizationFailed(String),
#[error("No HTTP-01 challenge found")]
NoHttp01Challenge,
#[error("Timeout waiting for order: {0}")]
Timeout(String),
}
/// Pending HTTP-01 challenge that needs to be served.
pub struct PendingChallenge {
pub token: String,
pub key_authorization: String,
pub domain: String,
}
/// ACME client wrapper around instant-acme.
pub struct AcmeClient {
use_production: bool,
email: String,
}
impl AcmeClient {
pub fn new(email: String, use_production: bool) -> Self {
Self {
use_production,
email,
}
}
/// Create a new ACME account (ephemeral — not persisted).
async fn get_or_create_account(&self) -> Result<Account, AcmeError> {
let directory_url = self.directory_url();
let contact = format!("mailto:{}", self.email);
let (account, _credentials) = Account::create(
&NewAccount {
contact: &[&contact],
terms_of_service_agreed: true,
only_return_existing: false,
},
directory_url,
None,
)
.await
.map_err(|e| AcmeError::AccountCreation(e.to_string()))?;
debug!("ACME account created");
Ok(account)
}
/// Request a certificate for a domain using the HTTP-01 challenge.
///
/// Returns (cert_chain_pem, private_key_pem) on success.
///
/// The caller must serve the HTTP-01 challenge at:
/// `http://<domain>/.well-known/acme-challenge/<token>`
///
/// The `challenge_handler` closure is called with a `PendingChallenge`
/// and must arrange for the challenge response to be served. It should
/// return once the challenge is ready to be validated.
pub async fn provision<F, Fut>(
&self,
domain: &str,
challenge_handler: F,
) -> Result<(String, String), AcmeError>
where
F: FnOnce(PendingChallenge) -> Fut,
Fut: std::future::Future<Output = Result<(), AcmeError>>,
{
info!("Starting ACME provisioning for {} via {}", domain, self.directory_url());
// 1. Get or create ACME account
let account = self.get_or_create_account().await?;
// 2. Create order
let identifier = Identifier::Dns(domain.to_string());
let mut order = account
.new_order(&NewOrder {
identifiers: &[identifier],
})
.await
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
debug!("ACME order created");
// 3. Get authorizations and find HTTP-01 challenge
let authorizations = order
.authorizations()
.await
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
// Find the HTTP-01 challenge
let (challenge_token, challenge_url) = authorizations
.iter()
.flat_map(|auth| auth.challenges.iter())
.find(|c| c.r#type == ChallengeType::Http01)
.map(|c| {
let key_auth = order.key_authorization(c);
(
PendingChallenge {
token: c.token.clone(),
key_authorization: key_auth.as_str().to_string(),
domain: domain.to_string(),
},
c.url.clone(),
)
})
.ok_or(AcmeError::NoHttp01Challenge)?;
// Call the handler to set up challenge serving
challenge_handler(challenge_token).await?;
// 4. Notify ACME server that challenge is ready
order
.set_challenge_ready(&challenge_url)
.await
.map_err(|e| AcmeError::ChallengeFailed(e.to_string()))?;
debug!("Challenge marked as ready, waiting for validation...");
// 5. Poll for order to become ready
let mut attempts = 0;
let state = loop {
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
let state = order
.refresh()
.await
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
match state.status {
OrderStatus::Ready | OrderStatus::Valid => break state.status,
OrderStatus::Invalid => {
return Err(AcmeError::ChallengeFailed(
"Order became invalid (challenge failed)".to_string(),
));
}
_ => {
attempts += 1;
if attempts > 30 {
return Err(AcmeError::Timeout(
"Order did not become ready within 60 seconds".to_string(),
));
}
}
}
};
debug!("Order ready, finalizing...");
// 6. Generate CSR and finalize
let key_pair = KeyPair::generate().map_err(|e| {
AcmeError::FinalizationFailed(format!("Key generation failed: {}", e))
})?;
let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| {
AcmeError::FinalizationFailed(format!("CSR params failed: {}", e))
})?;
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
let csr = params.serialize_request(&key_pair).map_err(|e| {
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
})?;
if state == OrderStatus::Ready {
order
.finalize(csr.der())
.await
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?;
}
// 7. Wait for certificate to be issued
let mut attempts = 0;
loop {
let state = order
.refresh()
.await
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
if state.status == OrderStatus::Valid {
break;
}
if state.status == OrderStatus::Invalid {
return Err(AcmeError::FinalizationFailed(
"Order became invalid during finalization".to_string(),
));
}
attempts += 1;
if attempts > 15 {
return Err(AcmeError::Timeout(
"Certificate not issued within 30 seconds".to_string(),
));
}
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
}
// 8. Download certificate
let cert_chain_pem = order
.certificate()
.await
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
.ok_or_else(|| {
AcmeError::FinalizationFailed("No certificate returned".to_string())
})?;
let private_key_pem = key_pair.serialize_pem();
info!("Certificate provisioned successfully for {}", domain);
Ok((cert_chain_pem, private_key_pem))
}
/// Restore an ACME account from stored credentials.
pub async fn restore_account(
&self,
credentials: AccountCredentials,
) -> Result<Account, AcmeError> {
Account::from_credentials(credentials)
.await
.map_err(|e| AcmeError::AccountCreation(e.to_string()))
}
/// Get the ACME directory URL based on production/staging.
pub fn directory_url(&self) -> &str {
if self.use_production {
"https://acme-v02.api.letsencrypt.org/directory"
} else {
"https://acme-staging-v02.api.letsencrypt.org/directory"
}
}
/// Whether this client is configured for production.
pub fn is_production(&self) -> bool {
self.use_production
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_directory_url_staging() {
let client = AcmeClient::new("test@example.com".to_string(), false);
assert!(client.directory_url().contains("staging"));
assert!(!client.is_production());
}
#[test]
fn test_directory_url_production() {
let client = AcmeClient::new("test@example.com".to_string(), true);
assert!(!client.directory_url().contains("staging"));
assert!(client.is_production());
}
}

View File

@@ -0,0 +1,168 @@
use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error;
use tracing::info;
use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource};
use crate::acme::AcmeClient;
#[derive(Debug, Error)]
pub enum CertManagerError {
#[error("ACME provisioning failed for {domain}: {message}")]
AcmeFailure { domain: String, message: String },
#[error("No ACME email configured")]
NoEmail,
}
/// Certificate lifecycle manager.
/// Handles ACME provisioning, static cert loading, and renewal.
pub struct CertManager {
store: CertStore,
acme_email: Option<String>,
use_production: bool,
renew_before_days: u32,
}
impl CertManager {
pub fn new(
store: CertStore,
acme_email: Option<String>,
use_production: bool,
renew_before_days: u32,
) -> Self {
Self {
store,
acme_email,
use_production,
renew_before_days,
}
}
/// Get a certificate for a domain (from cache).
pub fn get_cert(&self, domain: &str) -> Option<&CertBundle> {
self.store.get(domain)
}
/// Create an ACME client using this manager's configuration.
/// Returns None if no ACME email is configured.
pub fn acme_client(&self) -> Option<AcmeClient> {
self.acme_email.as_ref().map(|email| {
AcmeClient::new(email.clone(), self.use_production)
})
}
/// Load a static certificate into the store (infallible — pure cache insert).
pub fn load_static(
&mut self,
domain: String,
bundle: CertBundle,
) {
self.store.store(domain, bundle);
}
/// Check and return domains that need certificate renewal.
///
/// A certificate needs renewal if it expires within `renew_before_days`.
/// Returns a list of domain names needing renewal.
pub fn check_renewals(&self) -> Vec<String> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let renewal_threshold = self.renew_before_days as u64 * 86400;
let mut needs_renewal = Vec::new();
for (domain, bundle) in self.store.iter() {
// Only auto-renew ACME certs
if bundle.metadata.source != CertSource::Acme {
continue;
}
let time_until_expiry = bundle.metadata.expires_at.saturating_sub(now);
if time_until_expiry < renewal_threshold {
info!(
"Certificate for {} needs renewal (expires in {} days)",
domain,
time_until_expiry / 86400
);
needs_renewal.push(domain.clone());
}
}
needs_renewal
}
/// Renew a certificate for a domain.
///
/// Performs the full ACME provision+store flow. The `challenge_setup` closure
/// is called to arrange for the HTTP-01 challenge to be served. It receives
/// (token, key_authorization) and must make the challenge response available.
///
/// Returns the new CertBundle on success.
pub async fn renew_domain<F, Fut>(
&mut self,
domain: &str,
challenge_setup: F,
) -> Result<CertBundle, CertManagerError>
where
F: FnOnce(String, String) -> Fut,
Fut: std::future::Future<Output = ()>,
{
let acme_client = self.acme_client()
.ok_or(CertManagerError::NoEmail)?;
info!("Renewing certificate for {}", domain);
let domain_owned = domain.to_string();
let result = acme_client.provision(&domain_owned, |pending| {
let token = pending.token.clone();
let key_auth = pending.key_authorization.clone();
async move {
challenge_setup(token, key_auth).await;
Ok(())
}
}).await.map_err(|e| CertManagerError::AcmeFailure {
domain: domain.to_string(),
message: e.to_string(),
})?;
let (cert_pem, key_pem) = result;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let bundle = CertBundle {
cert_pem,
key_pem,
ca_pem: None,
metadata: CertMetadata {
domain: domain.to_string(),
source: CertSource::Acme,
issued_at: now,
expires_at: now + 90 * 86400,
renewed_at: Some(now),
},
};
self.store.store(domain.to_string(), bundle.clone());
info!("Certificate renewed and stored for {}", domain);
Ok(bundle)
}
/// Whether this manager has an ACME email configured.
pub fn has_acme(&self) -> bool {
self.acme_email.is_some()
}
/// Get reference to the underlying store.
pub fn store(&self) -> &CertStore {
&self.store
}
/// Get mutable reference to the underlying store.
pub fn store_mut(&mut self) -> &mut CertStore {
&mut self.store
}
}

View File

@@ -0,0 +1,174 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
/// Certificate metadata stored alongside certs.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CertMetadata {
pub domain: String,
pub source: CertSource,
pub issued_at: u64,
pub expires_at: u64,
pub renewed_at: Option<u64>,
}
/// How a certificate was obtained.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CertSource {
Acme,
Static,
Custom,
SelfSigned,
}
/// An in-memory certificate bundle.
#[derive(Debug, Clone)]
pub struct CertBundle {
pub key_pem: String,
pub cert_pem: String,
pub ca_pem: Option<String>,
pub metadata: CertMetadata,
}
/// In-memory certificate store.
///
/// All persistence is owned by the consumer (TypeScript side).
/// This struct is a thin HashMap wrapper used as a runtime cache.
pub struct CertStore {
cache: HashMap<String, CertBundle>,
}
impl CertStore {
/// Create a new empty cert store.
pub fn new() -> Self {
Self {
cache: HashMap::new(),
}
}
/// Get a certificate by domain.
pub fn get(&self, domain: &str) -> Option<&CertBundle> {
self.cache.get(domain)
}
/// Store a certificate in the cache.
pub fn store(&mut self, domain: String, bundle: CertBundle) {
self.cache.insert(domain, bundle);
}
/// Check if a certificate exists for a domain.
pub fn has(&self, domain: &str) -> bool {
self.cache.contains_key(domain)
}
/// Get the number of cached certificates.
pub fn count(&self) -> usize {
self.cache.len()
}
/// Iterate over all cached certificates.
pub fn iter(&self) -> impl Iterator<Item = (&String, &CertBundle)> {
self.cache.iter()
}
/// Remove a certificate from the cache.
pub fn remove(&mut self, domain: &str) -> bool {
self.cache.remove(domain).is_some()
}
}
impl Default for CertStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_bundle(domain: &str) -> CertBundle {
CertBundle {
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(),
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(),
ca_pem: None,
metadata: CertMetadata {
domain: domain.to_string(),
source: CertSource::Static,
issued_at: 1700000000,
expires_at: 1700000000 + 90 * 86400,
renewed_at: None,
},
}
}
#[test]
fn test_store_and_get() {
let mut store = CertStore::new();
let bundle = make_test_bundle("example.com");
store.store("example.com".to_string(), bundle.clone());
let loaded = store.get("example.com").unwrap();
assert_eq!(loaded.key_pem, bundle.key_pem);
assert_eq!(loaded.cert_pem, bundle.cert_pem);
assert_eq!(loaded.metadata.domain, "example.com");
assert_eq!(loaded.metadata.source, CertSource::Static);
}
#[test]
fn test_store_with_ca_cert() {
let mut store = CertStore::new();
let mut bundle = make_test_bundle("secure.com");
bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
store.store("secure.com".to_string(), bundle);
let loaded = store.get("secure.com").unwrap();
assert!(loaded.ca_pem.is_some());
}
#[test]
fn test_multiple_certs() {
let mut store = CertStore::new();
store.store("a.com".to_string(), make_test_bundle("a.com"));
store.store("b.com".to_string(), make_test_bundle("b.com"));
store.store("c.com".to_string(), make_test_bundle("c.com"));
assert_eq!(store.count(), 3);
assert!(store.has("a.com"));
assert!(store.has("b.com"));
assert!(store.has("c.com"));
}
#[test]
fn test_remove_cert() {
let mut store = CertStore::new();
store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com"));
assert!(store.has("remove-me.com"));
let removed = store.remove("remove-me.com");
assert!(removed);
assert!(!store.has("remove-me.com"));
}
#[test]
fn test_remove_nonexistent() {
let mut store = CertStore::new();
assert!(!store.remove("nonexistent.com"));
}
#[test]
fn test_wildcard_domain() {
let mut store = CertStore::new();
store.store("*.example.com".to_string(), make_test_bundle("*.example.com"));
assert!(store.has("*.example.com"));
let loaded = store.get("*.example.com").unwrap();
assert_eq!(loaded.metadata.domain, "*.example.com");
}
}

View File

@@ -0,0 +1,13 @@
//! # rustproxy-tls
//!
//! TLS certificate management for RustProxy.
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
pub mod cert_store;
pub mod cert_manager;
pub mod acme;
pub mod sni_resolver;
pub use cert_store::*;
pub use cert_manager::*;
pub use sni_resolver::*;

View File

@@ -0,0 +1,139 @@
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::cert_store::CertBundle;
/// Dynamic SNI-based certificate resolver.
/// Used by the TLS stack to select the right certificate based on client SNI.
pub struct SniResolver {
/// Domain -> certificate bundle mapping
certs: RwLock<HashMap<String, Arc<CertBundle>>>,
/// Fallback certificate (used when no SNI or no match)
fallback: RwLock<Option<Arc<CertBundle>>>,
}
impl SniResolver {
pub fn new() -> Self {
Self {
certs: RwLock::new(HashMap::new()),
fallback: RwLock::new(None),
}
}
/// Register a certificate for a domain.
pub fn add_cert(&self, domain: String, bundle: CertBundle) {
let mut certs = self.certs.write().unwrap();
certs.insert(domain, Arc::new(bundle));
}
/// Set the fallback certificate.
pub fn set_fallback(&self, bundle: CertBundle) {
let mut fallback = self.fallback.write().unwrap();
*fallback = Some(Arc::new(bundle));
}
/// Resolve a certificate for the given SNI domain.
pub fn resolve(&self, domain: &str) -> Option<Arc<CertBundle>> {
let certs = self.certs.read().unwrap();
// Try exact match
if let Some(bundle) = certs.get(domain) {
return Some(Arc::clone(bundle));
}
// Try wildcard match (e.g., *.example.com)
if let Some(dot_pos) = domain.find('.') {
let wildcard = format!("*.{}", &domain[dot_pos + 1..]);
if let Some(bundle) = certs.get(&wildcard) {
return Some(Arc::clone(bundle));
}
}
// Fallback
let fallback = self.fallback.read().unwrap();
fallback.clone()
}
/// Remove a certificate for a domain.
pub fn remove_cert(&self, domain: &str) {
let mut certs = self.certs.write().unwrap();
certs.remove(domain);
}
/// Get the number of registered certificates.
pub fn cert_count(&self) -> usize {
self.certs.read().unwrap().len()
}
}
impl Default for SniResolver {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cert_store::{CertBundle, CertMetadata, CertSource};
fn make_bundle(domain: &str) -> CertBundle {
CertBundle {
key_pem: format!("KEY-{}", domain),
cert_pem: format!("CERT-{}", domain),
ca_pem: None,
metadata: CertMetadata {
domain: domain.to_string(),
source: CertSource::Static,
issued_at: 0,
expires_at: 0,
renewed_at: None,
},
}
}
#[test]
fn test_exact_domain_resolve() {
let resolver = SniResolver::new();
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
let result = resolver.resolve("example.com");
assert!(result.is_some());
assert_eq!(result.unwrap().cert_pem, "CERT-example.com");
}
#[test]
fn test_wildcard_resolve() {
let resolver = SniResolver::new();
resolver.add_cert("*.example.com".to_string(), make_bundle("*.example.com"));
let result = resolver.resolve("sub.example.com");
assert!(result.is_some());
assert_eq!(result.unwrap().cert_pem, "CERT-*.example.com");
}
#[test]
fn test_fallback() {
let resolver = SniResolver::new();
resolver.set_fallback(make_bundle("fallback"));
let result = resolver.resolve("unknown.com");
assert!(result.is_some());
assert_eq!(result.unwrap().cert_pem, "CERT-fallback");
}
#[test]
fn test_no_match_no_fallback() {
let resolver = SniResolver::new();
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
let result = resolver.resolve("other.com");
assert!(result.is_none());
}
#[test]
fn test_remove_cert() {
let resolver = SniResolver::new();
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
assert_eq!(resolver.cert_count(), 1);
resolver.remove_cert("example.com");
assert_eq!(resolver.cert_count(), 0);
assert!(resolver.resolve("example.com").is_none());
}
}

View File

@@ -0,0 +1,44 @@
[package]
name = "rustproxy"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "High-performance multi-protocol proxy built on Pingora, compatible with SmartProxy configuration"
[[bin]]
name = "rustproxy"
path = "src/main.rs"
[lib]
name = "rustproxy"
path = "src/lib.rs"
[dependencies]
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-tls = { workspace = true }
rustproxy-passthrough = { workspace = true }
rustproxy-http = { workspace = true }
rustproxy-nftables = { workspace = true }
rustproxy-metrics = { workspace = true }
rustproxy-security = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
clap = { workspace = true }
anyhow = { workspace = true }
arc-swap = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
rustls = { workspace = true }
tokio-rustls = { workspace = true }
tokio-util = { workspace = true }
dashmap = { workspace = true }
hyper = { workspace = true }
hyper-util = { workspace = true }
http-body-util = { workspace = true }
bytes = { workspace = true }
[dev-dependencies]
rcgen = { workspace = true }

View File

@@ -0,0 +1,177 @@
//! HTTP-01 ACME challenge server.
//!
//! A lightweight HTTP server that serves ACME challenge responses at
//! `/.well-known/acme-challenge/<token>`.
use std::sync::Arc;
use bytes::Bytes;
use dashmap::DashMap;
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, error};
/// ACME HTTP-01 challenge server.
pub struct ChallengeServer {
/// Token -> key authorization mapping
challenges: Arc<DashMap<String, String>>,
/// Cancellation token to stop the server
cancel: CancellationToken,
/// Server task handle
handle: Option<tokio::task::JoinHandle<()>>,
}
impl ChallengeServer {
/// Create a new challenge server (not yet started).
pub fn new() -> Self {
Self {
challenges: Arc::new(DashMap::new()),
cancel: CancellationToken::new(),
handle: None,
}
}
/// Register a challenge token -> key_authorization mapping.
pub fn set_challenge(&self, token: String, key_authorization: String) {
debug!("Registered ACME challenge: token={}", token);
self.challenges.insert(token, key_authorization);
}
/// Remove a challenge token.
pub fn remove_challenge(&self, token: &str) {
self.challenges.remove(token);
}
/// Start the challenge server on the given port.
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = format!("0.0.0.0:{}", port);
let listener = TcpListener::bind(&addr).await?;
info!("ACME challenge server listening on port {}", port);
let challenges = Arc::clone(&self.challenges);
let cancel = self.cancel.clone();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = cancel.cancelled() => {
info!("ACME challenge server stopping");
break;
}
result = listener.accept() => {
match result {
Ok((stream, _)) => {
let challenges = Arc::clone(&challenges);
tokio::spawn(async move {
let io = TokioIo::new(stream);
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
let challenges = Arc::clone(&challenges);
async move {
Self::handle_request(req, &challenges)
}
});
let conn = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service);
if let Err(e) = conn.await {
debug!("Challenge server connection error: {}", e);
}
});
}
Err(e) => {
error!("Challenge server accept error: {}", e);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
}
}
}
});
self.handle = Some(handle);
Ok(())
}
/// Stop the challenge server.
pub async fn stop(&mut self) {
self.cancel.cancel();
if let Some(handle) = self.handle.take() {
let _ = tokio::time::timeout(
std::time::Duration::from_secs(5),
handle,
).await;
}
self.challenges.clear();
self.cancel = CancellationToken::new();
info!("ACME challenge server stopped");
}
/// Handle an HTTP request for ACME challenges.
fn handle_request(
req: Request<Incoming>,
challenges: &DashMap<String, String>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
let path = req.uri().path();
if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
if let Some(key_auth) = challenges.get(token) {
debug!("Serving ACME challenge for token: {}", token);
return Ok(Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/plain")
.body(Full::new(Bytes::from(key_auth.value().clone())))
.unwrap());
}
}
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found")))
.unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_challenge_server_lifecycle() {
let mut server = ChallengeServer::new();
// Set a challenge before starting
server.set_challenge("test-token".to_string(), "test-key-auth".to_string());
// Start on a random port
server.start(19900).await.unwrap();
// Give server a moment to start
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Fetch the challenge
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap();
let io = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move { let _ = conn.await; });
let req = Request::get("/.well-known/acme-challenge/test-token")
.body(Full::new(Bytes::new()))
.unwrap();
let resp = sender.send_request(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
// Test 404 for unknown token
let req = Request::get("/.well-known/acme-challenge/unknown")
.body(Full::new(Bytes::new()))
.unwrap();
let resp = sender.send_request(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
server.stop().await;
}
}

View File

@@ -0,0 +1,972 @@
//! # RustProxy
//!
//! High-performance multi-protocol proxy built on Rust,
//! compatible with SmartProxy configuration.
//!
//! ## Quick Start
//!
//! ```rust,no_run
//! use rustproxy::RustProxy;
//! use rustproxy_config::{RustProxyOptions, create_https_passthrough_route};
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! let options = RustProxyOptions {
//! routes: vec![
//! create_https_passthrough_route("example.com", "backend", 443),
//! ],
//! ..Default::default()
//! };
//!
//! let mut proxy = RustProxy::new(options)?;
//! proxy.start().await?;
//! Ok(())
//! }
//! ```
pub mod challenge_server;
pub mod management;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use arc_swap::ArcSwap;
use anyhow::Result;
use tracing::{info, warn, debug, error};
// Re-export key types
pub use rustproxy_config;
pub use rustproxy_routing;
pub use rustproxy_passthrough;
pub use rustproxy_tls;
pub use rustproxy_http;
pub use rustproxy_nftables;
pub use rustproxy_metrics;
pub use rustproxy_security;
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec, ForwardingEngine};
use rustproxy_routing::RouteManager;
use rustproxy_passthrough::{TcpListenerManager, TlsCertConfig, ConnectionConfig};
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
use rustproxy_nftables::{NftManager, rule_builder};
/// Certificate status.
#[derive(Debug, Clone)]
pub struct CertStatus {
pub domain: String,
pub source: String,
pub expires_at: u64,
pub is_valid: bool,
}
/// The main RustProxy struct.
/// This is the primary public API matching SmartProxy's interface.
pub struct RustProxy {
options: RustProxyOptions,
route_table: ArcSwap<RouteManager>,
listener_manager: Option<TcpListenerManager>,
metrics: Arc<MetricsCollector>,
cert_manager: Option<Arc<tokio::sync::Mutex<CertManager>>>,
challenge_server: Option<challenge_server::ChallengeServer>,
renewal_handle: Option<tokio::task::JoinHandle<()>>,
sampling_handle: Option<tokio::task::JoinHandle<()>>,
nft_manager: Option<NftManager>,
started: bool,
started_at: Option<Instant>,
/// Shared path to a Unix domain socket for relaying socket-handler connections back to TypeScript.
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
/// Dynamically loaded certificates (via loadCertificate IPC), independent of CertManager.
loaded_certs: HashMap<String, TlsCertConfig>,
}
impl RustProxy {
/// Create a new RustProxy instance with the given configuration.
pub fn new(mut options: RustProxyOptions) -> Result<Self> {
// Apply defaults to routes before validation
Self::apply_defaults(&mut options);
// Validate routes
if let Err(errors) = rustproxy_config::validate_routes(&options.routes) {
for err in &errors {
warn!("Route validation error: {}", err);
}
if !errors.is_empty() {
anyhow::bail!("Route validation failed with {} errors", errors.len());
}
}
let route_manager = RouteManager::new(options.routes.clone());
// Set up certificate manager if ACME is configured
let cert_manager = Self::build_cert_manager(&options)
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
let retention = options.metrics.as_ref()
.and_then(|m| m.retention_seconds)
.unwrap_or(3600) as usize;
Ok(Self {
options,
route_table: ArcSwap::from(Arc::new(route_manager)),
listener_manager: None,
metrics: Arc::new(MetricsCollector::with_retention(retention)),
cert_manager,
challenge_server: None,
renewal_handle: None,
sampling_handle: None,
nft_manager: None,
started: false,
started_at: None,
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
loaded_certs: HashMap::new(),
})
}
/// Apply default configuration to routes that lack targets or security.
fn apply_defaults(options: &mut RustProxyOptions) {
let defaults = match &options.defaults {
Some(d) => d.clone(),
None => return,
};
for route in &mut options.routes {
// Apply default target if route has no targets
if route.action.targets.is_none() {
if let Some(ref default_target) = defaults.target {
debug!("Applying default target {}:{} to route {:?}",
default_target.host, default_target.port,
route.name.as_deref().unwrap_or("unnamed"));
route.action.targets = Some(vec![
rustproxy_config::RouteTarget {
target_match: None,
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
port: rustproxy_config::PortSpec::Fixed(default_target.port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}
]);
}
}
// Apply default security if route has no security
if route.security.is_none() {
if let Some(ref default_security) = defaults.security {
let mut security = rustproxy_config::RouteSecurity {
ip_allow_list: None,
ip_block_list: None,
max_connections: default_security.max_connections,
authentication: None,
rate_limit: None,
basic_auth: None,
jwt_auth: None,
};
if let Some(ref allow_list) = default_security.ip_allow_list {
security.ip_allow_list = Some(allow_list.clone());
}
if let Some(ref block_list) = default_security.ip_block_list {
security.ip_block_list = Some(block_list.clone());
}
// Only apply if there's something meaningful
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
debug!("Applying default security to route {:?}",
route.name.as_deref().unwrap_or("unnamed"));
route.security = Some(security);
}
}
}
}
}
/// Build a CertManager from options.
fn build_cert_manager(options: &RustProxyOptions) -> Option<CertManager> {
let acme = options.acme.as_ref()?;
if !acme.enabled.unwrap_or(false) {
return None;
}
let email = acme.email.clone()
.or_else(|| acme.account_email.clone());
let use_production = acme.use_production.unwrap_or(false);
let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
let store = CertStore::new();
Some(CertManager::new(store, email, use_production, renew_before_days))
}
/// Build ConnectionConfig from RustProxyOptions.
fn build_connection_config(options: &RustProxyOptions) -> ConnectionConfig {
ConnectionConfig {
connection_timeout_ms: options.effective_connection_timeout(),
initial_data_timeout_ms: options.effective_initial_data_timeout(),
socket_timeout_ms: options.effective_socket_timeout(),
max_connection_lifetime_ms: options.effective_max_connection_lifetime(),
graceful_shutdown_timeout_ms: options.graceful_shutdown_timeout.unwrap_or(30_000),
max_connections_per_ip: options.max_connections_per_ip,
connection_rate_limit_per_minute: options.connection_rate_limit_per_minute,
keep_alive_treatment: options.keep_alive_treatment.clone(),
keep_alive_inactivity_multiplier: options.keep_alive_inactivity_multiplier,
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
}
}
/// Start the proxy, binding to all configured ports.
pub async fn start(&mut self) -> Result<()> {
if self.started {
anyhow::bail!("Proxy is already started");
}
info!("Starting RustProxy...");
// Auto-provision certificates for routes with certificate: 'auto'
self.auto_provision_certificates().await;
let route_manager = self.route_table.load();
let ports = route_manager.listening_ports();
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len());
// Create TCP listener manager with metrics
let mut listener = TcpListenerManager::with_metrics(
Arc::clone(&*route_manager),
Arc::clone(&self.metrics),
);
// Apply connection config from options
let conn_config = Self::build_connection_config(&self.options);
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
conn_config.connection_timeout_ms,
conn_config.initial_data_timeout_ms,
conn_config.socket_timeout_ms,
conn_config.max_connection_lifetime_ms,
);
listener.set_connection_config(conn_config);
// Share the socket-handler relay path with the listener
listener.set_socket_handler_relay(Arc::clone(&self.socket_handler_relay));
// Extract TLS configurations from routes and cert manager
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
// Also load certs from cert manager into TLS config
if let Some(ref cm) = self.cert_manager {
let cm = cm.lock().await;
for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
}
}
}
// Merge dynamically loaded certs (from loadCertificate IPC)
for (d, c) in &self.loaded_certs {
if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), c.clone());
}
}
if !tls_configs.is_empty() {
debug!("Loaded TLS certificates for {} domains", tls_configs.len());
listener.set_tls_configs(tls_configs);
}
// Bind all ports
for port in &ports {
listener.add_port(*port).await?;
}
self.listener_manager = Some(listener);
self.started = true;
self.started_at = Some(Instant::now());
// Start the throughput sampling task
let metrics = Arc::clone(&self.metrics);
let interval_ms = self.options.metrics.as_ref()
.and_then(|m| m.sample_interval_ms)
.unwrap_or(1000);
self.sampling_handle = Some(tokio::spawn(async move {
let mut interval = tokio::time::interval(
std::time::Duration::from_millis(interval_ms)
);
loop {
interval.tick().await;
metrics.sample_all();
}
}));
// Apply NFTables rules for routes using nftables forwarding engine
self.apply_nftables_rules(&self.options.routes.clone()).await;
// Start renewal timer if ACME is enabled
self.start_renewal_timer();
info!("RustProxy started successfully on ports: {:?}", ports);
Ok(())
}
/// Auto-provision certificates for routes that use certificate: 'auto'.
async fn auto_provision_certificates(&mut self) {
let cm_arc = match self.cert_manager {
Some(ref cm) => Arc::clone(cm),
None => return,
};
let mut domains_to_provision = Vec::new();
for route in &self.options.routes {
let tls_mode = route.tls_mode();
let needs_cert = matches!(
tls_mode,
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
);
if !needs_cert {
continue;
}
let cert_spec = route.action.tls.as_ref()
.and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Auto(_)) = cert_spec {
if let Some(ref domains) = route.route_match.domains {
for domain in domains.to_vec() {
let domain = domain.to_string();
// Skip if we already have a valid cert
let cm = cm_arc.lock().await;
if cm.store().has(&domain) {
debug!("Already have cert for {}, skipping auto-provision", domain);
continue;
}
drop(cm);
domains_to_provision.push(domain);
}
}
}
}
if domains_to_provision.is_empty() {
return;
}
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len());
// Start challenge server
let acme_port = self.options.acme.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let mut challenge_server = challenge_server::ChallengeServer::new();
if let Err(e) = challenge_server.start(acme_port).await {
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e);
return;
}
for domain in &domains_to_provision {
info!("Provisioning certificate for {}", domain);
let cm = cm_arc.lock().await;
let acme_client = cm.acme_client();
drop(cm);
if let Some(acme_client) = acme_client {
let challenge_server_ref = &challenge_server;
let result = acme_client.provision(domain, |pending| {
challenge_server_ref.set_challenge(
pending.token.clone(),
pending.key_authorization.clone(),
);
async move { Ok(()) }
}).await;
match result {
Ok((cert_pem, key_pem)) => {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let bundle = CertBundle {
cert_pem,
key_pem,
ca_pem: None,
metadata: CertMetadata {
domain: domain.clone(),
source: CertSource::Acme,
issued_at: now,
expires_at: now + 90 * 86400, // 90 days
renewed_at: None,
},
};
let mut cm = cm_arc.lock().await;
cm.load_static(domain.clone(), bundle);
info!("Certificate provisioned for {}", domain);
}
Err(e) => {
error!("Failed to provision certificate for {}: {}", domain, e);
}
}
}
}
challenge_server.stop().await;
}
/// Start the renewal timer background task.
/// The background task checks for expiring certificates and renews them.
fn start_renewal_timer(&mut self) {
let cm_arc = match self.cert_manager {
Some(ref cm) => Arc::clone(cm),
None => return,
};
let auto_renew = self.options.acme.as_ref()
.and_then(|a| a.auto_renew)
.unwrap_or(true);
if !auto_renew {
return;
}
let check_interval_hours = self.options.acme.as_ref()
.and_then(|a| a.renew_check_interval_hours)
.unwrap_or(24);
let acme_port = self.options.acme.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let interval = std::time::Duration::from_secs(check_interval_hours as u64 * 3600);
let handle = tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
debug!("Certificate renewal check triggered (interval: {}h)", check_interval_hours);
// Check which domains need renewal
let domains = {
let cm = cm_arc.lock().await;
cm.check_renewals()
};
if domains.is_empty() {
debug!("No certificates need renewal");
continue;
}
info!("Renewing {} certificate(s)", domains.len());
// Start challenge server for renewals
let mut cs = challenge_server::ChallengeServer::new();
if let Err(e) = cs.start(acme_port).await {
error!("Failed to start challenge server for renewal: {}", e);
continue;
}
for domain in &domains {
let cs_ref = &cs;
let mut cm = cm_arc.lock().await;
let result = cm.renew_domain(domain, |token, key_auth| {
cs_ref.set_challenge(token, key_auth);
async {}
}).await;
match result {
Ok(_bundle) => {
info!("Successfully renewed certificate for {}", domain);
}
Err(e) => {
error!("Failed to renew certificate for {}: {}", domain, e);
}
}
}
cs.stop().await;
}
});
self.renewal_handle = Some(handle);
}
/// Stop the proxy gracefully.
pub async fn stop(&mut self) -> Result<()> {
if !self.started {
return Ok(());
}
info!("Stopping RustProxy...");
// Stop sampling task
if let Some(handle) = self.sampling_handle.take() {
handle.abort();
}
// Stop renewal timer
if let Some(handle) = self.renewal_handle.take() {
handle.abort();
}
// Stop challenge server if running
if let Some(ref mut cs) = self.challenge_server {
cs.stop().await;
}
self.challenge_server = None;
// Clean up NFTables rules
if let Some(ref mut nft) = self.nft_manager {
if let Err(e) = nft.cleanup().await {
warn!("NFTables cleanup failed: {}", e);
}
}
self.nft_manager = None;
if let Some(ref mut listener) = self.listener_manager {
listener.graceful_stop().await;
}
self.listener_manager = None;
self.started = false;
info!("RustProxy stopped");
Ok(())
}
/// Update routes atomically (hot-reload).
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
// Validate new routes
rustproxy_config::validate_routes(&routes)
.map_err(|errors| {
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
})?;
let new_manager = RouteManager::new(routes.clone());
let new_ports = new_manager.listening_ports();
info!("Updating routes: {} routes on {} ports",
new_manager.route_count(), new_ports.len());
// Get old ports
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
listener.listening_ports()
} else {
vec![]
};
// Atomically swap the route table
let new_manager = Arc::new(new_manager);
self.route_table.store(Arc::clone(&new_manager));
// Update listener manager
if let Some(ref mut listener) = self.listener_manager {
listener.update_route_manager(Arc::clone(&new_manager));
// Update TLS configs
let mut tls_configs = Self::extract_tls_configs(&routes);
if let Some(ref cm_arc) = self.cert_manager {
let cm = cm_arc.lock().await;
for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
}
}
}
// Merge dynamically loaded certs (from loadCertificate IPC)
for (d, c) in &self.loaded_certs {
if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), c.clone());
}
}
listener.set_tls_configs(tls_configs);
// Add new ports
for port in &new_ports {
if !old_ports.contains(port) {
listener.add_port(*port).await?;
}
}
// Remove old ports no longer needed
for port in &old_ports {
if !new_ports.contains(port) {
listener.remove_port(*port);
}
}
}
// Update NFTables rules: remove old, apply new
self.update_nftables_rules(&routes).await;
self.options.routes = routes;
Ok(())
}
/// Provision a certificate for a named route.
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
let cm_arc = self.cert_manager.as_ref()
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?;
// Find the route by name
let route = self.options.routes.iter()
.find(|r| r.name.as_deref() == Some(route_name))
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
let domain = route.route_match.domains.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain);
// Start challenge server
let acme_port = self.options.acme.as_ref()
.and_then(|a| a.port)
.unwrap_or(80);
let mut cs = challenge_server::ChallengeServer::new();
cs.start(acme_port).await
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
let cs_ref = &cs;
let mut cm = cm_arc.lock().await;
let result = cm.renew_domain(&domain, |token, key_auth| {
cs_ref.set_challenge(token, key_auth);
async {}
}).await;
drop(cm);
cs.stop().await;
let bundle = result
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
// Hot-swap into TLS configs
if let Some(ref mut listener) = self.listener_manager {
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() {
if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
});
}
}
listener.set_tls_configs(tls_configs);
}
info!("Certificate provisioned and loaded for route '{}'", route_name);
Ok(())
}
/// Renew a certificate for a named route.
pub async fn renew_certificate(&mut self, route_name: &str) -> Result<()> {
// Renewal is just re-provisioning
self.provision_certificate(route_name).await
}
/// Get the status of a certificate for a named route.
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
let route = self.options.routes.iter()
.find(|r| r.name.as_deref() == Some(route_name))?;
let domain = route.route_match.domains.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
if let Some(ref cm_arc) = self.cert_manager {
let cm = cm_arc.lock().await;
if let Some(bundle) = cm.get_cert(&domain) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
return Some(CertStatus {
domain,
source: format!("{:?}", bundle.metadata.source),
expires_at: bundle.metadata.expires_at,
is_valid: bundle.metadata.expires_at > now,
});
}
}
None
}
/// Get current metrics snapshot.
pub fn get_metrics(&self) -> Metrics {
self.metrics.snapshot()
}
/// Add a listening port at runtime.
pub async fn add_listening_port(&mut self, port: u16) -> Result<()> {
if let Some(ref mut listener) = self.listener_manager {
listener.add_port(port).await?;
}
Ok(())
}
/// Remove a listening port at runtime.
pub async fn remove_listening_port(&mut self, port: u16) -> Result<()> {
if let Some(ref mut listener) = self.listener_manager {
listener.remove_port(port);
}
Ok(())
}
/// Get all currently listening ports.
pub fn get_listening_ports(&self) -> Vec<u16> {
self.listener_manager
.as_ref()
.map(|l| l.listening_ports())
.unwrap_or_default()
}
/// Get statistics snapshot.
pub fn get_statistics(&self) -> Statistics {
let uptime = self.started_at
.map(|t| t.elapsed().as_secs())
.unwrap_or(0);
Statistics {
active_connections: self.metrics.active_connections(),
total_connections: self.metrics.total_connections(),
routes_count: self.route_table.load().route_count() as u64,
listening_ports: self.get_listening_ports(),
uptime_seconds: uptime,
}
}
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
/// The path is shared with the TcpListenerManager via Arc<RwLock>, so updates
/// take effect immediately for all new connections.
pub fn set_socket_handler_relay_path(&mut self, path: Option<String>) {
info!("Socket handler relay path set to: {:?}", path);
*self.socket_handler_relay.write().unwrap() = path;
}
/// Get the current socket handler relay path.
pub fn get_socket_handler_relay_path(&self) -> Option<String> {
self.socket_handler_relay.read().unwrap().clone()
}
/// Load a certificate for a domain and hot-swap the TLS configuration.
pub async fn load_certificate(
&mut self,
domain: &str,
cert_pem: String,
key_pem: String,
ca_pem: Option<String>,
) -> Result<()> {
info!("Loading certificate for domain: {}", domain);
// Store in cert manager if available
if let Some(ref cm_arc) = self.cert_manager {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let bundle = CertBundle {
cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
ca_pem: ca_pem.clone(),
metadata: CertMetadata {
domain: domain.to_string(),
source: CertSource::Static,
issued_at: now,
expires_at: now + 90 * 86400, // assume 90 days
renewed_at: None,
},
};
let mut cm = cm_arc.lock().await;
cm.load_static(domain.to_string(), bundle);
}
// Persist in loaded_certs so future rebuild calls include this cert
self.loaded_certs.insert(domain.to_string(), TlsCertConfig {
cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
});
// Hot-swap TLS config on the listener
if let Some(ref mut listener) = self.listener_manager {
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
// Add the new cert
tls_configs.insert(domain.to_string(), TlsCertConfig {
cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
});
// Also include all existing certs from cert manager
if let Some(ref cm_arc) = self.cert_manager {
let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() {
if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), TlsCertConfig {
cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
});
}
}
}
// Merge dynamically loaded certs from previous loadCertificate calls
for (d, c) in &self.loaded_certs {
if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), c.clone());
}
}
listener.set_tls_configs(tls_configs);
}
info!("Certificate loaded and TLS config updated for {}", domain);
Ok(())
}
/// Get NFTables status.
pub async fn get_nftables_status(&self) -> Result<HashMap<String, serde_json::Value>> {
match &self.nft_manager {
Some(nft) => Ok(nft.status()),
None => Ok(HashMap::new()),
}
}
/// Apply NFTables rules for routes using the nftables forwarding engine.
async fn apply_nftables_rules(&mut self, routes: &[RouteConfig]) {
let nft_routes: Vec<&RouteConfig> = routes.iter()
.filter(|r| r.action.forwarding_engine.as_ref() == Some(&ForwardingEngine::Nftables))
.collect();
if nft_routes.is_empty() {
return;
}
info!("Applying NFTables rules for {} routes", nft_routes.len());
let table_name = nft_routes.iter()
.find_map(|r| r.action.nftables.as_ref()?.table_name.clone())
.unwrap_or_else(|| "rustproxy".to_string());
let mut nft = NftManager::new(Some(table_name));
for route in &nft_routes {
let route_id = route.id.as_deref()
.or(route.name.as_deref())
.unwrap_or("unnamed");
let nft_options = match &route.action.nftables {
Some(opts) => opts.clone(),
None => rustproxy_config::NfTablesOptions {
preserve_source_ip: None,
protocol: None,
max_rate: None,
priority: None,
table_name: None,
use_ip_sets: None,
use_advanced_nat: None,
},
};
let targets = match &route.action.targets {
Some(targets) => targets,
None => {
warn!("NFTables route '{}' has no targets, skipping", route_id);
continue;
}
};
let source_ports = route.route_match.ports.to_ports();
for target in targets {
let target_host = target.host.first().to_string();
let target_port_spec = &target.port;
for &source_port in &source_ports {
let resolved_port = target_port_spec.resolve(source_port);
let rules = rule_builder::build_dnat_rule(
nft.table_name(),
"prerouting",
source_port,
&target_host,
resolved_port,
&nft_options,
);
let rule_id = format!("{}-{}-{}", route_id, source_port, resolved_port);
if let Err(e) = nft.apply_rules(&rule_id, rules).await {
error!("Failed to apply NFTables rules for route '{}': {}", route_id, e);
}
}
}
}
self.nft_manager = Some(nft);
}
/// Update NFTables rules when routes change.
async fn update_nftables_rules(&mut self, new_routes: &[RouteConfig]) {
// Clean up old rules
if let Some(ref mut nft) = self.nft_manager {
if let Err(e) = nft.cleanup().await {
warn!("NFTables cleanup during update failed: {}", e);
}
}
self.nft_manager = None;
// Apply new rules
self.apply_nftables_rules(new_routes).await;
}
/// Extract TLS configurations from route configs.
fn extract_tls_configs(routes: &[RouteConfig]) -> HashMap<String, TlsCertConfig> {
let mut configs = HashMap::new();
for route in routes {
let tls_mode = route.tls_mode();
let needs_cert = matches!(
tls_mode,
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
);
if !needs_cert {
continue;
}
let cert_spec = route.action.tls.as_ref()
.and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
if let Some(ref domains) = route.route_match.domains {
for domain in domains.to_vec() {
configs.insert(domain.to_string(), TlsCertConfig {
cert_pem: cert_config.cert.clone(),
key_pem: cert_config.key.clone(),
});
}
}
}
}
configs
}
}

View File

@@ -0,0 +1,95 @@
use clap::Parser;
use tracing_subscriber::EnvFilter;
use anyhow::Result;
use rustproxy::RustProxy;
use rustproxy::management;
use rustproxy_config::RustProxyOptions;
/// RustProxy - High-performance multi-protocol proxy
#[derive(Parser, Debug)]
#[command(name = "rustproxy", version, about)]
struct Cli {
/// Path to JSON configuration file
#[arg(short, long, default_value = "config.json")]
config: String,
/// Log level (trace, debug, info, warn, error)
#[arg(short, long, default_value = "info")]
log_level: String,
/// Validate configuration without starting
#[arg(long)]
validate: bool,
/// Run in management mode (JSON-over-stdin IPC for TypeScript wrapper)
#[arg(long)]
management: bool,
}
#[tokio::main]
async fn main() -> Result<()> {
// Install the default CryptoProvider early, before any TLS or ACME code runs.
// This prevents panics from instant-acme/hyper-rustls calling ClientConfig::builder()
// before TLS listeners have started. Idempotent — later calls harmlessly return Err.
let _ = rustls::crypto::ring::default_provider().install_default();
let cli = Cli::parse();
// Initialize tracing - write to stderr so stdout is reserved for management IPC
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
)
.init();
// Management mode: JSON IPC over stdin/stdout
if cli.management {
tracing::info!("RustProxy starting in management mode...");
return management::management_loop().await;
}
tracing::info!("RustProxy starting...");
// Load configuration
let options = RustProxyOptions::from_file(&cli.config)
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
tracing::info!(
"Loaded {} routes from {}",
options.routes.len(),
cli.config
);
// Validate-only mode
if cli.validate {
match rustproxy_config::validate_routes(&options.routes) {
Ok(()) => {
tracing::info!("Configuration is valid");
return Ok(());
}
Err(errors) => {
for err in &errors {
tracing::error!("Validation error: {}", err);
}
anyhow::bail!("{} validation errors found", errors.len());
}
}
}
// Create and start proxy
let mut proxy = RustProxy::new(options)?;
proxy.start().await?;
// Wait for shutdown signal
tracing::info!("RustProxy is running. Press Ctrl+C to stop.");
tokio::signal::ctrl_c().await?;
tracing::info!("Shutdown signal received");
proxy.stop().await?;
tracing::info!("RustProxy shutdown complete");
Ok(())
}

View File

@@ -0,0 +1,470 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tracing::{info, error};
use crate::RustProxy;
use rustproxy_config::RustProxyOptions;
/// A management request from the TypeScript wrapper.
#[derive(Debug, Deserialize)]
pub struct ManagementRequest {
pub id: String,
pub method: String,
#[serde(default)]
pub params: serde_json::Value,
}
/// A management response back to the TypeScript wrapper.
#[derive(Debug, Serialize)]
pub struct ManagementResponse {
pub id: String,
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
/// An unsolicited event from the proxy to the TypeScript wrapper.
#[derive(Debug, Serialize)]
pub struct ManagementEvent {
pub event: String,
pub data: serde_json::Value,
}
impl ManagementResponse {
fn ok(id: String, result: serde_json::Value) -> Self {
Self {
id,
success: true,
result: Some(result),
error: None,
}
}
fn err(id: String, message: String) -> Self {
Self {
id,
success: false,
result: None,
error: Some(message),
}
}
}
fn send_line(line: &str) {
// Use blocking stdout write - we're writing short JSON lines
use std::io::Write;
let stdout = std::io::stdout();
let mut handle = stdout.lock();
let _ = handle.write_all(line.as_bytes());
let _ = handle.write_all(b"\n");
let _ = handle.flush();
}
fn send_response(response: &ManagementResponse) {
match serde_json::to_string(response) {
Ok(json) => send_line(&json),
Err(e) => error!("Failed to serialize management response: {}", e),
}
}
fn send_event(event: &str, data: serde_json::Value) {
let evt = ManagementEvent {
event: event.to_string(),
data,
};
match serde_json::to_string(&evt) {
Ok(json) => send_line(&json),
Err(e) => error!("Failed to serialize management event: {}", e),
}
}
/// Run the management loop, reading JSON commands from stdin and writing responses to stdout.
pub async fn management_loop() -> Result<()> {
let stdin = BufReader::new(tokio::io::stdin());
let mut lines = stdin.lines();
let mut proxy: Option<RustProxy> = None;
send_event("ready", serde_json::json!({}));
loop {
let line = match lines.next_line().await {
Ok(Some(line)) => line,
Ok(None) => {
// stdin closed - parent process exited
info!("Management stdin closed, shutting down");
if let Some(ref mut p) = proxy {
let _ = p.stop().await;
}
break;
}
Err(e) => {
error!("Error reading management stdin: {}", e);
break;
}
};
let line = line.trim().to_string();
if line.is_empty() {
continue;
}
let request: ManagementRequest = match serde_json::from_str(&line) {
Ok(r) => r,
Err(e) => {
error!("Failed to parse management request: {}", e);
// Send error response without an ID
send_response(&ManagementResponse::err(
"unknown".to_string(),
format!("Failed to parse request: {}", e),
));
continue;
}
};
let response = handle_request(&request, &mut proxy).await;
send_response(&response);
}
Ok(())
}
async fn handle_request(
request: &ManagementRequest,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let id = request.id.clone();
match request.method.as_str() {
"start" => handle_start(&id, &request.params, proxy).await,
"stop" => handle_stop(&id, proxy).await,
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
"getMetrics" => handle_get_metrics(&id, proxy),
"getStatistics" => handle_get_statistics(&id, proxy),
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
"getNftablesStatus" => handle_get_nftables_status(&id, proxy).await,
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
_ => ManagementResponse::err(id, format!("Unknown method: {}", request.method)),
}
}
async fn handle_start(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
if proxy.is_some() {
return ManagementResponse::err(id.to_string(), "Proxy is already running".to_string());
}
let config = match params.get("config") {
Some(config) => config,
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
};
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
Ok(o) => o,
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)),
};
match RustProxy::new(options) {
Ok(mut p) => {
match p.start().await {
Ok(()) => {
send_event("started", serde_json::json!({}));
*proxy = Some(p);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => {
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
}
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
}
}
async fn handle_stop(
id: &str,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_mut() {
Some(p) => {
match p.stop().await {
Ok(()) => {
*proxy = None;
send_event("stopped", serde_json::json!({}));
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
}
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
async fn handle_update_routes(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let routes = match params.get("routes") {
Some(routes) => routes,
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()),
};
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
Ok(r) => r,
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid routes: {}", e)),
};
match p.update_routes(routes).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)),
}
}
fn handle_get_metrics(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let metrics = p.get_metrics();
match serde_json::to_value(&metrics) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}
}
fn handle_get_statistics(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let stats = p.get_statistics();
match serde_json::to_value(&stats) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}
}
async fn handle_provision_certificate(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
};
match p.provision_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)),
}
}
async fn handle_renew_certificate(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
};
match p.renew_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)),
}
}
async fn handle_get_certificate_status(
id: &str,
params: &serde_json::Value,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_ref() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name,
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
};
match p.get_certificate_status(route_name).await {
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({
"domain": status.domain,
"source": status.source,
"expiresAt": status.expires_at,
"isValid": status.is_valid,
})),
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
}
}
fn handle_get_listening_ports(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let ports = p.get_listening_ports();
ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": ports }))
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": [] })),
}
}
async fn handle_get_nftables_status(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
match p.get_nftables_status().await {
Ok(status) => {
match serde_json::to_value(&status) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize: {}", e)),
}
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to get status: {}", e)),
}
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
async fn handle_set_socket_handler_relay(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let socket_path = params.get("socketPath")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
info!("setSocketHandlerRelay: socket_path={:?}", socket_path);
p.set_socket_handler_relay_path(socket_path);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
async fn handle_add_listening_port(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
};
match p.add_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)),
}
}
async fn handle_remove_listening_port(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
};
match p.remove_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)),
}
}
async fn handle_load_certificate(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let domain = match params.get("domain").and_then(|v| v.as_str()) {
Some(d) => d.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()),
};
let cert = match params.get("cert").and_then(|v| v.as_str()) {
Some(c) => c.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()),
};
let key = match params.get("key").and_then(|v| v.as_str()) {
Some(k) => k.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()),
};
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string());
info!("loadCertificate: domain={}", domain);
// Load cert into cert manager and hot-swap TLS config
match p.load_certificate(&domain, cert, key, ca).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)),
}
}

View File

@@ -0,0 +1,553 @@
use std::sync::atomic::{AtomicU16, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
/// Atomic port allocator starting at 19000 to avoid collisions.
static PORT_COUNTER: AtomicU16 = AtomicU16::new(19000);
/// Get the next available port for testing.
pub fn next_port() -> u16 {
PORT_COUNTER.fetch_add(1, Ordering::SeqCst)
}
/// Start a simple TCP echo server that echoes back whatever it receives.
/// Returns the join handle for the server task.
pub async fn start_echo_server(port: u16) -> JoinHandle<()> {
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind echo server");
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if stream.write_all(&buf[..n]).await.is_err() {
break;
}
}
});
}
})
}
/// Start a TCP echo server that prefixes responses to identify which backend responded.
pub async fn start_prefix_echo_server(port: u16, prefix: &str) -> JoinHandle<()> {
let prefix = prefix.to_string();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind prefix echo server");
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let pfx = prefix.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
let mut response = pfx.as_bytes().to_vec();
response.extend_from_slice(&buf[..n]);
if stream.write_all(&response).await.is_err() {
break;
}
}
});
}
})
}
/// Start a simple HTTP server that responds with a fixed status and body.
pub async fn start_http_server(port: u16, status: u16, body: &str) -> JoinHandle<()> {
let body = body.to_string();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind HTTP server");
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let b = body.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 8192];
// Read the request
let _n = stream.read(&mut buf).await.unwrap_or(0);
// Send response
let response = format!(
"HTTP/1.1 {} OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
status,
b.len(),
b,
);
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.shutdown().await;
});
}
})
}
/// Start an HTTP backend server that echoes back request details as JSON.
/// The response body contains: {"method":"GET","path":"/foo","host":"example.com","backend":"<name>"}
/// Supports keep-alive by reading HTTP requests properly.
pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandle<()> {
let name = backend_name.to_string();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap_or_else(|_| panic!("Failed to bind HTTP echo backend on port {}", port));
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let backend = name.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 16384];
// Read request data
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
let req_str = String::from_utf8_lossy(&buf[..n]);
// Parse first line: METHOD PATH HTTP/x.x
let first_line = req_str.lines().next().unwrap_or("");
let parts: Vec<&str> = first_line.split_whitespace().collect();
let method = parts.first().copied().unwrap_or("UNKNOWN");
let path = parts.get(1).copied().unwrap_or("/");
// Extract Host header
let host = req_str.lines()
.find(|l| l.to_lowercase().starts_with("host:"))
.map(|l| l[5..].trim())
.unwrap_or("unknown");
let body = format!(
r#"{{"method":"{}","path":"{}","host":"{}","backend":"{}"}}"#,
method, path, host, backend
);
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body,
);
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.shutdown().await;
});
}
})
}
/// Wrap a future with a timeout, preventing tests from hanging.
pub async fn with_timeout<F, T>(future: F, secs: u64) -> Result<T, &'static str>
where
F: std::future::Future<Output = T>,
{
match tokio::time::timeout(std::time::Duration::from_secs(secs), future).await {
Ok(result) => Ok(result),
Err(_) => Err("Test timed out"),
}
}
/// Wait briefly for a server to be ready by attempting TCP connections.
pub async fn wait_for_port(port: u16, timeout_ms: u64) -> bool {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
while start.elapsed() < timeout {
if tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.is_ok()
{
return true;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
false
}
/// Start a TLS HTTP echo backend: accepts TLS, then responds with HTTP JSON
/// containing request details. Combines TLS acceptance with HTTP echo behavior.
pub async fn start_tls_http_backend(
port: u16,
backend_name: &str,
cert_pem: &str,
key_pem: &str,
) -> JoinHandle<()> {
use std::sync::Arc;
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
.expect("Failed to build TLS acceptor");
let acceptor = Arc::new(acceptor);
let name = backend_name.to_string();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap_or_else(|_| panic!("Failed to bind TLS HTTP backend on port {}", port));
tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let acc = acceptor.clone();
let backend = name.clone();
tokio::spawn(async move {
let mut tls_stream = match acc.accept(stream).await {
Ok(s) => s,
Err(_) => return,
};
let mut buf = vec![0u8; 16384];
let n = match tls_stream.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
let req_str = String::from_utf8_lossy(&buf[..n]);
// Parse first line: METHOD PATH HTTP/x.x
let first_line = req_str.lines().next().unwrap_or("");
let parts: Vec<&str> = first_line.split_whitespace().collect();
let method = parts.first().copied().unwrap_or("UNKNOWN");
let path = parts.get(1).copied().unwrap_or("/");
// Extract Host header
let host = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("host:"))
.map(|l| l[5..].trim())
.unwrap_or("unknown");
let body = format!(
r#"{{"method":"{}","path":"{}","host":"{}","backend":"{}"}}"#,
method, path, host, backend
);
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body,
);
let _ = tls_stream.write_all(response.as_bytes()).await;
let _ = tls_stream.shutdown().await;
});
}
})
}
/// Helper to create a minimal route config for testing.
pub fn make_test_route(
port: u16,
domain: Option<&str>,
target_host: &str,
target_port: u16,
) -> rustproxy_config::RouteConfig {
rustproxy_config::RouteConfig {
id: None,
route_match: rustproxy_config::RouteMatch {
ports: rustproxy_config::PortRange::Single(port),
domains: domain.map(|d| rustproxy_config::DomainSpec::Single(d.to_string())),
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: rustproxy_config::RouteAction {
action_type: rustproxy_config::RouteActionType::Forward,
targets: Some(vec![rustproxy_config::RouteTarget {
target_match: None,
host: rustproxy_config::HostSpec::Single(target_host.to_string()),
port: rustproxy_config::PortSpec::Fixed(target_port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: None,
tags: None,
enabled: None,
}
}
/// Start a simple WebSocket echo backend.
///
/// Accepts WebSocket upgrade requests (HTTP Upgrade: websocket), sends 101 back,
/// then echoes all data received on the connection.
pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap_or_else(|_| panic!("Failed to bind WS echo backend on port {}", port));
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
tokio::spawn(async move {
// Read the HTTP upgrade request
let mut buf = vec![0u8; 4096];
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
let req_str = String::from_utf8_lossy(&buf[..n]);
// Extract Sec-WebSocket-Key for proper handshake
let ws_key = req_str.lines()
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
.unwrap_or_default();
// Compute Sec-WebSocket-Accept (simplified - just echo for test purposes)
// Real implementation would compute SHA-1 + base64
let accept_response = format!(
"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: {}\r\n\
\r\n",
ws_key
);
if stream.write_all(accept_response.as_bytes()).await.is_err() {
return;
}
// Echo all data back (raw TCP after upgrade)
let mut echo_buf = vec![0u8; 65536];
loop {
let n = match stream.read(&mut echo_buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if stream.write_all(&echo_buf[..n]).await.is_err() {
break;
}
}
});
}
})
}
/// Generate a self-signed certificate for testing using rcgen.
/// Returns (cert_pem, key_pem).
pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
use rcgen::{CertificateParams, KeyPair};
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
let key_pair = KeyPair::generate().unwrap();
let cert = params.self_signed(&key_pair).unwrap();
(cert.pem(), key_pair.serialize_pem())
}
/// Start a TLS echo server using the given cert/key.
/// Returns the join handle.
pub async fn start_tls_echo_server(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
use std::sync::Arc;
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
.expect("Failed to build TLS acceptor");
let acceptor = Arc::new(acceptor);
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("Failed to bind TLS echo server");
tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let acc = acceptor.clone();
tokio::spawn(async move {
let mut tls_stream = match acc.accept(stream).await {
Ok(s) => s,
Err(_) => return,
};
let mut buf = vec![0u8; 65536];
loop {
let n = match tls_stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if tls_stream.write_all(&buf[..n]).await.is_err() {
break;
}
}
});
}
})
}
/// Helper to create a TLS terminate route with static cert for testing.
pub fn make_tls_terminate_route(
port: u16,
domain: &str,
target_host: &str,
target_port: u16,
cert_pem: &str,
key_pem: &str,
) -> rustproxy_config::RouteConfig {
let mut route = make_test_route(port, Some(domain), target_host, target_port);
route.action.tls = Some(rustproxy_config::RouteTls {
mode: rustproxy_config::TlsMode::Terminate,
certificate: Some(rustproxy_config::CertificateSpec::Static(
rustproxy_config::CertificateConfig {
cert: cert_pem.to_string(),
key: key_pem.to_string(),
ca: None,
key_file: None,
cert_file: None,
},
)),
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
/// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data.
/// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`).
pub async fn start_tls_ws_echo_backend(
port: u16,
cert_pem: &str,
key_pem: &str,
) -> JoinHandle<()> {
use std::sync::Arc;
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
.expect("Failed to build TLS acceptor");
let acceptor = Arc::new(acceptor);
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap_or_else(|_| panic!("Failed to bind TLS WS echo backend on port {}", port));
tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let acc = acceptor.clone();
tokio::spawn(async move {
let mut tls_stream = match acc.accept(stream).await {
Ok(s) => s,
Err(_) => return,
};
// Read the HTTP upgrade request
let mut buf = vec![0u8; 4096];
let n = match tls_stream.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
let req_str = String::from_utf8_lossy(&buf[..n]);
// Extract Sec-WebSocket-Key for handshake
let ws_key = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
.unwrap_or_default();
// Send 101 Switching Protocols
let accept_response = format!(
"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: {}\r\n\
\r\n",
ws_key
);
if tls_stream
.write_all(accept_response.as_bytes())
.await
.is_err()
{
return;
}
// Echo all data back (raw TCP after upgrade)
let mut echo_buf = vec![0u8; 65536];
loop {
let n = match tls_stream.read(&mut echo_buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if tls_stream.write_all(&echo_buf[..n]).await.is_err() {
break;
}
}
});
}
})
}
/// Helper to create a TLS passthrough route for testing.
pub fn make_tls_passthrough_route(
port: u16,
domain: Option<&str>,
target_host: &str,
target_port: u16,
) -> rustproxy_config::RouteConfig {
let mut route = make_test_route(port, domain, target_host, target_port);
route.action.tls = Some(rustproxy_config::RouteTls {
mode: rustproxy_config::TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}

View File

@@ -0,0 +1,752 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
/// Send a raw HTTP request and return the full response as a string.
async fn send_http_request(port: u16, host: &str, method: &str, path: &str) -> String {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let request = format!(
"{} {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
method, path, host,
);
stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}
/// Extract the body from a raw HTTP response string (after the \r\n\r\n).
fn extract_body(response: &str) -> &str {
response.split("\r\n\r\n").nth(1).unwrap_or("")
}
#[tokio::test]
async fn test_http_forward_basic() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_http_echo_backend(backend_port, "main").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
let body = extract_body(&response);
body.to_string()
}, 10)
.await
.unwrap();
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result);
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result);
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_host_routing() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let _b1 = start_http_echo_backend(backend1_port, "alpha").await;
let _b2 = start_http_echo_backend(backend2_port, "beta").await;
let options = RustProxyOptions {
routes: vec![
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port),
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let alpha_result = with_timeout(async {
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result);
// Test beta domain
let beta_result = with_timeout(async {
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_path_routing() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let _b1 = start_http_echo_backend(backend1_port, "api").await;
let _b2 = start_http_echo_backend(backend2_port, "web").await;
let mut api_route = make_test_route(proxy_port, None, "127.0.0.1", backend1_port);
api_route.route_match.path = Some("/api/**".to_string());
api_route.priority = Some(10);
let web_route = make_test_route(proxy_port, None, "127.0.0.1", backend2_port);
let options = RustProxyOptions {
routes: vec![api_route, web_route],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Test API path
let api_result = with_timeout(async {
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result);
// Test web path (no /api prefix)
let web_result = with_timeout(async {
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_cors_preflight() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_http_echo_backend(backend_port, "main").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
// Send CORS preflight request
let request = format!(
"OPTIONS /api/data HTTP/1.1\r\nHost: example.com\r\nOrigin: http://localhost:3000\r\nAccess-Control-Request-Method: POST\r\nConnection: close\r\n\r\n",
);
stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
.await
.unwrap();
// Should get 204 No Content with CORS headers
assert!(result.contains("204"), "Expected 204 status, got: {}", result);
assert!(result.to_lowercase().contains("access-control-allow-origin"),
"Expected CORS header, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_backend_error() {
let backend_port = next_port();
let proxy_port = next_port();
// Start an HTTP server that returns 500
let _backend = start_http_server(backend_port, 500, "Internal Error").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
response
}, 10)
.await
.unwrap();
// Proxy should relay the 500 from backend
assert!(result.contains("500"), "Expected 500 status, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_no_route_matched() {
let proxy_port = next_port();
// Create a route only for a specific domain
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
response
}, 10)
.await
.unwrap();
// Should get 502 Bad Gateway (no route matched)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_http_forward_backend_unavailable() {
let proxy_port = next_port();
let dead_port = next_port(); // No server running here
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
response
}, 10)
.await
.unwrap();
// Should get 502 Bad Gateway (backend unavailable)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_https_terminate_http_forward() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "httpproxy.example.com";
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
let _backend = start_http_echo_backend(backend_port, "tls-backend").await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send HTTP request through TLS
let request = format!(
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
domain
);
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
.await
.unwrap();
let body = extract_body(&result);
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body);
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body);
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_websocket_through_proxy() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_ws_echo_backend(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
// Send WebSocket upgrade request
let request = format!(
"GET /ws HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"
);
stream.write_all(request.as_bytes()).await.unwrap();
// Read the 101 response
let mut response_buf = Vec::with_capacity(4096);
let mut temp = [0u8; 1];
loop {
let n = stream.read(&mut temp).await.unwrap();
if n == 0 { break; }
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len-4..] == *b"\r\n\r\n" {
break;
}
}
}
let response_str = String::from_utf8_lossy(&response_buf).to_string();
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str);
assert!(
response_str.to_lowercase().contains("upgrade: websocket"),
"Expected Upgrade header, got: {}",
response_str
);
// After upgrade, send data and verify echo
let test_data = b"Hello WebSocket!";
stream.write_all(test_data).await.unwrap();
// Read echoed data
let mut echo_buf = vec![0u8; 256];
let n = stream.read(&mut echo_buf).await.unwrap();
let echoed = &echo_buf[..n];
assert_eq!(echoed, test_data, "Expected echo of sent data");
"ok".to_string()
}, 10)
.await
.unwrap();
assert_eq!(result, "ok");
proxy.stop().await.unwrap();
}
/// Test that terminate-and-reencrypt mode routes HTTP traffic through the
/// full HTTP proxy with per-request Host-based routing.
///
/// This verifies the new behavior: after TLS termination, HTTP data is detected
/// and routed through HttpProxyService (like nginx) instead of being blindly tunneled.
#[tokio::test]
async fn test_terminate_and_reencrypt_http_routing() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let (cert1, key1) = generate_self_signed_cert("alpha.example.com");
let (cert2, key2) = generate_self_signed_cert("beta.example.com");
// Generate separate backend certs (backends are independent TLS servers)
let (backend_cert1, backend_key1) = generate_self_signed_cert("localhost");
let (backend_cert2, backend_key2) = generate_self_signed_cert("localhost");
// Start TLS HTTP echo backends (proxy re-encrypts to these)
let _b1 = start_tls_http_backend(backend1_port, "alpha", &backend_cert1, &backend_key1).await;
let _b2 = start_tls_http_backend(backend2_port, "beta", &backend_cert2, &backend_key2).await;
// Create terminate-and-reencrypt routes
let mut route1 = make_tls_terminate_route(
proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1,
);
route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
let mut route2 = make_tls_terminate_route(
proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2,
);
route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
let options = RustProxyOptions {
routes: vec![route1, route2],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain - HTTP request through TLS terminate-and-reencrypt
let alpha_result = with_timeout(async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
.await
.unwrap();
let alpha_body = extract_body(&alpha_result);
assert!(
alpha_body.contains(r#""backend":"alpha"#),
"Expected alpha backend, got: {}",
alpha_body
);
assert!(
alpha_body.contains(r#""method":"GET"#),
"Expected GET method, got: {}",
alpha_body
);
assert!(
alpha_body.contains(r#""path":"/api/data"#),
"Expected /api/data path, got: {}",
alpha_body
);
// Verify original Host header is preserved (not replaced with backend IP:port)
assert!(
alpha_body.contains(r#""host":"alpha.example.com"#),
"Expected original Host header alpha.example.com, got: {}",
alpha_body
);
// Test beta domain - different host goes to different backend
let beta_result = with_timeout(async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
.await
.unwrap();
let beta_body = extract_body(&beta_result);
assert!(
beta_body.contains(r#""backend":"beta"#),
"Expected beta backend, got: {}",
beta_body
);
assert!(
beta_body.contains(r#""path":"/other"#),
"Expected /other path, got: {}",
beta_body
);
// Verify original Host header is preserved for beta too
assert!(
beta_body.contains(r#""host":"beta.example.com"#),
"Expected original Host header beta.example.com, got: {}",
beta_body
);
proxy.stop().await.unwrap();
}
/// Test that WebSocket upgrade works through terminate-and-reencrypt mode.
///
/// Verifies the full chain: client→TLS→proxy terminates→re-encrypts→TLS→backend WebSocket.
/// The proxy's `handle_websocket_upgrade` checks `upstream.use_tls` and calls
/// `connect_tls_backend()` when true. This test covers that path.
#[tokio::test]
async fn test_terminate_and_reencrypt_websocket() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "ws.example.com";
// Frontend cert (client→proxy TLS)
let (frontend_cert, frontend_key) = generate_self_signed_cert(domain);
// Backend cert (proxy→backend TLS)
let (backend_cert, backend_key) = generate_self_signed_cert("localhost");
// Start TLS WebSocket echo backend
let _backend = start_tls_ws_echo_backend(backend_port, &backend_cert, &backend_key).await;
// Create terminate-and-reencrypt route
let mut route = make_tls_terminate_route(
proxy_port,
domain,
"127.0.0.1",
backend_port,
&frontend_cert,
&frontend_key,
);
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
let options = RustProxyOptions {
routes: vec![route],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(
async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector =
tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name =
rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send WebSocket upgrade request through TLS
let request = format!(
"GET /ws HTTP/1.1\r\n\
Host: {}\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n",
domain
);
tls_stream.write_all(request.as_bytes()).await.unwrap();
// Read the 101 response (byte-by-byte until \r\n\r\n)
let mut response_buf = Vec::with_capacity(4096);
let mut temp = [0u8; 1];
loop {
let n = tls_stream.read(&mut temp).await.unwrap();
if n == 0 {
break;
}
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len - 4..] == *b"\r\n\r\n" {
break;
}
}
}
let response_str = String::from_utf8_lossy(&response_buf).to_string();
assert!(
response_str.contains("101"),
"Expected 101 Switching Protocols, got: {}",
response_str
);
assert!(
response_str.to_lowercase().contains("upgrade: websocket"),
"Expected Upgrade header, got: {}",
response_str
);
// After upgrade, send data and verify echo
let test_data = b"Hello TLS WebSocket!";
tls_stream.write_all(test_data).await.unwrap();
// Read echoed data
let mut echo_buf = vec![0u8; 256];
let n = tls_stream.read(&mut echo_buf).await.unwrap();
let echoed = &echo_buf[..n];
assert_eq!(echoed, test_data, "Expected echo of sent data");
"ok".to_string()
},
10,
)
.await
.unwrap();
assert_eq!(result, "ok");
proxy.stop().await.unwrap();
}
/// Test that the protocol field on route config is accepted and processed.
#[tokio::test]
async fn test_protocol_field_in_route_config() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_http_echo_backend(backend_port, "main").await;
// Create a route with protocol: "http" - should only match HTTP traffic
let mut route = make_test_route(proxy_port, None, "127.0.0.1", backend_port);
route.route_match.protocol = Some("http".to_string());
let options = RustProxyOptions {
routes: vec![route],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// HTTP request should match the route and get proxied
let result = with_timeout(async {
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
extract_body(&response).to_string()
}, 10)
.await
.unwrap();
assert!(
result.contains(r#""backend":"main"#),
"Expected main backend, got: {}",
result
);
assert!(
result.contains(r#""path":"/test"#),
"Expected /test path, got: {}",
result
);
proxy.stop().await.unwrap();
}
/// InsecureVerifier for test TLS client connections.
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
]
}
}

View File

@@ -0,0 +1,250 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn test_start_and_stop() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
// Not listening before start
assert!(!wait_for_port(port, 200).await);
proxy.start().await.unwrap();
assert!(wait_for_port(port, 2000).await, "Port should be listening after start");
proxy.stop().await.unwrap();
// Give the OS a moment to release the port
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop");
}
#[tokio::test]
async fn test_double_start_fails() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
// Second start should fail
let result = proxy.start().await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("already started"));
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_update_routes_hot_reload() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
// Update routes atomically
let new_routes = vec![
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090),
];
let result = proxy.update_routes(new_routes).await;
assert!(result.is_ok());
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_add_remove_listening_port() {
let port1 = next_port();
let port2 = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port1, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(port1, 2000).await);
// Add a new port
proxy.add_listening_port(port2).await.unwrap();
assert!(wait_for_port(port2, 2000).await, "New port should be listening");
// Remove the port
proxy.remove_listening_port(port2).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening");
// Original port should still be listening
assert!(wait_for_port(port1, 200).await, "Original port should still be listening");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_get_statistics() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
let stats = proxy.get_statistics();
assert_eq!(stats.routes_count, 1);
assert!(stats.listening_ports.contains(&port));
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_invalid_routes_rejected() {
let options = RustProxyOptions {
routes: vec![{
let mut route = make_test_route(80, None, "127.0.0.1", 8080);
route.action.targets = None; // Invalid: forward without targets
route
}],
..Default::default()
};
let result = RustProxy::new(options);
assert!(result.is_err());
}
#[tokio::test]
async fn test_metrics_track_connections() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// No connections yet
let stats = proxy.get_statistics();
assert_eq!(stats.total_connections, 0);
// Make a connection and send data
{
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"hello").await.unwrap();
let mut buf = vec![0u8; 16];
let _ = stream.read(&mut buf).await;
}
// Small delay for metrics to update
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_metrics_track_bytes() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_http_echo_backend(backend_port, "metrics-test").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Send HTTP request through proxy
{
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let request = b"GET /test HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n";
stream.write_all(request).await.unwrap();
let mut response = Vec::new();
stream.read_to_end(&mut response).await.unwrap();
assert!(!response.is_empty(), "Expected non-empty response");
}
// Small delay for metrics to update
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0,
"Expected some connections tracked, got {}", stats.total_connections);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_hot_reload_port_changes() {
let port1 = next_port();
let port2 = next_port();
let backend_port = next_port();
let _backend = start_echo_server(backend_port).await;
// Start with port1
let options = RustProxyOptions {
routes: vec![make_test_route(port1, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(port1, 2000).await);
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet");
// Update routes to use port2 instead
let new_routes = vec![
make_test_route(port2, None, "127.0.0.1", backend_port),
];
proxy.update_routes(new_routes).await.unwrap();
// Port2 should now be listening, port1 should be closed
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload");
// Verify port2 works
let ports = proxy.get_listening_ports();
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports);
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports);
proxy.stop().await.unwrap();
}

View File

@@ -0,0 +1,197 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
#[tokio::test]
async fn test_tcp_forward_echo() {
let backend_port = next_port();
let proxy_port = next_port();
// Start echo backend
let _backend = start_echo_server(backend_port).await;
// Configure proxy
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
// Wait for proxy to be ready
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready");
// Connect and send data
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"hello world").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert_eq!(result, "hello world");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_large_payload() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
// Send 1MB of data
let data = vec![b'A'; 1_000_000];
stream.write_all(&data).await.unwrap();
stream.shutdown().await.unwrap();
// Read all back
let mut received = Vec::new();
stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 10)
.await
.unwrap();
assert_eq!(result, 1_000_000);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_multiple_connections() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
handles.push(tokio::spawn(async move {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let msg = format!("connection-{}", i);
stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}));
}
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
results
}, 10)
.await
.unwrap();
assert_eq!(result.len(), 10);
for (i, r) in result.iter().enumerate() {
assert_eq!(r, &format!("connection-{}", i));
}
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_backend_unreachable() {
let proxy_port = next_port();
let dead_port = next_port(); // No server on this port
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Connection should complete (proxy accepts it) but data should not flow
let result = with_timeout(async {
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
stream.is_ok()
}, 5)
.await
.unwrap();
assert!(result, "Should be able to connect to proxy even if backend is down");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tcp_forward_bidirectional() {
let backend_port = next_port();
let proxy_port = next_port();
// Start a prefix echo server to verify data flows in both directions
let _backend = start_prefix_echo_server(backend_port, "REPLY:").await;
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
stream.write_all(b"test data").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert_eq!(result, "REPLY:test data");
proxy.stop().await.unwrap();
}

View File

@@ -0,0 +1,247 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
/// Build a minimal TLS ClientHello with the given SNI domain.
/// This is enough for the proxy's SNI parser to extract the domain.
fn build_client_hello(domain: &str) -> Vec<u8> {
let domain_bytes = domain.as_bytes();
let sni_length = domain_bytes.len() as u16;
// Server Name extension (type 0x0000)
let mut sni_ext = Vec::new();
sni_ext.extend_from_slice(&[0x00, 0x00]); // extension type: server_name
let sni_list_len = sni_length + 5; // 2 (list len) + 1 (type) + 2 (name len) + name
sni_ext.extend_from_slice(&(sni_list_len as u16).to_be_bytes()); // extension data length
sni_ext.extend_from_slice(&((sni_list_len - 2) as u16).to_be_bytes()); // server name list length
sni_ext.push(0x00); // host_name type
sni_ext.extend_from_slice(&sni_length.to_be_bytes());
sni_ext.extend_from_slice(domain_bytes);
let extensions_length = sni_ext.len() as u16;
// ClientHello message
let mut client_hello = Vec::new();
client_hello.extend_from_slice(&[0x03, 0x03]); // TLS 1.2 version
client_hello.extend_from_slice(&[0x00; 32]); // random
client_hello.push(0x00); // session_id length
client_hello.extend_from_slice(&[0x00, 0x02, 0x00, 0xff]); // cipher suites (1 suite)
client_hello.extend_from_slice(&[0x01, 0x00]); // compression methods (null)
client_hello.extend_from_slice(&extensions_length.to_be_bytes());
client_hello.extend_from_slice(&sni_ext);
let hello_len = client_hello.len() as u32;
// Handshake wrapper (type 1 = ClientHello)
let mut handshake = Vec::new();
handshake.push(0x01); // ClientHello
handshake.extend_from_slice(&hello_len.to_be_bytes()[1..4]); // 3-byte length
handshake.extend_from_slice(&client_hello);
let hs_len = handshake.len() as u16;
// TLS record
let mut record = Vec::new();
record.push(0x16); // ContentType: Handshake
record.extend_from_slice(&[0x03, 0x01]); // TLS 1.0 (record version)
record.extend_from_slice(&hs_len.to_be_bytes());
record.extend_from_slice(&handshake);
record
}
#[tokio::test]
async fn test_tls_passthrough_sni_routing() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let _b1 = start_prefix_echo_server(backend1_port, "BACKEND1:").await;
let _b2 = start_prefix_echo_server(backend2_port, "BACKEND2:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Send a fake ClientHello with SNI "one.example.com"
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("one.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
// Backend1 should have received the ClientHello and prefixed its response
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
// Now test routing to backend2
let result2 = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("two.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_passthrough_unknown_sni() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Send ClientHello with unknown SNI - should get no response (connection dropped)
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("unknown.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
// Should either get 0 bytes (closed) or an error
match stream.read(&mut buf).await {
Ok(0) => true, // Connection closed = no route matched
Ok(_) => false, // Got data = route shouldn't have matched
Err(_) => true, // Error = connection dropped
}
}, 5)
.await
.unwrap();
assert!(result, "Unknown SNI should result in dropped connection");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_passthrough_wildcard_domain() {
let backend_port = next_port();
let proxy_port = next_port();
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Should match any subdomain of example.com
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello("anything.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_passthrough_multiple_domains() {
let b1_port = next_port();
let b2_port = next_port();
let b3_port = next_port();
let proxy_port = next_port();
let _b1 = start_prefix_echo_server(b1_port, "B1:").await;
let _b2 = start_prefix_echo_server(b2_port, "B2:").await;
let _b3 = start_prefix_echo_server(b3_port, "B3:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", b1_port),
make_tls_passthrough_route(proxy_port, Some("beta.example.com"), "127.0.0.1", b2_port),
make_tls_passthrough_route(proxy_port, Some("gamma.example.com"), "127.0.0.1", b3_port),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
for (domain, expected_prefix) in [
("alpha.example.com", "B1:"),
("beta.example.com", "B2:"),
("gamma.example.com", "B3:"),
] {
let result = with_timeout(async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let hello = build_client_hello(domain);
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
.await
.unwrap();
assert!(
result.starts_with(expected_prefix),
"Domain {} should route to {}, got: {}",
domain, expected_prefix, result
);
}
proxy.stop().await.unwrap();
}

View File

@@ -0,0 +1,324 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
/// Create a rustls client config that trusts self-signed certs.
fn make_insecure_tls_client_config() -> Arc<rustls::ClientConfig> {
let _ = rustls::crypto::ring::default_provider().install_default();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
Arc::new(config)
}
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
]
}
}
#[tokio::test]
async fn test_tls_terminate_basic() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "test.example.com";
// Generate self-signed cert
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
// Start plain TCP echo backend (proxy terminates TLS, sends plain to backend)
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Connect with TLS client
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello TLS").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
.await
.unwrap();
assert_eq!(result, "hello TLS");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_and_reencrypt() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "reencrypt.example.com";
let backend_domain = "backend.internal";
// Generate certs
let (proxy_cert, proxy_key) = generate_self_signed_cert(domain);
let (backend_cert, backend_key) = generate_self_signed_cert(backend_domain);
// Start TLS echo backend
let _backend = start_tls_echo_server(backend_port, &backend_cert, &backend_key).await;
// Create terminate-and-reencrypt route
let mut route = make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key,
);
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
let options = RustProxyOptions {
routes: vec![route],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello reencrypt").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
.await
.unwrap();
assert_eq!(result, "hello reencrypt");
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_sni_cert_selection() {
let backend1_port = next_port();
let backend2_port = next_port();
let proxy_port = next_port();
let (cert1, key1) = generate_self_signed_cert("alpha.example.com");
let (cert2, key2) = generate_self_signed_cert("beta.example.com");
let _b1 = start_prefix_echo_server(backend1_port, "ALPHA:").await;
let _b2 = start_prefix_echo_server(backend2_port, "BETA:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1),
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2),
],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"test").await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
.await
.unwrap();
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_large_payload() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "large.example.com";
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send 1MB of data
let data = vec![b'X'; 1_000_000];
tls_stream.write_all(&data).await.unwrap();
tls_stream.shutdown().await.unwrap();
let mut received = Vec::new();
tls_stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 15)
.await
.unwrap();
assert_eq!(result, 1_000_000);
proxy.stop().await.unwrap();
}
#[tokio::test]
async fn test_tls_terminate_concurrent() {
let backend_port = next_port();
let proxy_port = next_port();
let domain = "concurrent.example.com";
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
)],
..Default::default()
};
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
let dom = domain.to_string();
handles.push(tokio::spawn(async move {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let msg = format!("conn-{}", i);
tls_stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}));
}
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap());
}
results
}, 15)
.await
.unwrap();
assert_eq!(result.len(), 10);
for (i, r) in result.iter().enumerate() {
assert_eq!(r, &format!("conn-{}", i));
}
proxy.stop().await.unwrap();
}

View File

@@ -1,13 +1,37 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as plugins from '../ts/plugins.js';
import * as http from 'http';
import { SmartProxy, SocketHandlers } from '../ts/index.js';
/**
* Helper to make HTTP requests using Node's http module (unlike fetch/undici,
* http.request doesn't keep the event loop alive via a connection pool).
*/
function httpRequest(url: string, options: { method?: string; headers?: Record<string, string> } = {}): Promise<{ status: number; headers: http.IncomingHttpHeaders; body: string }> {
return new Promise((resolve, reject) => {
const parsed = new URL(url);
const req = http.request({
hostname: parsed.hostname,
port: parsed.port,
path: parsed.pathname + parsed.search,
method: options.method || 'GET',
headers: options.headers,
}, (res) => {
let body = '';
res.on('data', (chunk: Buffer) => { body += chunk.toString(); });
res.on('end', () => resolve({ status: res.statusCode!, headers: res.headers, body }));
});
req.on('error', reject);
req.end();
});
}
tap.test('should handle HTTP requests on port 80 for ACME challenges', async (tools) => {
tools.timeout(10000);
// Track HTTP requests that are handled
const handledRequests: any[] = [];
const settings = {
routes: [
{
@@ -24,7 +48,7 @@ tap.test('should handle HTTP requests on port 80 for ACME challenges', async (to
method: req.method,
headers: req.headers
});
// Simulate ACME challenge response
const token = req.url?.split('/').pop() || '';
res.header('Content-Type', 'text/plain');
@@ -34,40 +58,31 @@ tap.test('should handle HTTP requests on port 80 for ACME challenges', async (to
}
]
};
const proxy = new SmartProxy(settings);
// Mock NFTables manager
(proxy as any).nftablesManager = {
ensureNFTablesSetup: async () => {},
stop: async () => {}
};
await proxy.start();
// Make an HTTP request to the challenge endpoint
const response = await fetch('http://localhost:18080/.well-known/acme-challenge/test-token', {
method: 'GET'
});
const response = await httpRequest('http://localhost:18080/.well-known/acme-challenge/test-token');
// Verify response
expect(response.status).toEqual(200);
const body = await response.text();
expect(body).toEqual('challenge-response-for-test-token');
expect(response.body).toEqual('challenge-response-for-test-token');
// Verify request was handled
expect(handledRequests.length).toEqual(1);
expect(handledRequests[0].path).toEqual('/.well-known/acme-challenge/test-token');
expect(handledRequests[0].method).toEqual('GET');
await proxy.stop();
});
tap.test('should parse HTTP headers correctly', async (tools) => {
tools.timeout(10000);
const capturedContext: any = {};
const settings = {
routes: [
{
@@ -92,36 +107,30 @@ tap.test('should parse HTTP headers correctly', async (tools) => {
}
]
};
const proxy = new SmartProxy(settings);
// Mock NFTables manager
(proxy as any).nftablesManager = {
ensureNFTablesSetup: async () => {},
stop: async () => {}
};
await proxy.start();
// Make request with custom headers
const response = await fetch('http://localhost:18081/test', {
const response = await httpRequest('http://localhost:18081/test', {
method: 'POST',
headers: {
'X-Custom-Header': 'test-value',
'User-Agent': 'test-agent'
}
});
expect(response.status).toEqual(200);
const body = await response.json();
const body = JSON.parse(response.body);
// Verify headers were parsed correctly
expect(capturedContext.headers['x-custom-header']).toEqual('test-value');
expect(capturedContext.headers['user-agent']).toEqual('test-agent');
expect(capturedContext.method).toEqual('POST');
expect(capturedContext.path).toEqual('/test');
await proxy.stop();
});
export default tap.start();
export default tap.start();

View File

@@ -1,218 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import { SmartProxy } from '../ts/index.js';
import * as plugins from '../ts/plugins.js';
/**
* Test that verifies ACME challenge routes are properly created
*/
tap.test('should create ACME challenge route', async (tools) => {
tools.timeout(5000);
// Create a challenge route manually to test its structure
const challengeRoute = {
name: 'acme-challenge',
priority: 1000,
match: {
ports: 18080,
path: '/.well-known/acme-challenge/*'
},
action: {
type: 'socket-handler' as const,
socketHandler: (socket: any, context: any) => {
socket.once('data', (data: Buffer) => {
const request = data.toString();
const lines = request.split('\r\n');
const [method, path] = lines[0].split(' ');
const token = path?.split('/').pop() || '';
const response = [
'HTTP/1.1 200 OK',
'Content-Type: text/plain',
`Content-Length: ${token.length}`,
'Connection: close',
'',
token
].join('\r\n');
socket.write(response);
socket.end();
});
}
}
};
// Test that the challenge route has the correct structure
expect(challengeRoute).toBeDefined();
expect(challengeRoute.match.path).toEqual('/.well-known/acme-challenge/*');
expect(challengeRoute.match.ports).toEqual(18080);
expect(challengeRoute.action.type).toEqual('socket-handler');
expect(challengeRoute.priority).toEqual(1000);
// Create a proxy with the challenge route
const settings = {
routes: [
{
name: 'secure-route',
match: {
ports: [18443],
domains: 'test.local'
},
action: {
type: 'forward' as const,
targets: [{ host: 'localhost', port: 8080 }]
}
},
challengeRoute
]
};
const proxy = new SmartProxy(settings);
// Mock NFTables manager
(proxy as any).nftablesManager = {
ensureNFTablesSetup: async () => {},
stop: async () => {}
};
// Mock certificate manager to prevent real ACME initialization
(proxy as any).createCertificateManager = async function() {
return {
setUpdateRoutesCallback: () => {},
setHttpProxy: () => {},
setGlobalAcmeDefaults: () => {},
setAcmeStateManager: () => {},
initialize: async () => {},
provisionAllCertificates: async () => {},
stop: async () => {},
getAcmeOptions: () => ({}),
getState: () => ({ challengeRouteActive: false })
};
};
await proxy.start();
// Verify the challenge route is in the proxy's routes
const proxyRoutes = proxy.routeManager.getRoutes();
const foundChallengeRoute = proxyRoutes.find((r: any) => r.name === 'acme-challenge');
expect(foundChallengeRoute).toBeDefined();
expect(foundChallengeRoute?.match.path).toEqual('/.well-known/acme-challenge/*');
await proxy.stop();
});
tap.test('should handle HTTP request parsing correctly', async (tools) => {
tools.timeout(5000);
let handlerCalled = false;
let receivedContext: any;
let parsedRequest: any = {};
const settings = {
routes: [
{
name: 'test-static',
match: {
ports: [18090],
path: '/test/*'
},
action: {
type: 'socket-handler' as const,
socketHandler: (socket, context) => {
handlerCalled = true;
receivedContext = context;
// Parse HTTP request from socket
socket.once('data', (data) => {
const request = data.toString();
const lines = request.split('\r\n');
const [method, path, protocol] = lines[0].split(' ');
// Parse headers
const headers: any = {};
for (let i = 1; i < lines.length; i++) {
if (lines[i] === '') break;
const [key, value] = lines[i].split(': ');
if (key && value) {
headers[key.toLowerCase()] = value;
}
}
// Store parsed request data
parsedRequest = { method, path, headers };
// Send HTTP response
const response = [
'HTTP/1.1 200 OK',
'Content-Type: text/plain',
'Content-Length: 2',
'Connection: close',
'',
'OK'
].join('\r\n');
socket.write(response);
socket.end();
});
}
}
}
]
};
const proxy = new SmartProxy(settings);
// Mock NFTables manager
(proxy as any).nftablesManager = {
ensureNFTablesSetup: async () => {},
stop: async () => {}
};
await proxy.start();
// Create a simple HTTP request
const client = new plugins.net.Socket();
await new Promise<void>((resolve, reject) => {
client.connect(18090, 'localhost', () => {
// Send HTTP request
const request = [
'GET /test/example HTTP/1.1',
'Host: localhost:18090',
'User-Agent: test-client',
'',
''
].join('\r\n');
client.write(request);
// Wait for response
client.on('data', (data) => {
const response = data.toString();
expect(response).toContain('HTTP/1.1 200');
expect(response).toContain('OK');
client.end();
resolve();
});
});
client.on('error', reject);
});
// Verify handler was called
expect(handlerCalled).toBeTrue();
expect(receivedContext).toBeDefined();
// The context passed to socket handlers is IRouteContext, not HTTP request data
expect(receivedContext.port).toEqual(18090);
expect(receivedContext.routeName).toEqual('test-static');
// Verify the parsed HTTP request data
expect(parsedRequest.path).toEqual('/test/example');
expect(parsedRequest.method).toEqual('GET');
expect(parsedRequest.headers.host).toEqual('localhost:18090');
await proxy.stop();
});
export default tap.start();

View File

@@ -1,188 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { AcmeStateManager } from '../ts/proxies/smart-proxy/acme-state-manager.js';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
tap.test('AcmeStateManager should track challenge routes correctly', async (tools) => {
const stateManager = new AcmeStateManager();
const challengeRoute: IRouteConfig = {
name: 'acme-challenge',
priority: 1000,
match: {
ports: 80,
path: '/.well-known/acme-challenge/*'
},
action: {
type: 'socket-handler',
socketHandler: async (socket, context) => {
// Mock handler that would write the challenge response
socket.end('challenge response');
}
}
};
// Initially no challenge routes
expect(stateManager.isChallengeRouteActive()).toBeFalse();
expect(stateManager.getActiveChallengeRoutes()).toEqual([]);
// Add challenge route
stateManager.addChallengeRoute(challengeRoute);
expect(stateManager.isChallengeRouteActive()).toBeTrue();
expect(stateManager.getActiveChallengeRoutes()).toHaveProperty("length", 1);
expect(stateManager.getPrimaryChallengeRoute()).toEqual(challengeRoute);
// Remove challenge route
stateManager.removeChallengeRoute('acme-challenge');
expect(stateManager.isChallengeRouteActive()).toBeFalse();
expect(stateManager.getActiveChallengeRoutes()).toEqual([]);
expect(stateManager.getPrimaryChallengeRoute()).toBeNull();
});
tap.test('AcmeStateManager should track port allocations', async (tools) => {
const stateManager = new AcmeStateManager();
const challengeRoute1: IRouteConfig = {
name: 'acme-challenge-1',
priority: 1000,
match: {
ports: 80,
path: '/.well-known/acme-challenge/*'
},
action: {
type: 'socket-handler'
}
};
const challengeRoute2: IRouteConfig = {
name: 'acme-challenge-2',
priority: 900,
match: {
ports: [80, 8080],
path: '/.well-known/acme-challenge/*'
},
action: {
type: 'socket-handler'
}
};
// Add first route
stateManager.addChallengeRoute(challengeRoute1);
expect(stateManager.isPortAllocatedForAcme(80)).toBeTrue();
expect(stateManager.isPortAllocatedForAcme(8080)).toBeFalse();
expect(stateManager.getAcmePorts()).toEqual([80]);
// Add second route
stateManager.addChallengeRoute(challengeRoute2);
expect(stateManager.isPortAllocatedForAcme(80)).toBeTrue();
expect(stateManager.isPortAllocatedForAcme(8080)).toBeTrue();
expect(stateManager.getAcmePorts()).toContain(80);
expect(stateManager.getAcmePorts()).toContain(8080);
// Remove first route - port 80 should still be allocated
stateManager.removeChallengeRoute('acme-challenge-1');
expect(stateManager.isPortAllocatedForAcme(80)).toBeTrue();
expect(stateManager.isPortAllocatedForAcme(8080)).toBeTrue();
// Remove second route - all ports should be deallocated
stateManager.removeChallengeRoute('acme-challenge-2');
expect(stateManager.isPortAllocatedForAcme(80)).toBeFalse();
expect(stateManager.isPortAllocatedForAcme(8080)).toBeFalse();
expect(stateManager.getAcmePorts()).toEqual([]);
});
tap.test('AcmeStateManager should select primary route by priority', async (tools) => {
const stateManager = new AcmeStateManager();
const lowPriorityRoute: IRouteConfig = {
name: 'low-priority',
priority: 100,
match: {
ports: 80
},
action: {
type: 'socket-handler'
}
};
const highPriorityRoute: IRouteConfig = {
name: 'high-priority',
priority: 2000,
match: {
ports: 80
},
action: {
type: 'socket-handler'
}
};
const defaultPriorityRoute: IRouteConfig = {
name: 'default-priority',
// No priority specified - should default to 0
match: {
ports: 80
},
action: {
type: 'socket-handler'
}
};
// Add low priority first
stateManager.addChallengeRoute(lowPriorityRoute);
expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('low-priority');
// Add high priority - should become primary
stateManager.addChallengeRoute(highPriorityRoute);
expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('high-priority');
// Add default priority - primary should remain high priority
stateManager.addChallengeRoute(defaultPriorityRoute);
expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('high-priority');
// Remove high priority - primary should fall back to low priority
stateManager.removeChallengeRoute('high-priority');
expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('low-priority');
});
tap.test('AcmeStateManager should handle clear operation', async (tools) => {
const stateManager = new AcmeStateManager();
const challengeRoute1: IRouteConfig = {
name: 'route-1',
match: {
ports: [80, 443]
},
action: {
type: 'socket-handler'
}
};
const challengeRoute2: IRouteConfig = {
name: 'route-2',
match: {
ports: 8080
},
action: {
type: 'socket-handler'
}
};
// Add routes
stateManager.addChallengeRoute(challengeRoute1);
stateManager.addChallengeRoute(challengeRoute2);
// Verify state before clear
expect(stateManager.isChallengeRouteActive()).toBeTrue();
expect(stateManager.getActiveChallengeRoutes()).toHaveProperty("length", 2);
expect(stateManager.getAcmePorts()).toHaveProperty("length", 3);
// Clear all state
stateManager.clear();
// Verify state after clear
expect(stateManager.isChallengeRouteActive()).toBeFalse();
expect(stateManager.getActiveChallengeRoutes()).toEqual([]);
expect(stateManager.getAcmePorts()).toEqual([]);
expect(stateManager.getPrimaryChallengeRoute()).toBeNull();
});
export default tap.start();

View File

@@ -1,122 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import { SmartProxy } from '../ts/index.js';
// Test that certificate provisioning is deferred until after ports are listening
tap.test('should defer certificate provisioning until ports are ready', async (tapTest) => {
// Track when operations happen
let portsListening = false;
let certProvisioningStarted = false;
let operationOrder: string[] = [];
// Create proxy with certificate route but without real ACME
const proxy = new SmartProxy({
routes: [{
name: 'test-route',
match: {
ports: 8443,
domains: ['test.local']
},
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8181 }],
tls: {
mode: 'terminate',
certificate: 'auto',
acme: {
email: 'test@local.dev',
useProduction: false
}
}
}
}]
});
// Override the certificate manager creation to avoid real ACME
const originalCreateCertManager = proxy['createCertificateManager'];
proxy['createCertificateManager'] = async function(...args: any[]) {
console.log('Creating mock cert manager');
operationOrder.push('create-cert-manager');
const mockCertManager = {
certStore: null,
smartAcme: null,
httpProxy: null,
renewalTimer: null,
pendingChallenges: new Map(),
challengeRoute: null,
certStatus: new Map(),
globalAcmeDefaults: null,
updateRoutesCallback: undefined,
challengeRouteActive: false,
isProvisioning: false,
acmeStateManager: null,
initialize: async () => {
operationOrder.push('cert-manager-init');
console.log('Mock cert manager initialized');
},
provisionAllCertificates: async () => {
operationOrder.push('cert-provisioning');
certProvisioningStarted = true;
// Check that ports are listening when provisioning starts
if (!portsListening) {
throw new Error('Certificate provisioning started before ports ready!');
}
console.log('Mock certificate provisioning (ports are ready)');
},
stop: async () => {},
setHttpProxy: () => {},
setGlobalAcmeDefaults: () => {},
setAcmeStateManager: () => {},
setUpdateRoutesCallback: () => {},
getAcmeOptions: () => ({}),
getState: () => ({ challengeRouteActive: false }),
getCertStatus: () => new Map(),
checkAndRenewCertificates: async () => {},
addChallengeRoute: async () => {},
removeChallengeRoute: async () => {},
getCertificate: async () => null,
isValidCertificate: () => false,
waitForProvisioning: async () => {}
} as any;
// Call initialize immediately as the real createCertificateManager does
await mockCertManager.initialize();
return mockCertManager;
};
// Track port manager operations
const originalAddPorts = proxy['portManager'].addPorts;
proxy['portManager'].addPorts = async function(ports: number[]) {
operationOrder.push('ports-starting');
const result = await originalAddPorts.call(this, ports);
operationOrder.push('ports-ready');
portsListening = true;
console.log('Ports are now listening');
return result;
};
// Start the proxy
await proxy.start();
// Log the operation order for debugging
console.log('Operation order:', operationOrder);
// Verify operations happened in the correct order
expect(operationOrder).toContain('create-cert-manager');
expect(operationOrder).toContain('cert-manager-init');
expect(operationOrder).toContain('ports-starting');
expect(operationOrder).toContain('ports-ready');
expect(operationOrder).toContain('cert-provisioning');
// Verify ports were ready before certificate provisioning
const portsReadyIndex = operationOrder.indexOf('ports-ready');
const certProvisioningIndex = operationOrder.indexOf('cert-provisioning');
expect(portsReadyIndex).toBeLessThan(certProvisioningIndex);
expect(certProvisioningStarted).toEqual(true);
expect(portsListening).toEqual(true);
await proxy.stop();
});
export default tap.start();

View File

@@ -1,204 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import { SmartProxy } from '../ts/index.js';
import * as net from 'net';
// Test that certificate provisioning waits for ports to be ready
tap.test('should defer certificate provisioning until after ports are listening', async (tapTest) => {
// Track the order of operations
const operationLog: string[] = [];
// Create a mock server to verify ports are listening
let port80Listening = false;
// Try to use port 8080 instead of 80 to avoid permission issues in testing
const acmePort = 8080;
// Create proxy with ACME certificate requirement
const proxy = new SmartProxy({
useHttpProxy: [acmePort],
httpProxyPort: 8845, // Use different port to avoid conflicts
acme: {
email: 'test@test.local',
useProduction: false,
port: acmePort
},
routes: [{
name: 'test-acme-route',
match: {
ports: 8443,
domains: ['test.local']
},
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8181 }],
tls: {
mode: 'terminate',
certificate: 'auto',
acme: {
email: 'test@test.local',
useProduction: false
}
}
}
}]
});
// Mock some internal methods to track operation order
const originalAddPorts = proxy['portManager'].addPorts;
proxy['portManager'].addPorts = async function(ports: number[]) {
operationLog.push('Starting port listeners');
const result = await originalAddPorts.call(this, ports);
operationLog.push('Port listeners started');
port80Listening = true;
return result;
};
// Track that we created a certificate manager and SmartProxy will call provisionAllCertificates
let certManagerCreated = false;
// Override createCertificateManager to set up our tracking
const originalCreateCertManager = (proxy as any).createCertificateManager;
(proxy as any).certManagerCreated = false;
// Mock certificate manager to avoid real ACME initialization
(proxy as any).createCertificateManager = async function() {
operationLog.push('Creating certificate manager');
const mockCertManager = {
setUpdateRoutesCallback: () => {},
setHttpProxy: () => {},
setGlobalAcmeDefaults: () => {},
setAcmeStateManager: () => {},
initialize: async () => {
operationLog.push('Certificate manager initialized');
},
provisionAllCertificates: async () => {
operationLog.push('Starting certificate provisioning');
if (!port80Listening) {
operationLog.push('ERROR: Certificate provisioning started before ports ready');
}
operationLog.push('Certificate provisioning completed');
},
stop: async () => {},
getAcmeOptions: () => ({ email: 'test@test.local', useProduction: false }),
getState: () => ({ challengeRouteActive: false })
};
certManagerCreated = true;
(proxy as any).certManager = mockCertManager;
return mockCertManager;
};
// Start the proxy
await proxy.start();
// Verify the order of operations
expect(operationLog).toContain('Starting port listeners');
expect(operationLog).toContain('Port listeners started');
expect(operationLog).toContain('Starting certificate provisioning');
// Ensure port listeners started before certificate provisioning
const portStartIndex = operationLog.indexOf('Port listeners started');
const certStartIndex = operationLog.indexOf('Starting certificate provisioning');
expect(portStartIndex).toBeLessThan(certStartIndex);
expect(operationLog).not.toContain('ERROR: Certificate provisioning started before ports ready');
await proxy.stop();
});
// Test that ACME challenge route is available when certificate is requested
tap.test('should have ACME challenge route ready before certificate provisioning', async (tapTest) => {
let challengeRouteActive = false;
let certificateProvisioningStarted = false;
const proxy = new SmartProxy({
useHttpProxy: [8080],
httpProxyPort: 8846, // Use different port to avoid conflicts
acme: {
email: 'test@test.local',
useProduction: false,
port: 8080
},
routes: [{
name: 'test-route',
match: {
ports: 8443,
domains: ['test.example.com']
},
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8181 }],
tls: {
mode: 'terminate',
certificate: 'auto'
}
}
}]
});
// Mock the certificate manager to track operations
const originalInitialize = proxy['certManager'] ?
proxy['certManager'].initialize : null;
if (proxy['certManager']) {
const certManager = proxy['certManager'];
// Track when challenge route is added
const originalAddChallenge = certManager['addChallengeRoute'];
certManager['addChallengeRoute'] = async function() {
await originalAddChallenge.call(this);
challengeRouteActive = true;
};
// Track when certificate provisioning starts
const originalProvisionAcme = certManager['provisionAcmeCertificate'];
certManager['provisionAcmeCertificate'] = async function(...args: any[]) {
certificateProvisioningStarted = true;
// Verify challenge route is active
expect(challengeRouteActive).toEqual(true);
// Don't actually provision in test
return;
};
}
// Mock certificate manager to avoid real ACME initialization
(proxy as any).createCertificateManager = async function() {
const mockCertManager = {
setUpdateRoutesCallback: () => {},
setHttpProxy: () => {},
setGlobalAcmeDefaults: () => {},
setAcmeStateManager: () => {},
initialize: async () => {
challengeRouteActive = true;
},
provisionAllCertificates: async () => {
certificateProvisioningStarted = true;
expect(challengeRouteActive).toEqual(true);
},
stop: async () => {},
getAcmeOptions: () => ({ email: 'test@test.local', useProduction: false }),
getState: () => ({ challengeRouteActive: false }),
addChallengeRoute: async () => {
challengeRouteActive = true;
},
provisionAcmeCertificate: async () => {
certificateProvisioningStarted = true;
expect(challengeRouteActive).toEqual(true);
}
};
// Call initialize like the real createCertificateManager does
await mockCertManager.initialize();
return mockCertManager;
};
await proxy.start();
// Give it a moment to complete initialization
await new Promise(resolve => setTimeout(resolve, 100));
// Verify challenge route was added before any certificate provisioning
expect(challengeRouteActive).toEqual(true);
await proxy.stop();
});
export default tap.start();

123
test/test.bun.ts Normal file
View File

@@ -0,0 +1,123 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import {
createHttpsTerminateRoute,
createCompleteHttpsServer,
createHttpRoute,
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import {
mergeRouteConfigs,
cloneRoute,
routeMatchesPath,
} from '../ts/proxies/smart-proxy/utils/route-utils.js';
import {
validateRoutes,
validateRouteConfig,
} from '../ts/proxies/smart-proxy/utils/route-validator.js';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
tap.test('route creation - createHttpsTerminateRoute produces correct structure', async () => {
const route = createHttpsTerminateRoute('secure.example.com', { host: '127.0.0.1', port: 8443 });
expect(route).toHaveProperty('match');
expect(route).toHaveProperty('action');
expect(route.action.type).toEqual('forward');
expect(route.action.tls).toBeDefined();
expect(route.action.tls!.mode).toEqual('terminate');
expect(route.match.domains).toEqual('secure.example.com');
});
tap.test('route creation - createCompleteHttpsServer returns redirect and main route', async () => {
const routes = createCompleteHttpsServer('app.example.com', { host: '127.0.0.1', port: 3000 });
expect(routes).toBeArray();
expect(routes.length).toBeGreaterThanOrEqual(2);
// Should have an HTTP→HTTPS redirect and an HTTPS route
const hasRedirect = routes.some((r) => r.action.type === 'forward' && r.action.redirect !== undefined);
const hasHttps = routes.some((r) => r.action.tls?.mode === 'terminate');
expect(hasRedirect || hasHttps).toBeTrue();
});
tap.test('route validation - validateRoutes on a set of routes', async () => {
const routes: IRouteConfig[] = [
createHttpRoute('a.com', { host: '127.0.0.1', port: 3000 }),
createHttpRoute('b.com', { host: '127.0.0.1', port: 4000 }),
];
const result = validateRoutes(routes);
expect(result.valid).toBeTrue();
expect(result.errors).toHaveLength(0);
});
tap.test('route validation - validateRoutes catches invalid route in set', async () => {
const routes: any[] = [
createHttpRoute('valid.com', { host: '127.0.0.1', port: 3000 }),
{ match: { ports: 80 } }, // missing action
];
const result = validateRoutes(routes);
expect(result.valid).toBeFalse();
expect(result.errors.length).toBeGreaterThan(0);
});
tap.test('path matching - routeMatchesPath with exact path', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
route.match.path = '/api';
expect(routeMatchesPath(route, '/api')).toBeTrue();
expect(routeMatchesPath(route, '/other')).toBeFalse();
});
tap.test('path matching - route without path matches everything', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
// No path set, should match any path
expect(routeMatchesPath(route, '/anything')).toBeTrue();
expect(routeMatchesPath(route, '/')).toBeTrue();
});
tap.test('route merging - mergeRouteConfigs combines routes', async () => {
const base = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
base.priority = 10;
base.name = 'base-route';
const merged = mergeRouteConfigs(base, {
priority: 50,
name: 'merged-route',
});
expect(merged.priority).toEqual(50);
expect(merged.name).toEqual('merged-route');
// Original route fields should be preserved
expect(merged.match.domains).toEqual('example.com');
expect(merged.action.targets![0].host).toEqual('127.0.0.1');
});
tap.test('route merging - mergeRouteConfigs does not mutate original', async () => {
const base = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
base.name = 'original';
const merged = mergeRouteConfigs(base, { name: 'changed' });
expect(base.name).toEqual('original');
expect(merged.name).toEqual('changed');
});
tap.test('route cloning - cloneRoute produces independent copy', async () => {
const original = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
original.priority = 42;
original.name = 'original-route';
const cloned = cloneRoute(original);
// Should be equal in value
expect(cloned.match.domains).toEqual('example.com');
expect(cloned.priority).toEqual(42);
expect(cloned.name).toEqual('original-route');
expect(cloned.action.targets![0].host).toEqual('127.0.0.1');
expect(cloned.action.targets![0].port).toEqual(3000);
// Should be independent - modifying clone shouldn't affect original
cloned.name = 'cloned-route';
cloned.priority = 99;
expect(original.name).toEqual('original-route');
expect(original.priority).toEqual(42);
});
export default tap.start();

View File

@@ -1,77 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as plugins from '../ts/plugins.js';
import * as smartproxy from '../ts/index.js';
// This test verifies that SmartProxy correctly uses the updated SmartAcme v8.0.0 API
// with the optional wildcard parameter
tap.test('SmartCertManager should call getCertificateForDomain with wildcard option', async () => {
console.log('Testing SmartCertManager with SmartAcme v8.0.0 API...');
// Create a mock route with ACME certificate configuration
const mockRoute: smartproxy.IRouteConfig = {
match: {
domains: ['test.example.com'],
ports: 443
},
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 8080
}],
tls: {
mode: 'terminate',
certificate: 'auto',
acme: {
email: 'test@example.com',
useProduction: false
}
}
},
name: 'test-route'
};
// Create a certificate manager
const certManager = new smartproxy.SmartCertManager(
[mockRoute],
'./test-certs',
{
email: 'test@example.com',
useProduction: false
}
);
// Since we can't actually test ACME in a unit test, we'll just verify the logic
// The actual test would be that it builds and runs without errors
// Test the wildcard logic for different domain types and challenge handlers
const testCases = [
{ domain: 'example.com', hasDnsChallenge: true, shouldIncludeWildcard: true },
{ domain: 'example.com', hasDnsChallenge: false, shouldIncludeWildcard: false },
{ domain: 'sub.example.com', hasDnsChallenge: true, shouldIncludeWildcard: true },
{ domain: 'sub.example.com', hasDnsChallenge: false, shouldIncludeWildcard: false },
{ domain: '*.example.com', hasDnsChallenge: true, shouldIncludeWildcard: false },
{ domain: '*.example.com', hasDnsChallenge: false, shouldIncludeWildcard: false },
{ domain: 'test', hasDnsChallenge: true, shouldIncludeWildcard: false }, // single label domain
{ domain: 'test', hasDnsChallenge: false, shouldIncludeWildcard: false },
{ domain: 'my.sub.example.com', hasDnsChallenge: true, shouldIncludeWildcard: true },
{ domain: 'my.sub.example.com', hasDnsChallenge: false, shouldIncludeWildcard: false }
];
for (const testCase of testCases) {
const shouldIncludeWildcard = !testCase.domain.startsWith('*.') &&
testCase.domain.includes('.') &&
testCase.domain.split('.').length >= 2 &&
testCase.hasDnsChallenge;
console.log(`Domain: ${testCase.domain}, DNS-01: ${testCase.hasDnsChallenge}, Should include wildcard: ${shouldIncludeWildcard}`);
expect(shouldIncludeWildcard).toEqual(testCase.shouldIncludeWildcard);
}
console.log('All wildcard logic tests passed!');
});
tap.start({
throwOnError: true
});

View File

@@ -1,360 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { SmartProxy } from '../ts/index.js';
import type { TSmartProxyCertProvisionObject } from '../ts/index.js';
import * as fs from 'fs';
import * as path from 'path';
import { fileURLToPath } from 'url';
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
let testProxy: SmartProxy;
// Load test certificates from helpers
const testCert = fs.readFileSync(path.join(__dirname, 'helpers/test-cert.pem'), 'utf8');
const testKey = fs.readFileSync(path.join(__dirname, 'helpers/test-key.pem'), 'utf8');
tap.test('SmartProxy should support custom certificate provision function', async () => {
// Create test certificate object matching ICert interface
const testCertObject = {
id: 'test-cert-1',
domainName: 'test.example.com',
created: Date.now(),
validUntil: Date.now() + 90 * 24 * 60 * 60 * 1000, // 90 days
privateKey: testKey,
publicKey: testCert,
csr: ''
};
// Custom certificate store for testing
const customCerts = new Map<string, typeof testCertObject>();
customCerts.set('test.example.com', testCertObject);
// Create proxy with custom certificate provision
testProxy = new SmartProxy({
certProvisionFunction: async (domain: string): Promise<TSmartProxyCertProvisionObject> => {
console.log(`Custom cert provision called for domain: ${domain}`);
// Return custom cert for known domains
if (customCerts.has(domain)) {
console.log(`Returning custom certificate for ${domain}`);
return customCerts.get(domain)!;
}
// Fallback to Let's Encrypt for other domains
console.log(`Falling back to Let's Encrypt for ${domain}`);
return 'http01';
},
certProvisionFallbackToAcme: true,
acme: {
email: 'test@example.com',
useProduction: false
},
routes: [
{
name: 'test-route',
match: {
ports: [443],
domains: ['test.example.com']
},
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 8080
}],
tls: {
mode: 'terminate',
certificate: 'auto'
}
}
}
]
});
expect(testProxy).toBeInstanceOf(SmartProxy);
});
tap.test('Custom certificate provision function should be called', async () => {
let provisionCalled = false;
const provisionedDomains: string[] = [];
const testProxy2 = new SmartProxy({
certProvisionFunction: async (domain: string): Promise<TSmartProxyCertProvisionObject> => {
provisionCalled = true;
provisionedDomains.push(domain);
// Return a test certificate matching ICert interface
return {
id: `test-cert-${domain}`,
domainName: domain,
created: Date.now(),
validUntil: Date.now() + 90 * 24 * 60 * 60 * 1000,
privateKey: testKey,
publicKey: testCert,
csr: ''
};
},
acme: {
email: 'test@example.com',
useProduction: false,
port: 9080
},
routes: [
{
name: 'custom-cert-route',
match: {
ports: [9443],
domains: ['custom.example.com']
},
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 8080
}],
tls: {
mode: 'terminate',
certificate: 'auto'
}
}
}
]
});
// Mock the certificate manager to test our custom provision function
let certManagerCalled = false;
const origCreateCertManager = (testProxy2 as any).createCertificateManager;
(testProxy2 as any).createCertificateManager = async function(...args: any[]) {
const certManager = await origCreateCertManager.apply(testProxy2, args);
// Override provisionAllCertificates to track calls
const origProvisionAll = certManager.provisionAllCertificates;
certManager.provisionAllCertificates = async function() {
certManagerCalled = true;
await origProvisionAll.call(certManager);
};
return certManager;
};
// Start the proxy (this will trigger certificate provisioning)
await testProxy2.start();
expect(certManagerCalled).toBeTrue();
expect(provisionCalled).toBeTrue();
expect(provisionedDomains).toContain('custom.example.com');
await testProxy2.stop();
});
tap.test('Should fallback to ACME when custom provision fails', async () => {
const failedDomains: string[] = [];
let acmeAttempted = false;
const testProxy3 = new SmartProxy({
certProvisionFunction: async (domain: string): Promise<TSmartProxyCertProvisionObject> => {
failedDomains.push(domain);
throw new Error('Custom provision failed for testing');
},
certProvisionFallbackToAcme: true,
acme: {
email: 'test@example.com',
useProduction: false,
port: 9080
},
routes: [
{
name: 'fallback-route',
match: {
ports: [9444],
domains: ['fallback.example.com']
},
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 8080
}],
tls: {
mode: 'terminate',
certificate: 'auto'
}
}
}
]
});
// Mock to track ACME attempts
const origCreateCertManager = (testProxy3 as any).createCertificateManager;
(testProxy3 as any).createCertificateManager = async function(...args: any[]) {
const certManager = await origCreateCertManager.apply(testProxy3, args);
// Mock SmartAcme to avoid real ACME calls
(certManager as any).smartAcme = {
getCertificateForDomain: async () => {
acmeAttempted = true;
throw new Error('Mocked ACME failure');
}
};
return certManager;
};
// Start the proxy
await testProxy3.start();
// Custom provision should have failed
expect(failedDomains).toContain('fallback.example.com');
// ACME should have been attempted as fallback
expect(acmeAttempted).toBeTrue();
await testProxy3.stop();
});
tap.test('Should not fallback when certProvisionFallbackToAcme is false', async () => {
let errorThrown = false;
let errorMessage = '';
const testProxy4 = new SmartProxy({
certProvisionFunction: async (_domain: string): Promise<TSmartProxyCertProvisionObject> => {
throw new Error('Custom provision failed for testing');
},
certProvisionFallbackToAcme: false,
routes: [
{
name: 'no-fallback-route',
match: {
ports: [9445],
domains: ['no-fallback.example.com']
},
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 8080
}],
tls: {
mode: 'terminate',
certificate: 'auto'
}
}
}
]
});
// Mock certificate manager to capture errors
const origCreateCertManager = (testProxy4 as any).createCertificateManager;
(testProxy4 as any).createCertificateManager = async function(...args: any[]) {
const certManager = await origCreateCertManager.apply(testProxy4, args);
// Override provisionAllCertificates to capture errors
const origProvisionAll = certManager.provisionAllCertificates;
certManager.provisionAllCertificates = async function() {
try {
await origProvisionAll.call(certManager);
} catch (e) {
errorThrown = true;
errorMessage = e.message;
throw e;
}
};
return certManager;
};
try {
await testProxy4.start();
} catch (e) {
// Expected to fail
}
expect(errorThrown).toBeTrue();
expect(errorMessage).toInclude('Custom provision failed for testing');
await testProxy4.stop();
});
tap.test('Should return http01 for unknown domains', async () => {
let returnedHttp01 = false;
let acmeAttempted = false;
const testProxy5 = new SmartProxy({
certProvisionFunction: async (domain: string): Promise<TSmartProxyCertProvisionObject> => {
if (domain === 'known.example.com') {
return {
id: `test-cert-${domain}`,
domainName: domain,
created: Date.now(),
validUntil: Date.now() + 90 * 24 * 60 * 60 * 1000,
privateKey: testKey,
publicKey: testCert,
csr: ''
};
}
returnedHttp01 = true;
return 'http01';
},
acme: {
email: 'test@example.com',
useProduction: false,
port: 9081
},
routes: [
{
name: 'unknown-domain-route',
match: {
ports: [9446],
domains: ['unknown.example.com']
},
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 8080
}],
tls: {
mode: 'terminate',
certificate: 'auto'
}
}
}
]
});
// Mock to track ACME attempts
const origCreateCertManager = (testProxy5 as any).createCertificateManager;
(testProxy5 as any).createCertificateManager = async function(...args: any[]) {
const certManager = await origCreateCertManager.apply(testProxy5, args);
// Mock SmartAcme to track attempts
(certManager as any).smartAcme = {
getCertificateForDomain: async () => {
acmeAttempted = true;
throw new Error('Mocked ACME failure');
}
};
return certManager;
};
await testProxy5.start();
// Should have returned http01 for unknown domain
expect(returnedHttp01).toBeTrue();
// ACME should have been attempted
expect(acmeAttempted).toBeTrue();
await testProxy5.stop();
});
tap.test('cleanup', async () => {
// Clean up any test proxies
if (testProxy) {
await testProxy.stop();
}
});
export default tap.start();

View File

@@ -1,241 +0,0 @@
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
import { expect, tap } from '@git.zone/tstest/tapbundle';
const testProxy = new SmartProxy({
routes: [{
name: 'test-route',
match: { ports: 9443, domains: 'test.local' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8080 }],
tls: {
mode: 'terminate',
certificate: 'auto',
acme: {
email: 'test@test.local',
useProduction: false
}
}
}
}],
acme: {
port: 9080 // Use high port for ACME challenges
}
});
tap.test('should provision certificate automatically', async () => {
// Mock certificate manager to avoid real ACME initialization
const mockCertStatus = {
domain: 'test-route',
status: 'valid' as const,
source: 'acme' as const,
expiryDate: new Date(Date.now() + 90 * 24 * 60 * 60 * 1000),
issueDate: new Date()
};
(testProxy as any).createCertificateManager = async function() {
return {
setUpdateRoutesCallback: () => {},
setHttpProxy: () => {},
setGlobalAcmeDefaults: () => {},
setAcmeStateManager: () => {},
initialize: async () => {},
provisionAllCertificates: async () => {},
stop: async () => {},
getAcmeOptions: () => ({ email: 'test@test.local', useProduction: false }),
getState: () => ({ challengeRouteActive: false }),
getCertificateStatus: () => mockCertStatus
};
};
(testProxy as any).getCertificateStatus = () => mockCertStatus;
await testProxy.start();
const status = testProxy.getCertificateStatus('test-route');
expect(status).toBeDefined();
expect(status.status).toEqual('valid');
expect(status.source).toEqual('acme');
await testProxy.stop();
});
tap.test('should handle static certificates', async () => {
const proxy = new SmartProxy({
routes: [{
name: 'static-route',
match: { ports: 9444, domains: 'static.example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8080 }],
tls: {
mode: 'terminate',
certificate: {
cert: '-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----',
key: '-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----'
}
}
}
}]
});
await proxy.start();
const status = proxy.getCertificateStatus('static-route');
expect(status).toBeDefined();
expect(status.status).toEqual('valid');
expect(status.source).toEqual('static');
await proxy.stop();
});
tap.test('should handle ACME challenge routes', async () => {
const proxy = new SmartProxy({
routes: [{
name: 'auto-cert-route',
match: { ports: 9445, domains: 'acme.local' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8080 }],
tls: {
mode: 'terminate',
certificate: 'auto',
acme: {
email: 'acme@test.local',
useProduction: false,
challengePort: 9081
}
}
}
}, {
name: 'port-9081-route',
match: { ports: 9081, domains: 'acme.local' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8080 }]
}
}],
acme: {
port: 9081 // Use high port for ACME challenges
}
});
// Mock certificate manager to avoid real ACME initialization
(proxy as any).createCertificateManager = async function() {
return {
setUpdateRoutesCallback: () => {},
setHttpProxy: () => {},
setGlobalAcmeDefaults: () => {},
setAcmeStateManager: () => {},
initialize: async () => {},
provisionAllCertificates: async () => {},
stop: async () => {},
getAcmeOptions: () => ({ email: 'acme@test.local', useProduction: false }),
getState: () => ({ challengeRouteActive: false })
};
};
await proxy.start();
// Verify the proxy is configured with routes including the necessary port
const routes = proxy.settings.routes;
// Check that we have a route listening on the ACME challenge port
const acmeChallengePort = 9081;
const routesOnChallengePort = routes.filter((r: any) => {
const ports = Array.isArray(r.match.ports) ? r.match.ports : [r.match.ports];
return ports.includes(acmeChallengePort);
});
expect(routesOnChallengePort.length).toBeGreaterThan(0);
expect(routesOnChallengePort[0].name).toEqual('port-9081-route');
// Verify the main route has ACME configuration
const mainRoute = routes.find((r: any) => r.name === 'auto-cert-route');
expect(mainRoute).toBeDefined();
expect(mainRoute?.action.tls?.certificate).toEqual('auto');
expect(mainRoute?.action.tls?.acme?.email).toEqual('acme@test.local');
expect(mainRoute?.action.tls?.acme?.challengePort).toEqual(9081);
await proxy.stop();
});
tap.test('should renew certificates', async () => {
const proxy = new SmartProxy({
routes: [{
name: 'renew-route',
match: { ports: 9446, domains: 'renew.local' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8080 }],
tls: {
mode: 'terminate',
certificate: 'auto',
acme: {
email: 'renew@test.local',
useProduction: false,
renewBeforeDays: 30
}
}
}
}],
acme: {
port: 9082 // Use high port for ACME challenges
}
});
// Mock certificate manager with renewal capability
let renewCalled = false;
const mockCertStatus = {
domain: 'renew-route',
status: 'valid' as const,
source: 'acme' as const,
expiryDate: new Date(Date.now() + 90 * 24 * 60 * 60 * 1000),
issueDate: new Date()
};
(proxy as any).certManager = {
renewCertificate: async (routeName: string) => {
renewCalled = true;
expect(routeName).toEqual('renew-route');
},
getCertificateStatus: () => mockCertStatus,
setUpdateRoutesCallback: () => {},
setHttpProxy: () => {},
setGlobalAcmeDefaults: () => {},
setAcmeStateManager: () => {},
initialize: async () => {},
provisionAllCertificates: async () => {},
stop: async () => {},
getAcmeOptions: () => ({ email: 'renew@test.local', useProduction: false }),
getState: () => ({ challengeRouteActive: false })
};
(proxy as any).createCertificateManager = async function() {
return this.certManager;
};
(proxy as any).getCertificateStatus = function(routeName: string) {
return this.certManager.getCertificateStatus(routeName);
};
(proxy as any).renewCertificate = async function(routeName: string) {
if (this.certManager) {
await this.certManager.renewCertificate(routeName);
}
};
await proxy.start();
// Force renewal
await proxy.renewCertificate('renew-route');
expect(renewCalled).toBeTrue();
const status = proxy.getCertificateStatus('renew-route');
expect(status).toBeDefined();
expect(status.status).toEqual('valid');
await proxy.stop();
});
export default tap.start();

View File

@@ -1,146 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { SmartProxy } from '../ts/index.js';
tap.test('cleanup queue bug - verify queue processing handles more than batch size', async () => {
console.log('\n=== Cleanup Queue Bug Test ===');
console.log('Purpose: Verify that the cleanup queue correctly processes all connections');
console.log('even when there are more than the batch size (100)');
// Create proxy
const proxy = new SmartProxy({
routes: [{
name: 'test-route',
match: { ports: 8588 },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 9996 }]
}
}],
enableDetailedLogging: false,
});
await proxy.start();
console.log('✓ Proxy started on port 8588');
// Access connection manager
const cm = (proxy as any).connectionManager;
// Create mock connection records
console.log('\n--- Creating 150 mock connections ---');
const mockConnections: any[] = [];
for (let i = 0; i < 150; i++) {
// Create mock socket objects with necessary methods
const mockIncoming = {
destroyed: true,
writable: false,
remoteAddress: '127.0.0.1',
removeAllListeners: () => {},
destroy: () => {},
end: () => {},
on: () => {},
once: () => {},
emit: () => {},
pause: () => {},
resume: () => {}
};
const mockOutgoing = {
destroyed: true,
writable: false,
removeAllListeners: () => {},
destroy: () => {},
end: () => {},
on: () => {},
once: () => {},
emit: () => {}
};
const mockRecord = {
id: `mock-${i}`,
incoming: mockIncoming,
outgoing: mockOutgoing,
connectionClosed: false,
incomingStartTime: Date.now(),
lastActivity: Date.now(),
remoteIP: '127.0.0.1',
remotePort: 10000 + i,
localPort: 8588,
bytesReceived: 100,
bytesSent: 100,
incomingTerminationReason: null,
cleanupTimer: null
};
// Add to connection records
cm.connectionRecords.set(mockRecord.id, mockRecord);
mockConnections.push(mockRecord);
}
console.log(`Created ${cm.getConnectionCount()} mock connections`);
expect(cm.getConnectionCount()).toEqual(150);
// Queue all connections for cleanup
console.log('\n--- Queueing all connections for cleanup ---');
// The cleanup queue processes immediately when it reaches batch size (100)
// So after queueing 150, the first 100 will be processed immediately
for (const conn of mockConnections) {
cm.initiateCleanupOnce(conn, 'test_cleanup');
}
// After queueing 150, the first 100 should have been processed immediately
// leaving 50 in the queue
console.log(`Cleanup queue size after queueing: ${cm.cleanupQueue.size}`);
console.log(`Active connections after initial batch: ${cm.getConnectionCount()}`);
// The first 100 should have been cleaned up immediately
expect(cm.cleanupQueue.size).toEqual(50);
expect(cm.getConnectionCount()).toEqual(50);
// Wait for remaining cleanup to complete
console.log('\n--- Waiting for remaining cleanup batches to process ---');
// The remaining 50 connections should be cleaned up in the next batch
let waitTime = 0;
let lastCount = cm.getConnectionCount();
while (cm.getConnectionCount() > 0 || cm.cleanupQueue.size > 0) {
await new Promise(resolve => setTimeout(resolve, 100));
waitTime += 100;
const currentCount = cm.getConnectionCount();
if (currentCount !== lastCount) {
console.log(`Active connections: ${currentCount}, Queue size: ${cm.cleanupQueue.size}`);
lastCount = currentCount;
}
if (waitTime > 5000) {
console.log('Timeout waiting for cleanup to complete');
break;
}
}
console.log(`All cleanup completed in ${waitTime}ms`);
// Check final state
const finalCount = cm.getConnectionCount();
console.log(`\nFinal connection count: ${finalCount}`);
console.log(`Final cleanup queue size: ${cm.cleanupQueue.size}`);
// All connections should be cleaned up
expect(finalCount).toEqual(0);
expect(cm.cleanupQueue.size).toEqual(0);
// Verify termination stats - all 150 should have been terminated
const stats = cm.getTerminationStats();
console.log('Termination stats:', stats);
expect(stats.incoming.test_cleanup).toEqual(150);
// Cleanup
console.log('\n--- Stopping proxy ---');
await proxy.stop();
console.log('\n✓ Test complete: Cleanup queue now correctly processes all connections');
});
export default tap.start();

View File

@@ -1,242 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import * as plugins from '../ts/plugins.js';
// Import SmartProxy and configurations
import { SmartProxy } from '../ts/index.js';
tap.test('should handle clients that connect and immediately disconnect without sending data', async () => {
console.log('\n=== Testing Connect-Disconnect Cleanup ===');
// Create a SmartProxy instance
const proxy = new SmartProxy({
ports: [8560],
enableDetailedLogging: false,
initialDataTimeout: 5000, // 5 second timeout for initial data
routes: [{
name: 'test-route',
match: { ports: 8560 },
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 9999 // Non-existent port
}]
}
}]
});
// Start the proxy
await proxy.start();
console.log('✓ Proxy started on port 8560');
// Helper to get active connection count
const getActiveConnections = () => {
const connectionManager = (proxy as any).connectionManager;
return connectionManager ? connectionManager.getConnectionCount() : 0;
};
const initialCount = getActiveConnections();
console.log(`Initial connection count: ${initialCount}`);
// Test 1: Connect and immediately disconnect without sending data
console.log('\n--- Test 1: Immediate disconnect ---');
const connectionCounts: number[] = [];
for (let i = 0; i < 10; i++) {
const client = new net.Socket();
// Connect and immediately destroy
client.connect(8560, 'localhost', () => {
// Connected - immediately destroy without sending data
client.destroy();
});
// Wait a tiny bit
await new Promise(resolve => setTimeout(resolve, 10));
const count = getActiveConnections();
connectionCounts.push(count);
if ((i + 1) % 5 === 0) {
console.log(`After ${i + 1} connect/disconnect cycles: ${count} active connections`);
}
}
// Wait a bit for cleanup
await new Promise(resolve => setTimeout(resolve, 500));
const afterImmediateDisconnect = getActiveConnections();
console.log(`After immediate disconnect test: ${afterImmediateDisconnect} active connections`);
// Test 2: Connect, wait a bit, then disconnect without sending data
console.log('\n--- Test 2: Delayed disconnect ---');
for (let i = 0; i < 5; i++) {
const client = new net.Socket();
client.on('error', () => {
// Ignore errors
});
client.connect(8560, 'localhost', () => {
// Wait 100ms then disconnect without sending data
setTimeout(() => {
if (!client.destroyed) {
client.destroy();
}
}, 100);
});
}
// Check count immediately
const duringDelayed = getActiveConnections();
console.log(`During delayed disconnect test: ${duringDelayed} active connections`);
// Wait for cleanup
await new Promise(resolve => setTimeout(resolve, 1000));
const afterDelayedDisconnect = getActiveConnections();
console.log(`After delayed disconnect test: ${afterDelayedDisconnect} active connections`);
// Test 3: Mix of immediate and delayed disconnects
console.log('\n--- Test 3: Mixed disconnect patterns ---');
const promises = [];
for (let i = 0; i < 20; i++) {
promises.push(new Promise<void>((resolve) => {
const client = new net.Socket();
client.on('error', () => {
resolve();
});
client.on('close', () => {
resolve();
});
client.connect(8560, 'localhost', () => {
if (i % 2 === 0) {
// Half disconnect immediately
client.destroy();
} else {
// Half wait 50ms
setTimeout(() => {
if (!client.destroyed) {
client.destroy();
}
}, 50);
}
});
// Failsafe timeout
setTimeout(() => resolve(), 200);
}));
}
// Wait for all to complete
await Promise.all(promises);
const duringMixed = getActiveConnections();
console.log(`During mixed test: ${duringMixed} active connections`);
// Final cleanup wait
await new Promise(resolve => setTimeout(resolve, 1000));
const finalCount = getActiveConnections();
console.log(`\nFinal connection count: ${finalCount}`);
// Stop the proxy
await proxy.stop();
console.log('✓ Proxy stopped');
// Verify all connections were cleaned up
expect(finalCount).toEqual(initialCount);
expect(afterImmediateDisconnect).toEqual(initialCount);
expect(afterDelayedDisconnect).toEqual(initialCount);
// Check that connections didn't accumulate during the test
const maxCount = Math.max(...connectionCounts);
console.log(`\nMax connection count during immediate disconnect test: ${maxCount}`);
expect(maxCount).toBeLessThan(3); // Should stay very low
console.log('\n✅ PASS: Connect-disconnect cleanup working correctly!');
});
tap.test('should handle clients that error during connection', async () => {
console.log('\n=== Testing Connection Error Cleanup ===');
const proxy = new SmartProxy({
ports: [8561],
enableDetailedLogging: false,
routes: [{
name: 'test-route',
match: { ports: 8561 },
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 9999
}]
}
}]
});
await proxy.start();
console.log('✓ Proxy started on port 8561');
const getActiveConnections = () => {
const connectionManager = (proxy as any).connectionManager;
return connectionManager ? connectionManager.getConnectionCount() : 0;
};
const initialCount = getActiveConnections();
console.log(`Initial connection count: ${initialCount}`);
// Create connections that will error
const promises = [];
for (let i = 0; i < 10; i++) {
promises.push(new Promise<void>((resolve) => {
const client = new net.Socket();
client.on('error', () => {
resolve();
});
client.on('close', () => {
resolve();
});
// Connect to proxy
client.connect(8561, 'localhost', () => {
// Force an error by writing invalid data then destroying
try {
client.write(Buffer.alloc(1024 * 1024)); // Large write
client.destroy();
} catch (e) {
// Ignore
}
});
// Timeout
setTimeout(() => resolve(), 500);
}));
}
await Promise.all(promises);
console.log('✓ All error connections completed');
// Wait for cleanup
await new Promise(resolve => setTimeout(resolve, 500));
const finalCount = getActiveConnections();
console.log(`Final connection count: ${finalCount}`);
await proxy.stop();
console.log('✓ Proxy stopped');
expect(finalCount).toEqual(initialCount);
console.log('\n✅ PASS: Connection error cleanup working correctly!');
});
export default tap.start();

View File

@@ -1,279 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import * as plugins from '../ts/plugins.js';
// Import SmartProxy and configurations
import { SmartProxy } from '../ts/index.js';
tap.test('comprehensive connection cleanup test - all scenarios', async () => {
console.log('\n=== Comprehensive Connection Cleanup Test ===');
// Create a SmartProxy instance
const proxy = new SmartProxy({
ports: [8570, 8571], // One for immediate routing, one for TLS
enableDetailedLogging: false,
initialDataTimeout: 2000,
socketTimeout: 5000,
routes: [
{
name: 'non-tls-route',
match: { ports: 8570 },
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 9999 // Non-existent port
}]
}
},
{
name: 'tls-route',
match: { ports: 8571 },
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 9999 // Non-existent port
}],
tls: {
mode: 'passthrough'
}
}
}
]
});
// Start the proxy
await proxy.start();
console.log('✓ Proxy started on ports 8570 (non-TLS) and 8571 (TLS)');
// Helper to get active connection count
const getActiveConnections = () => {
const connectionManager = (proxy as any).connectionManager;
return connectionManager ? connectionManager.getConnectionCount() : 0;
};
const initialCount = getActiveConnections();
console.log(`Initial connection count: ${initialCount}`);
// Test 1: Rapid ECONNREFUSED retries (from original issue)
console.log('\n--- Test 1: Rapid ECONNREFUSED retries ---');
for (let i = 0; i < 10; i++) {
await new Promise<void>((resolve) => {
const client = new net.Socket();
client.on('error', () => {
client.destroy();
resolve();
});
client.on('close', () => {
resolve();
});
client.connect(8570, 'localhost', () => {
// Send data to trigger routing
client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n');
});
setTimeout(() => {
if (!client.destroyed) {
client.destroy();
}
resolve();
}, 100);
});
if ((i + 1) % 5 === 0) {
const count = getActiveConnections();
console.log(`After ${i + 1} ECONNREFUSED retries: ${count} active connections`);
}
}
// Test 2: Connect without sending data (immediate disconnect)
console.log('\n--- Test 2: Connect without sending data ---');
for (let i = 0; i < 10; i++) {
const client = new net.Socket();
client.on('error', () => {
// Ignore
});
// Connect to non-TLS port and immediately disconnect
client.connect(8570, 'localhost', () => {
client.destroy();
});
await new Promise(resolve => setTimeout(resolve, 10));
}
const afterNoData = getActiveConnections();
console.log(`After connect-without-data test: ${afterNoData} active connections`);
// Test 3: TLS connections that disconnect before handshake
console.log('\n--- Test 3: TLS early disconnect ---');
for (let i = 0; i < 10; i++) {
const client = new net.Socket();
client.on('error', () => {
// Ignore
});
// Connect to TLS port but disconnect before sending handshake
client.connect(8571, 'localhost', () => {
// Wait 50ms then disconnect (before initial data timeout)
setTimeout(() => {
client.destroy();
}, 50);
});
await new Promise(resolve => setTimeout(resolve, 100));
}
const afterTlsEarly = getActiveConnections();
console.log(`After TLS early disconnect test: ${afterTlsEarly} active connections`);
// Test 4: Mixed pattern - simulating real-world chaos
console.log('\n--- Test 4: Mixed chaos pattern ---');
const promises = [];
for (let i = 0; i < 30; i++) {
promises.push(new Promise<void>((resolve) => {
const client = new net.Socket();
const port = i % 2 === 0 ? 8570 : 8571;
client.on('error', () => {
resolve();
});
client.on('close', () => {
resolve();
});
client.connect(port, 'localhost', () => {
const scenario = i % 5;
switch (scenario) {
case 0:
// Immediate disconnect
client.destroy();
break;
case 1:
// Send data then disconnect
client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n');
setTimeout(() => client.destroy(), 20);
break;
case 2:
// Disconnect after delay
setTimeout(() => client.destroy(), 100);
break;
case 3:
// Send partial TLS handshake
if (port === 8571) {
client.write(Buffer.from([0x16, 0x03, 0x01])); // Partial TLS
}
setTimeout(() => client.destroy(), 50);
break;
case 4:
// Just let it timeout
break;
}
});
// Failsafe
setTimeout(() => {
if (!client.destroyed) {
client.destroy();
}
resolve();
}, 500);
}));
// Small delay between connections
if (i % 5 === 0) {
await new Promise(resolve => setTimeout(resolve, 10));
}
}
await Promise.all(promises);
console.log('✓ Chaos test completed');
// Wait for any cleanup
await new Promise(resolve => setTimeout(resolve, 1000));
const afterChaos = getActiveConnections();
console.log(`After chaos test: ${afterChaos} active connections`);
// Test 5: NFTables route (should cleanup properly)
console.log('\n--- Test 5: NFTables route cleanup ---');
const nftProxy = new SmartProxy({
ports: [8572],
enableDetailedLogging: false,
routes: [{
name: 'nftables-route',
match: { ports: 8572 },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{
host: 'localhost',
port: 9999
}]
}
}]
});
await nftProxy.start();
const getNftConnections = () => {
const connectionManager = (nftProxy as any).connectionManager;
return connectionManager ? connectionManager.getConnectionCount() : 0;
};
// Create NFTables connections
for (let i = 0; i < 5; i++) {
const client = new net.Socket();
client.on('error', () => {
// Ignore
});
client.connect(8572, 'localhost', () => {
setTimeout(() => client.destroy(), 50);
});
await new Promise(resolve => setTimeout(resolve, 100));
}
await new Promise(resolve => setTimeout(resolve, 500));
const nftFinal = getNftConnections();
console.log(`NFTables connections after test: ${nftFinal}`);
await nftProxy.stop();
// Final check on main proxy
const finalCount = getActiveConnections();
console.log(`\nFinal connection count: ${finalCount}`);
// Stop the proxy
await proxy.stop();
console.log('✓ Proxy stopped');
// Verify all connections were cleaned up
expect(finalCount).toEqual(initialCount);
expect(afterNoData).toEqual(initialCount);
expect(afterTlsEarly).toEqual(initialCount);
expect(afterChaos).toEqual(initialCount);
expect(nftFinal).toEqual(0);
console.log('\n✅ PASS: Comprehensive connection cleanup test passed!');
console.log('All connection scenarios properly cleaned up:');
console.log('- ECONNREFUSED rapid retries');
console.log('- Connect without sending data');
console.log('- TLS early disconnect');
console.log('- Mixed chaos patterns');
console.log('- NFTables connections');
});
export default tap.start();

View File

@@ -58,8 +58,7 @@ tap.test('should forward TCP connections correctly', async () => {
enableDetailedLogging: true,
routes: [
{
id: 'tcp-forward',
name: 'TCP Forward Route',
name: 'tcp-forward',
match: {
ports: 8080,
},
@@ -85,17 +84,15 @@ tap.test('should forward TCP connections correctly', async () => {
socket.on('error', reject);
});
// Test data transmission
// Test data transmission - wait for welcome message first
await new Promise<void>((resolve) => {
client.on('data', (data) => {
client.once('data', (data) => {
const response = data.toString();
console.log('Received:', response);
expect(response).toContain('Connected to TCP test server');
client.end();
resolve();
});
client.write('Hello from client');
});
await smartProxy.stop();
@@ -107,8 +104,7 @@ tap.test('should handle TLS passthrough correctly', async () => {
enableDetailedLogging: true,
routes: [
{
id: 'tls-passthrough',
name: 'TLS Passthrough Route',
name: 'tls-passthrough',
match: {
ports: 8443,
domains: 'test.example.com',
@@ -148,15 +144,13 @@ tap.test('should handle TLS passthrough correctly', async () => {
// Test data transmission over TLS
await new Promise<void>((resolve) => {
client.on('data', (data) => {
client.once('data', (data) => {
const response = data.toString();
console.log('TLS Received:', response);
expect(response).toContain('Connected to TLS test server');
client.end();
resolve();
});
client.write('Hello from TLS client');
});
await smartProxy.stop();
@@ -168,8 +162,7 @@ tap.test('should handle SNI-based forwarding', async () => {
enableDetailedLogging: true,
routes: [
{
id: 'domain-a',
name: 'Domain A Route',
name: 'domain-a',
match: {
ports: 8443,
domains: 'a.example.com',
@@ -186,8 +179,7 @@ tap.test('should handle SNI-based forwarding', async () => {
},
},
{
id: 'domain-b',
name: 'Domain B Route',
name: 'domain-b',
match: {
ports: 8443,
domains: 'b.example.com',
@@ -226,15 +218,13 @@ tap.test('should handle SNI-based forwarding', async () => {
});
await new Promise<void>((resolve) => {
clientA.on('data', (data) => {
clientA.once('data', (data) => {
const response = data.toString();
console.log('Domain A response:', response);
expect(response).toContain('Connected to TLS test server');
clientA.end();
resolve();
});
clientA.write('Hello from domain A');
});
// Test domain B should also use TLS since it's on port 8443
@@ -255,7 +245,7 @@ tap.test('should handle SNI-based forwarding', async () => {
});
await new Promise<void>((resolve) => {
clientB.on('data', (data) => {
clientB.once('data', (data) => {
const response = data.toString();
console.log('Domain B response:', response);
// Should be forwarded to TLS server
@@ -263,8 +253,6 @@ tap.test('should handle SNI-based forwarding', async () => {
clientB.end();
resolve();
});
clientB.write('Hello from domain B');
});
await smartProxy.stop();

View File

@@ -1,299 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
import { HttpProxy } from '../ts/proxies/http-proxy/index.js';
let testServer: net.Server;
let smartProxy: SmartProxy;
let httpProxy: HttpProxy;
const TEST_SERVER_PORT = 5100;
const PROXY_PORT = 5101;
const HTTP_PROXY_PORT = 5102;
// Track all created servers and connections for cleanup
const allServers: net.Server[] = [];
const allProxies: (SmartProxy | HttpProxy)[] = [];
const activeConnections: net.Socket[] = [];
// Helper: Creates a test TCP server
function createTestServer(port: number): Promise<net.Server> {
return new Promise((resolve) => {
const server = net.createServer((socket) => {
socket.on('data', (data) => {
socket.write(`Echo: ${data.toString()}`);
});
socket.on('error', () => {});
});
server.listen(port, 'localhost', () => {
console.log(`[Test Server] Listening on localhost:${port}`);
allServers.push(server);
resolve(server);
});
});
}
// Helper: Creates multiple concurrent connections
async function createConcurrentConnections(
port: number,
count: number,
fromIP?: string
): Promise<net.Socket[]> {
const connections: net.Socket[] = [];
const promises: Promise<net.Socket>[] = [];
for (let i = 0; i < count; i++) {
promises.push(
new Promise((resolve, reject) => {
const client = new net.Socket();
const timeout = setTimeout(() => {
client.destroy();
reject(new Error(`Connection ${i} timeout`));
}, 5000);
client.connect(port, 'localhost', () => {
clearTimeout(timeout);
activeConnections.push(client);
connections.push(client);
resolve(client);
});
client.on('error', (err) => {
clearTimeout(timeout);
reject(err);
});
})
);
}
await Promise.all(promises);
return connections;
}
// Helper: Clean up connections
function cleanupConnections(connections: net.Socket[]): void {
connections.forEach(conn => {
if (!conn.destroyed) {
conn.destroy();
}
});
}
tap.test('Setup test environment', async () => {
testServer = await createTestServer(TEST_SERVER_PORT);
// Create SmartProxy with low connection limits for testing
smartProxy = new SmartProxy({
routes: [{
name: 'test-route',
match: {
ports: PROXY_PORT
},
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: TEST_SERVER_PORT
}]
},
security: {
maxConnections: 5 // Low limit for testing
}
}],
maxConnectionsPerIP: 3, // Low per-IP limit
connectionRateLimitPerMinute: 10, // Low rate limit
defaults: {
security: {
maxConnections: 10 // Low global limit
}
}
});
await smartProxy.start();
allProxies.push(smartProxy);
});
tap.test('Per-IP connection limits', async () => {
// Test that we can create up to the per-IP limit
const connections1 = await createConcurrentConnections(PROXY_PORT, 3);
expect(connections1.length).toEqual(3);
// Try to create one more connection - should fail
try {
await createConcurrentConnections(PROXY_PORT, 1);
expect.fail('Should not allow more than 3 connections per IP');
} catch (err) {
expect(err.message).toInclude('ECONNRESET');
}
// Clean up first set of connections
cleanupConnections(connections1);
await new Promise(resolve => setTimeout(resolve, 100));
// Should be able to create new connections after cleanup
const connections2 = await createConcurrentConnections(PROXY_PORT, 2);
expect(connections2.length).toEqual(2);
cleanupConnections(connections2);
});
tap.test('Route-level connection limits', async () => {
// Create multiple connections up to route limit
const connections = await createConcurrentConnections(PROXY_PORT, 5);
expect(connections.length).toEqual(5);
// Try to exceed route limit
try {
await createConcurrentConnections(PROXY_PORT, 1);
expect.fail('Should not allow more than 5 connections for this route');
} catch (err) {
expect(err.message).toInclude('ECONNRESET');
}
cleanupConnections(connections);
});
tap.test('Connection rate limiting', async () => {
// Create connections rapidly
const connections: net.Socket[] = [];
// Create 10 connections rapidly (at rate limit)
for (let i = 0; i < 10; i++) {
try {
const conn = await createConcurrentConnections(PROXY_PORT, 1);
connections.push(...conn);
// Small delay to avoid per-IP limit
if (connections.length >= 3) {
cleanupConnections(connections.splice(0, 3));
await new Promise(resolve => setTimeout(resolve, 50));
}
} catch (err) {
// Expected to fail at some point due to rate limit
expect(i).toBeGreaterThan(0);
break;
}
}
cleanupConnections(connections);
});
tap.test('HttpProxy per-IP validation', async () => {
// Create HttpProxy
httpProxy = new HttpProxy({
port: HTTP_PROXY_PORT,
maxConnectionsPerIP: 2,
connectionRateLimitPerMinute: 10,
routes: []
});
await httpProxy.start();
allProxies.push(httpProxy);
// Update SmartProxy to use HttpProxy for TLS termination
await smartProxy.stop();
smartProxy = new SmartProxy({
routes: [{
name: 'https-route',
match: {
ports: PROXY_PORT + 10
},
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: TEST_SERVER_PORT
}],
tls: {
mode: 'terminate'
}
}
}],
useHttpProxy: [PROXY_PORT + 10],
httpProxyPort: HTTP_PROXY_PORT,
maxConnectionsPerIP: 3
});
await smartProxy.start();
// Test that HttpProxy enforces its own per-IP limits
const connections = await createConcurrentConnections(PROXY_PORT + 10, 2);
expect(connections.length).toEqual(2);
// Should reject additional connections
try {
await createConcurrentConnections(PROXY_PORT + 10, 1);
expect.fail('HttpProxy should enforce per-IP limits');
} catch (err) {
expect(err.message).toInclude('ECONNRESET');
}
cleanupConnections(connections);
});
tap.test('IP tracking cleanup', async (tools) => {
// Create and close many connections from different IPs
const connections: net.Socket[] = [];
for (let i = 0; i < 5; i++) {
const conn = await createConcurrentConnections(PROXY_PORT, 1);
connections.push(...conn);
}
// Close all connections
cleanupConnections(connections);
// Wait for cleanup interval (set to 60s in production, but we'll check immediately)
await tools.delayFor(100);
// Verify that IP tracking has been cleaned up
const securityManager = (smartProxy as any).securityManager;
const ipCount = (securityManager.connectionsByIP as Map<string, any>).size;
// Should have no IPs tracked after cleanup
expect(ipCount).toEqual(0);
});
tap.test('Cleanup queue race condition handling', async () => {
// Create many connections concurrently to trigger batched cleanup
const promises: Promise<net.Socket[]>[] = [];
for (let i = 0; i < 20; i++) {
promises.push(createConcurrentConnections(PROXY_PORT, 1).catch(() => []));
}
const results = await Promise.all(promises);
const allConnections = results.flat();
// Close all connections rapidly
allConnections.forEach(conn => conn.destroy());
// Give cleanup queue time to process
await new Promise(resolve => setTimeout(resolve, 500));
// Verify all connections were cleaned up
const connectionManager = (smartProxy as any).connectionManager;
const remainingConnections = connectionManager.getConnectionCount();
expect(remainingConnections).toEqual(0);
});
tap.test('Cleanup and shutdown', async () => {
// Clean up any remaining connections
cleanupConnections(activeConnections);
activeConnections.length = 0;
// Stop all proxies
for (const proxy of allProxies) {
await proxy.stop();
}
allProxies.length = 0;
// Close all test servers
for (const server of allServers) {
await new Promise<void>((resolve) => {
server.close(() => resolve());
});
}
allServers.length = 0;
});
export default tap.start();

111
test/test.deno.ts Normal file
View File

@@ -0,0 +1,111 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import {
createHttpRoute,
createHttpsTerminateRoute,
createLoadBalancerRoute,
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import {
findMatchingRoutes,
findBestMatchingRoute,
routeMatchesDomain,
routeMatchesPort,
routeMatchesPath,
} from '../ts/proxies/smart-proxy/utils/route-utils.js';
import {
validateRouteConfig,
isValidDomain,
isValidPort,
} from '../ts/proxies/smart-proxy/utils/route-validator.js';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
tap.test('route creation - createHttpRoute produces correct structure', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
expect(route).toHaveProperty('match');
expect(route).toHaveProperty('action');
expect(route.match.domains).toEqual('example.com');
expect(route.action.type).toEqual('forward');
expect(route.action.targets).toBeArray();
expect(route.action.targets![0].host).toEqual('127.0.0.1');
expect(route.action.targets![0].port).toEqual(3000);
});
tap.test('route creation - createHttpRoute with array of domains', async () => {
const route = createHttpRoute(['a.com', 'b.com'], { host: 'localhost', port: 8080 });
expect(route.match.domains).toEqual(['a.com', 'b.com']);
});
tap.test('route validation - validateRouteConfig accepts valid route', async () => {
const route = createHttpRoute('valid.example.com', { host: '10.0.0.1', port: 8080 });
const result = validateRouteConfig(route);
expect(result.valid).toBeTrue();
expect(result.errors).toHaveLength(0);
});
tap.test('route validation - validateRouteConfig rejects missing action', async () => {
const badRoute = { match: { ports: 80 } } as any;
const result = validateRouteConfig(badRoute);
expect(result.valid).toBeFalse();
expect(result.errors.length).toBeGreaterThan(0);
});
tap.test('route validation - isValidDomain checks correctly', async () => {
expect(isValidDomain('example.com')).toBeTrue();
expect(isValidDomain('*.example.com')).toBeTrue();
expect(isValidDomain('')).toBeFalse();
});
tap.test('route validation - isValidPort checks correctly', async () => {
expect(isValidPort(80)).toBeTrue();
expect(isValidPort(443)).toBeTrue();
expect(isValidPort(0)).toBeFalse();
expect(isValidPort(70000)).toBeFalse();
expect(isValidPort(-1)).toBeFalse();
});
tap.test('domain matching - exact domain', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
expect(routeMatchesDomain(route, 'example.com')).toBeTrue();
expect(routeMatchesDomain(route, 'other.com')).toBeFalse();
});
tap.test('domain matching - wildcard domain', async () => {
const route = createHttpRoute('*.example.com', { host: '127.0.0.1', port: 3000 });
expect(routeMatchesDomain(route, 'sub.example.com')).toBeTrue();
expect(routeMatchesDomain(route, 'example.com')).toBeFalse();
});
tap.test('port matching - single port', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
// createHttpRoute defaults to port 80
expect(routeMatchesPort(route, 80)).toBeTrue();
expect(routeMatchesPort(route, 443)).toBeFalse();
});
tap.test('route finding - findBestMatchingRoute selects by priority', async () => {
const lowPriority = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
lowPriority.priority = 10;
const highPriority = createHttpRoute('example.com', { host: '127.0.0.1', port: 4000 });
highPriority.priority = 100;
const routes: IRouteConfig[] = [lowPriority, highPriority];
const best = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
expect(best).toBeDefined();
expect(best!.priority).toEqual(100);
expect(best!.action.targets![0].port).toEqual(4000);
});
tap.test('route finding - findMatchingRoutes returns all matches', async () => {
const route1 = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
const route2 = createHttpRoute('example.com', { host: '127.0.0.1', port: 4000 });
const route3 = createHttpRoute('other.com', { host: '127.0.0.1', port: 5000 });
const matches = findMatchingRoutes([route1, route2, route3], { domain: 'example.com', port: 80 });
expect(matches).toHaveLength(2);
});
export default tap.start();

View File

@@ -0,0 +1,189 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { RouteValidator } from '../ts/proxies/smart-proxy/utils/route-validator.js';
tap.test('Domain Validation - Standard wildcard patterns', async () => {
const testPatterns = [
{ pattern: '*.example.com', shouldPass: true, description: 'Standard wildcard subdomain' },
{ pattern: '*.sub.example.com', shouldPass: true, description: 'Nested wildcard subdomain' },
{ pattern: 'example.com', shouldPass: true, description: 'Plain domain' },
{ pattern: 'sub.example.com', shouldPass: true, description: 'Subdomain' },
{ pattern: '*', shouldPass: true, description: 'Catch-all wildcard' },
{ pattern: 'localhost', shouldPass: true, description: 'Localhost' },
{ pattern: '192.168.1.1', shouldPass: true, description: 'IPv4 address' },
];
for (const { pattern, shouldPass, description } of testPatterns) {
const route = {
name: 'test',
match: {
ports: 443,
domains: pattern
},
action: {
type: 'forward' as const,
targets: [{ host: 'localhost', port: 8080 }]
}
};
const result = RouteValidator.validateRoute(route);
if (shouldPass) {
expect(result.valid).toEqual(true);
console.log(`✅ Domain '${pattern}' correctly accepted (${description})`);
} else {
expect(result.valid).toEqual(false);
console.log(`✅ Domain '${pattern}' correctly rejected (${description})`);
}
}
});
tap.test('Domain Validation - Prefix wildcard patterns (*domain)', async () => {
const testPatterns = [
{ pattern: '*nevermind.cloud', shouldPass: true, description: 'Prefix wildcard without dot' },
{ pattern: '*example.com', shouldPass: true, description: 'Prefix wildcard for TLD' },
{ pattern: '*sub.example.com', shouldPass: true, description: 'Prefix wildcard for subdomain' },
{ pattern: '*api.service.io', shouldPass: true, description: 'Prefix wildcard for nested domain' },
];
for (const { pattern, shouldPass, description } of testPatterns) {
const route = {
name: 'test',
match: {
ports: 443,
domains: pattern
},
action: {
type: 'forward' as const,
targets: [{ host: 'localhost', port: 8080 }]
}
};
const result = RouteValidator.validateRoute(route);
if (shouldPass) {
expect(result.valid).toEqual(true);
console.log(`✅ Domain '${pattern}' correctly accepted (${description})`);
} else {
expect(result.valid).toEqual(false);
console.log(`✅ Domain '${pattern}' correctly rejected (${description})`);
}
}
});
tap.test('Domain Validation - Invalid patterns', async () => {
const invalidPatterns = [
// Note: Empty string validation is handled differently in the validator
// { pattern: '', description: 'Empty string' },
{ pattern: '*.', description: 'Wildcard with trailing dot' },
{ pattern: '.example.com', description: 'Leading dot' },
{ pattern: 'example..com', description: 'Double dots' },
{ pattern: 'exam ple.com', description: 'Space in domain' },
{ pattern: 'example-.com', description: 'Hyphen at end of label' },
{ pattern: '-example.com', description: 'Hyphen at start of label' },
];
for (const { pattern, description } of invalidPatterns) {
const route = {
name: 'test',
match: {
ports: 443,
domains: pattern
},
action: {
type: 'forward' as const,
targets: [{ host: 'localhost', port: 8080 }]
}
};
const result = RouteValidator.validateRoute(route);
if (result.valid === false) {
console.log(`✅ Domain '${pattern}' correctly rejected (${description})`);
} else {
console.log(`❌ Domain '${pattern}' was unexpectedly accepted! (${description})`);
console.log(` Errors: ${result.errors.join(', ')}`);
}
expect(result.valid).toEqual(false);
}
});
tap.test('Domain Validation - Multiple domains in array', async () => {
const route = {
name: 'test',
match: {
ports: 443,
domains: [
'*.example.com',
'*nevermind.cloud',
'api.service.io',
'localhost'
]
},
action: {
type: 'forward' as const,
targets: [{ host: 'localhost', port: 8080 }]
}
};
const result = RouteValidator.validateRoute(route);
expect(result.valid).toEqual(true);
console.log('✅ Multiple valid domains in array correctly accepted');
});
tap.test('Domain Validation - Mixed valid and invalid domains', async () => {
const route = {
name: 'test',
match: {
ports: 443,
domains: [
'*.example.com', // valid
'', // invalid - empty
'localhost' // valid
]
},
action: {
type: 'forward' as const,
targets: [{ host: 'localhost', port: 8080 }]
}
};
const result = RouteValidator.validateRoute(route);
expect(result.valid).toEqual(false);
expect(result.errors.some(e => e.includes('Invalid domain pattern'))).toEqual(true);
console.log('✅ Mixed valid/invalid domains correctly rejected');
});
tap.test('Domain Validation - Real-world patterns from email routes', async () => {
// These are the patterns that were failing from the email conversion
const realWorldPatterns = [
{ pattern: '*nevermind.cloud', shouldPass: true, description: 'nevermind.cloud wildcard' },
{ pattern: '*push.email', shouldPass: true, description: 'push.email wildcard' },
{ pattern: '*.bleu.de', shouldPass: true, description: 'bleu.de subdomain wildcard' },
{ pattern: '*bleu.de', shouldPass: true, description: 'bleu.de prefix wildcard' },
];
for (const { pattern, shouldPass, description } of realWorldPatterns) {
const route = {
name: 'email-route',
match: {
ports: 443,
domains: pattern
},
action: {
type: 'forward' as const,
targets: [{ host: 'mail.server.com', port: 8080 }]
}
};
const result = RouteValidator.validateRoute(route);
if (shouldPass) {
expect(result.valid).toEqual(true);
console.log(`✅ Real-world domain '${pattern}' correctly accepted (${description})`);
} else {
expect(result.valid).toEqual(false);
console.log(`✅ Real-world domain '${pattern}' correctly rejected (${description})`);
}
}
});
export default tap.start();

Some files were not shown because too many files have changed in this diff Show More