ME + Admission + Cleanup Correctness: merge pull request #779 from telemt/flow

ME + Admission + Cleanup Correctness
This commit is contained in:
Alexey
2026-05-10 14:23:09 +03:00
committed by GitHub
41 changed files with 9547 additions and 5294 deletions
+8 -5
View File
@@ -191,6 +191,11 @@ When facing a non-trivial modification, follow this sequence:
4. **Implement**: Make the minimal, isolated change. 4. **Implement**: Make the minimal, isolated change.
5. **Verify**: Explain why the change preserves existing behavior and architectural integrity. 5. **Verify**: Explain why the change preserves existing behavior and architectural integrity.
When the repository contains a `PLAN.md` for the current task, maintain it as
a working checkbox plan while implementing changes. Mark completed and partial
items in `PLAN.md` as the code changes land, so the remaining work stays
explicit and future passes do not waste time rediscovering status.
--- ---
### 9. Context Awareness ### 9. Context Awareness
@@ -222,10 +227,9 @@ Your response MUST consist of two sections:
**Section 2: `## Changes`** **Section 2: `## Changes`**
- For each modified or created file: the filename on a separate line in backticks, followed by the code block. - For each modified or created file: the filename on a separate line in backticks, followed by a concise description of what changed.
- For files **under 200 lines**: return the full file with all changes applied. - Do not include full file contents or long code blocks in `## Changes` unless the user explicitly asks for code text.
- For files **over 200 lines**: return only the changed functions/blocks with at least 3 lines of surrounding context above and below. If the user requests the full file, provide it. - If code snippets are necessary, include only the minimal relevant excerpt.
- New files: full file content.
- End with a suggested git commit message in English. - End with a suggested git commit message in English.
#### Reporting Out-of-Scope Issues #### Reporting Out-of-Scope Issues
@@ -429,4 +433,3 @@ Every patch must be **atomic and production-safe**.
* **No transitional states** — no placeholders, incomplete refactors, or temporary inconsistencies. * **No transitional states** — no placeholders, incomplete refactors, or temporary inconsistencies.
**Invariant:** After any single patch, the repository remains fully functional and buildable. **Invariant:** After any single patch, the repository remains fully functional and buildable.
Generated
+3 -3
View File
@@ -2404,9 +2404,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
version = "0.103.12" version = "0.103.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8279bb85272c9f10811ae6a6c547ff594d6a7f3c6c6b02ee9726d1d0dcfcdd06" checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e"
dependencies = [ dependencies = [
"aws-lc-rs", "aws-lc-rs",
"ring", "ring",
@@ -2791,7 +2791,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]] [[package]]
name = "telemt" name = "telemt"
version = "3.4.10" version = "3.4.11"
dependencies = [ dependencies = [
"aes", "aes",
"anyhow", "anyhow",
+1 -1
View File
@@ -1,6 +1,6 @@
[package] [package]
name = "telemt" name = "telemt"
version = "3.4.10" version = "3.4.11"
edition = "2024" edition = "2024"
[features] [features]
+15
View File
@@ -178,6 +178,21 @@ Notes:
| `data_quota_bytes` | `u64` | no | Per-user traffic quota. | | `data_quota_bytes` | `u64` | no | Per-user traffic quota. |
| `max_unique_ips` | `usize` | no | Per-user unique source IP limit. | | `max_unique_ips` | `usize` | no | Per-user unique source IP limit. |
### `access.user_source_deny` via API
- In current API surface, per-user deny-list is **not** exposed as a dedicated field in `CreateUserRequest` / `PatchUserRequest`.
- Configure it in `config.toml` under `[access.user_source_deny]` and apply via normal config reload path.
- Runtime behavior after apply:
- auth succeeds for username/secret
- source IP is checked against `access.user_source_deny[username]`
- on match, handshake is rejected with the same fail-closed outcome as invalid auth
Example config:
```toml
[access.user_source_deny]
alice = ["203.0.113.0/24", "2001:db8:abcd::/48"]
bob = ["198.51.100.42/32"]
```
### `RotateSecretRequest` ### `RotateSecretRequest`
| Field | Type | Required | Description | | Field | Type | Required | Description |
| --- | --- | --- | --- | | --- | --- | --- | --- |
@@ -128,7 +128,48 @@ Recommended for cleaner testing:
Persisted cache artifacts are useful, but they are not required if packet captures already demonstrate the runtime result. Persisted cache artifacts are useful, but they are not required if packet captures already demonstrate the runtime result.
### 4. Capture a direct-origin trace ### 4. Check TLS-front profile health metrics
If the metrics endpoint is enabled, check the TLS-front profile health before packet-capture validation:
```bash
curl -s http://127.0.0.1:9999/metrics | grep -E 'telemt_tls_front_profile|telemt_tls_fetch_profile_cache|telemt_tls_front_full_cert'
```
The profile-health metrics expose the runtime state of configured TLS front domains:
- `telemt_tls_front_profile_domains` shows configured, emitted, and suppressed domain series.
- `telemt_tls_front_profile_info` shows profile source and feature flags per domain.
- `telemt_tls_front_profile_age_seconds` shows cached profile age.
- `telemt_tls_front_profile_app_data_records` shows cached AppData record count.
- `telemt_tls_front_profile_ticket_records` shows cached ticket-like tail record count.
- `telemt_tls_front_profile_change_cipher_spec_records` shows cached ChangeCipherSpec count.
- `telemt_tls_front_profile_app_data_bytes` shows total cached AppData bytes.
Interpretation:
- `source="merged"` or `source="raw"` means real TLS profile data is being used.
- `source="default"` or `is_default="true"` means the domain currently uses the synthetic default fallback.
- `has_cert_payload="true"` means certificate payload data is available for TLS emulation.
- Non-zero AppData/ticket/CCS counters show captured server-flight shape.
Example healthy output:
```text
telemt_tls_front_profile_domains{status="configured"} 1
telemt_tls_front_profile_domains{status="emitted"} 1
telemt_tls_front_profile_domains{status="suppressed"} 0
telemt_tls_front_profile_info{domain="itunes.apple.com",source="merged",is_default="false",has_cert_info="true",has_cert_payload="true"} 1
telemt_tls_front_profile_age_seconds{domain="itunes.apple.com"} 20
telemt_tls_front_profile_app_data_records{domain="itunes.apple.com"} 3
telemt_tls_front_profile_ticket_records{domain="itunes.apple.com"} 1
telemt_tls_front_profile_change_cipher_spec_records{domain="itunes.apple.com"} 1
telemt_tls_front_profile_app_data_bytes{domain="itunes.apple.com"} 5240
```
These metrics do not prove byte-level origin equivalence. They are an operational health signal that the configured domain is backed by real cached profile data instead of default fallback data.
### 5. Capture a direct-origin trace
From a separate client host, connect directly to the origin: From a separate client host, connect directly to the origin:
@@ -142,7 +183,7 @@ Capture with:
sudo tcpdump -i any -w origin-direct.pcap host ORIGIN_IP and port 443 sudo tcpdump -i any -w origin-direct.pcap host ORIGIN_IP and port 443
``` ```
### 5. Capture a Telemt FakeTLS success-path trace ### 6. Capture a Telemt FakeTLS success-path trace
Now connect to Telemt with a real Telegram client through an `ee` proxy link that targets the Telemt instance. Now connect to Telemt with a real Telegram client through an `ee` proxy link that targets the Telemt instance.
@@ -154,7 +195,7 @@ Capture with:
sudo tcpdump -i any -w telemt-emulated.pcap host TELEMT_IP and port 443 sudo tcpdump -i any -w telemt-emulated.pcap host TELEMT_IP and port 443
``` ```
### 6. Decode TLS record structure ### 7. Decode TLS record structure
Use `tshark` to print record-level structure: Use `tshark` to print record-level structure:
@@ -182,7 +223,7 @@ Focus on the server flight after ClientHello:
- `20` = ChangeCipherSpec - `20` = ChangeCipherSpec
- `23` = ApplicationData - `23` = ApplicationData
### 7. Build a comparison table ### 8. Build a comparison table
A compact table like the following is usually enough: A compact table like the following is usually enough:
@@ -126,9 +126,50 @@ openssl s_client -connect ORIGIN_IP:443 -servername YOUR_DOMAIN </dev/null
2. Дайте ему получить TLS front profile data для выбранного домена. 2. Дайте ему получить TLS front profile data для выбранного домена.
3. Если `tls_front_dir` хранится persistently, убедитесь, что TLS front cache заполнен. 3. Если `tls_front_dir` хранится persistently, убедитесь, что TLS front cache заполнен.
Persisted cache artifacts полезны, но не обязательны, если packet capture уже показывают runtime result. Сохранённые артефакты кэша полезны, но не обязательны, если packet capture уже показывает результат в runtime.
### 4. Снять direct-origin trace ### 4. Проверить метрики состояния TLS-front profile
Если endpoint метрик включён, перед проверкой через packet capture можно быстро проверить состояние TLS-front profile:
```bash
curl -s http://127.0.0.1:9999/metrics | grep -E 'telemt_tls_front_profile|telemt_tls_fetch_profile_cache|telemt_tls_front_full_cert'
```
Метрики состояния профиля показывают runtime-состояние настроенных TLS-front доменов:
- `telemt_tls_front_profile_domains` показывает количество настроенных, экспортируемых и скрытых из-за лимита доменов.
- `telemt_tls_front_profile_info` показывает источник профиля и флаги доступных данных по каждому домену.
- `telemt_tls_front_profile_age_seconds` показывает возраст закешированного профиля.
- `telemt_tls_front_profile_app_data_records` показывает количество закешированных AppData records.
- `telemt_tls_front_profile_ticket_records` показывает количество закешированных ticket-like tail records.
- `telemt_tls_front_profile_change_cipher_spec_records` показывает закешированное количество ChangeCipherSpec records.
- `telemt_tls_front_profile_app_data_bytes` показывает общий размер закешированных AppData bytes.
Интерпретация:
- `source="merged"` или `source="raw"` означает, что используются реальные данные TLS-профиля.
- `source="default"` или `is_default="true"` означает, что домен сейчас работает на synthetic default fallback.
- `has_cert_payload="true"` означает, что certificate payload доступен для TLS emulation.
- Ненулевые AppData/ticket/CCS counters показывают захваченную форму server flight.
Пример здорового состояния:
```text
telemt_tls_front_profile_domains{status="configured"} 1
telemt_tls_front_profile_domains{status="emitted"} 1
telemt_tls_front_profile_domains{status="suppressed"} 0
telemt_tls_front_profile_info{domain="itunes.apple.com",source="merged",is_default="false",has_cert_info="true",has_cert_payload="true"} 1
telemt_tls_front_profile_age_seconds{domain="itunes.apple.com"} 20
telemt_tls_front_profile_app_data_records{domain="itunes.apple.com"} 3
telemt_tls_front_profile_ticket_records{domain="itunes.apple.com"} 1
telemt_tls_front_profile_change_cipher_spec_records{domain="itunes.apple.com"} 1
telemt_tls_front_profile_app_data_bytes{domain="itunes.apple.com"} 5240
```
Эти метрики не доказывают побайтную эквивалентность с origin. Это эксплуатационный сигнал состояния: настроенный домен действительно основан на реальных закешированных данных профиля, а не на default fallback.
### 5. Снять direct-origin trace
С отдельной клиентской машины подключитесь напрямую к origin: С отдельной клиентской машины подключитесь напрямую к origin:
@@ -142,7 +183,7 @@ Capture:
sudo tcpdump -i any -w origin-direct.pcap host ORIGIN_IP and port 443 sudo tcpdump -i any -w origin-direct.pcap host ORIGIN_IP and port 443
``` ```
### 5. Снять Telemt FakeTLS success-path trace ### 6. Снять Telemt FakeTLS success-path trace
Теперь подключитесь к Telemt через реальный Telegram client с `ee` proxy link, который указывает на Telemt instance. Теперь подключитесь к Telemt через реальный Telegram client с `ee` proxy link, который указывает на Telemt instance.
@@ -154,7 +195,7 @@ Capture:
sudo tcpdump -i any -w telemt-emulated.pcap host TELEMT_IP and port 443 sudo tcpdump -i any -w telemt-emulated.pcap host TELEMT_IP and port 443
``` ```
### 6. Декодировать структуру TLS records ### 7. Декодировать структуру TLS records
Используйте `tshark`, чтобы вывести record-level structure: Используйте `tshark`, чтобы вывести record-level structure:
@@ -182,7 +223,7 @@ tshark -r telemt-emulated.pcap -Y "tls.record" -T fields \
- `20` = ChangeCipherSpec - `20` = ChangeCipherSpec
- `23` = ApplicationData - `23` = ApplicationData
### 7. Собрать сравнительную таблицу ### 8. Собрать сравнительную таблицу
Обычно достаточно короткой таблицы такого вида: Обычно достаточно короткой таблицы такого вида:
+15
View File
@@ -2886,6 +2886,7 @@ If your backend or network is very bandwidth-constrained, reduce cap first. If p
| [`user_max_unique_ips_global_each`](#user_max_unique_ips_global_each) | `usize` | `0` | | [`user_max_unique_ips_global_each`](#user_max_unique_ips_global_each) | `usize` | `0` |
| [`user_max_unique_ips_mode`](#user_max_unique_ips_mode) | `"active_window"`, `"time_window"`, or `"combined"` | `"active_window"` | | [`user_max_unique_ips_mode`](#user_max_unique_ips_mode) | `"active_window"`, `"time_window"`, or `"combined"` | `"active_window"` |
| [`user_max_unique_ips_window_secs`](#user_max_unique_ips_window_secs) | `u64` | `30` | | [`user_max_unique_ips_window_secs`](#user_max_unique_ips_window_secs) | `u64` | `30` |
| [`user_source_deny`](#user_source_deny) | `Map<String, IpNetwork[]>` | `{}` |
| [`replay_check_len`](#replay_check_len) | `usize` | `65536` | | [`replay_check_len`](#replay_check_len) | `usize` | `65536` |
| [`replay_window_secs`](#replay_window_secs) | `u64` | `120` | | [`replay_window_secs`](#replay_window_secs) | `u64` | `120` |
| [`ignore_time_skew`](#ignore_time_skew) | `bool` | `false` | | [`ignore_time_skew`](#ignore_time_skew) | `bool` | `false` |
@@ -2990,6 +2991,20 @@ If your backend or network is very bandwidth-constrained, reduce cap first. If p
[access] [access]
user_max_unique_ips_window_secs = 30 user_max_unique_ips_window_secs = 30
``` ```
## user_source_deny
- **Constraints / validation**: Table `username -> IpNetwork[]`. Each network must parse as CIDR (for example `203.0.113.0/24` or `2001:db8::/32`).
- **Description**: Per-user source IP/CIDR deny-list applied **after successful auth** in TLS and MTProto handshake paths. A matched source IP is rejected via the same fail-closed path as invalid auth.
- **Example**:
```toml
[access.user_source_deny]
alice = ["203.0.113.0/24", "2001:db8:abcd::/48"]
bob = ["198.51.100.42/32"]
```
- **How it works (quick check)**:
- connection from user `alice` and source `203.0.113.55` -> rejected (matches `203.0.113.0/24`)
- connection from user `alice` and source `198.51.100.10` -> allowed by this rule set (no match)
## replay_check_len ## replay_check_len
- **Constraints / validation**: `usize`. - **Constraints / validation**: `usize`.
- **Description**: Replay-protection storage length (number of entries tracked for duplicate detection). - **Description**: Replay-protection storage length (number of entries tracked for duplicate detection).
+7
View File
@@ -27,6 +27,8 @@ ACTION="install"
TARGET_VERSION="${VERSION:-latest}" TARGET_VERSION="${VERSION:-latest}"
LANG_CHOICE="en" LANG_CHOICE="en"
PATH="${PATH}:/usr/sbin:/sbin"
set_language() { set_language() {
case "$1" in case "$1" in
ru) ru)
@@ -102,6 +104,7 @@ set_language() {
L_OUT_SUCC_H="УСТАНОВКА УСПЕШНО ЗАВЕРШЕНА" L_OUT_SUCC_H="УСТАНОВКА УСПЕШНО ЗАВЕРШЕНА"
L_OUT_UNINST_H="УДАЛЕНИЕ ЗАВЕРШЕНО" L_OUT_UNINST_H="УДАЛЕНИЕ ЗАВЕРШЕНО"
L_OUT_LINK="Ваша ссылка для подключения к Telegram Proxy:\n" L_OUT_LINK="Ваша ссылка для подключения к Telegram Proxy:\n"
L_ERR_INCORR_ROOT_LOGIN="Используйте 'su -' или 'sudo -i' для входа под пользователем root"
;; ;;
*) *)
L_ERR_DOMAIN_REQ="requires a domain argument." L_ERR_DOMAIN_REQ="requires a domain argument."
@@ -176,6 +179,7 @@ set_language() {
L_OUT_SUCC_H="INSTALLATION SUCCESS" L_OUT_SUCC_H="INSTALLATION SUCCESS"
L_OUT_UNINST_H="UNINSTALLATION COMPLETE" L_OUT_UNINST_H="UNINSTALLATION COMPLETE"
L_OUT_LINK="Your Telegram Proxy connection link:\n" L_OUT_LINK="Your Telegram Proxy connection link:\n"
L_ERR_INCORR_ROOT_LOGIN="Use 'su -' or 'sudo -i' to login under root"
;; ;;
esac esac
} }
@@ -388,6 +392,9 @@ verify_common() {
if [ "$(id -u)" -eq 0 ]; then if [ "$(id -u)" -eq 0 ]; then
SUDO="" SUDO=""
if [ "$(id -u)" -ne 0 ]; then
die "$L_ERR_INCORR_ROOT_LOGIN"
fi
else else
command -v sudo >/dev/null 2>&1 || die "$L_ERR_ROOT" command -v sudo >/dev/null 2>&1 || die "$L_ERR_ROOT"
SUDO="sudo" SUDO="sudo"
+160 -55
View File
@@ -5,6 +5,7 @@ use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use http_body_util::Full; use http_body_util::Full;
use hyper::body::{Bytes, Incoming}; use hyper::body::{Bytes, Incoming};
@@ -12,8 +13,10 @@ use hyper::header::AUTHORIZATION;
use hyper::server::conn::http1; use hyper::server::conn::http1;
use hyper::service::service_fn; use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode}; use hyper::{Method, Request, Response, StatusCode};
use subtle::ConstantTimeEq;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::{Mutex, RwLock, watch}; use tokio::sync::{Mutex, RwLock, Semaphore, watch};
use tokio::time::timeout;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::config::{ApiGrayAction, ProxyConfig}; use crate::config::{ApiGrayAction, ProxyConfig};
@@ -43,7 +46,8 @@ use events::ApiEventStore;
use http_utils::{error_response, read_json, read_optional_json, success_response}; use http_utils::{error_response, read_json, read_optional_json, success_response};
use model::{ use model::{
ApiFailure, ClassCount, CreateUserRequest, DeleteUserResponse, HealthData, HealthReadyData, ApiFailure, ClassCount, CreateUserRequest, DeleteUserResponse, HealthData, HealthReadyData,
PatchUserRequest, RotateSecretRequest, SummaryData, UserActiveIps, PatchUserRequest, ResetUserQuotaResponse, RotateSecretRequest, SummaryData, UserActiveIps,
is_valid_username,
}; };
use runtime_edge::{ use runtime_edge::{
EdgeConnectionsCacheEntry, build_runtime_connections_summary_data, EdgeConnectionsCacheEntry, build_runtime_connections_summary_data,
@@ -66,6 +70,10 @@ use runtime_zero::{
}; };
use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config}; use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config};
const API_MAX_CONTROL_CONNECTIONS: usize = 1024;
const API_HTTP_CONNECTION_TIMEOUT: Duration = Duration::from_secs(15);
const ROUTE_USERNAME_ERROR: &str = "username must match [A-Za-z0-9_.-] and be 1..64 chars";
pub(super) struct ApiRuntimeState { pub(super) struct ApiRuntimeState {
pub(super) process_started_at_epoch_secs: u64, pub(super) process_started_at_epoch_secs: u64,
pub(super) config_reload_count: AtomicU64, pub(super) config_reload_count: AtomicU64,
@@ -80,6 +88,7 @@ pub(super) struct ApiShared {
pub(super) me_pool: Arc<RwLock<Option<Arc<MePool>>>>, pub(super) me_pool: Arc<RwLock<Option<Arc<MePool>>>>,
pub(super) upstream_manager: Arc<UpstreamManager>, pub(super) upstream_manager: Arc<UpstreamManager>,
pub(super) config_path: PathBuf, pub(super) config_path: PathBuf,
pub(super) quota_state_path: PathBuf,
pub(super) detected_ips_rx: watch::Receiver<(Option<IpAddr>, Option<IpAddr>)>, pub(super) detected_ips_rx: watch::Receiver<(Option<IpAddr>, Option<IpAddr>)>,
pub(super) mutation_lock: Arc<Mutex<()>>, pub(super) mutation_lock: Arc<Mutex<()>>,
pub(super) minimal_cache: Arc<Mutex<Option<MinimalCacheEntry>>>, pub(super) minimal_cache: Arc<Mutex<Option<MinimalCacheEntry>>>,
@@ -102,6 +111,18 @@ impl ApiShared {
} }
} }
fn auth_header_matches(actual: &str, expected: &str) -> bool {
actual.as_bytes().ct_eq(expected.as_bytes()).into()
}
fn parse_route_username(user: &str) -> Result<&str, ApiFailure> {
if is_valid_username(user) {
Ok(user)
} else {
Err(ApiFailure::bad_request(ROUTE_USERNAME_ERROR))
}
}
pub async fn serve( pub async fn serve(
listen: SocketAddr, listen: SocketAddr,
stats: Arc<Stats>, stats: Arc<Stats>,
@@ -112,6 +133,7 @@ pub async fn serve(
config_rx: watch::Receiver<Arc<ProxyConfig>>, config_rx: watch::Receiver<Arc<ProxyConfig>>,
admission_rx: watch::Receiver<bool>, admission_rx: watch::Receiver<bool>,
config_path: PathBuf, config_path: PathBuf,
quota_state_path: PathBuf,
detected_ips_rx: watch::Receiver<(Option<IpAddr>, Option<IpAddr>)>, detected_ips_rx: watch::Receiver<(Option<IpAddr>, Option<IpAddr>)>,
process_started_at_epoch_secs: u64, process_started_at_epoch_secs: u64,
startup_tracker: Arc<StartupTracker>, startup_tracker: Arc<StartupTracker>,
@@ -143,6 +165,7 @@ pub async fn serve(
me_pool, me_pool,
upstream_manager, upstream_manager,
config_path, config_path,
quota_state_path,
detected_ips_rx, detected_ips_rx,
mutation_lock: Arc::new(Mutex::new(())), mutation_lock: Arc::new(Mutex::new(())),
minimal_cache: Arc::new(Mutex::new(None)), minimal_cache: Arc::new(Mutex::new(None)),
@@ -164,6 +187,8 @@ pub async fn serve(
shared.runtime_events.clone(), shared.runtime_events.clone(),
); );
let connection_permits = Arc::new(Semaphore::new(API_MAX_CONTROL_CONNECTIONS));
loop { loop {
let (stream, peer) = match listener.accept().await { let (stream, peer) = match listener.accept().await {
Ok(v) => v, Ok(v) => v,
@@ -173,22 +198,47 @@ pub async fn serve(
} }
}; };
let connection_permit = match connection_permits.clone().try_acquire_owned() {
Ok(permit) => permit,
Err(_) => {
debug!(
peer = %peer,
max_connections = API_MAX_CONTROL_CONNECTIONS,
"Dropping API connection: control-plane connection budget exhausted"
);
continue;
}
};
let shared_conn = shared.clone(); let shared_conn = shared.clone();
let config_rx_conn = config_rx.clone(); let config_rx_conn = config_rx.clone();
tokio::spawn(async move { tokio::spawn(async move {
let _connection_permit = connection_permit;
let svc = service_fn(move |req: Request<Incoming>| { let svc = service_fn(move |req: Request<Incoming>| {
let shared_req = shared_conn.clone(); let shared_req = shared_conn.clone();
let config_rx_req = config_rx_conn.clone(); let config_rx_req = config_rx_conn.clone();
async move { handle(req, peer, shared_req, config_rx_req).await } async move { handle(req, peer, shared_req, config_rx_req).await }
}); });
if let Err(error) = http1::Builder::new() match timeout(
.serve_connection(hyper_util::rt::TokioIo::new(stream), svc) API_HTTP_CONNECTION_TIMEOUT,
http1::Builder::new().serve_connection(hyper_util::rt::TokioIo::new(stream), svc),
)
.await .await
{ {
Ok(Ok(())) => {}
Ok(Err(error)) => {
if !error.is_user() { if !error.is_user() {
debug!(error = %error, "API connection error"); debug!(error = %error, "API connection error");
} }
} }
Err(_) => {
debug!(
peer = %peer,
timeout_ms = API_HTTP_CONNECTION_TIMEOUT.as_millis() as u64,
"API connection timed out"
);
}
}
}); });
} }
} }
@@ -242,7 +292,7 @@ async fn handle(
.headers() .headers()
.get(AUTHORIZATION) .get(AUTHORIZATION)
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.map(|v| v == api_cfg.auth_header) .map(|v| auth_header_matches(v, &api_cfg.auth_header))
.unwrap_or(false); .unwrap_or(false);
if !auth_ok { if !auth_ok {
return Ok(error_response( return Ok(error_response(
@@ -491,10 +541,115 @@ async fn handle(
Ok(success_response(status, data, revision)) Ok(success_response(status, data, revision))
} }
_ => { _ => {
if method == Method::POST
&& let Some(user) = normalized_path
.strip_prefix("/v1/users/")
.and_then(|path| path.strip_suffix("/reset-quota"))
&& !user.is_empty()
&& !user.contains('/')
{
let user = parse_route_username(user)?;
if api_cfg.read_only {
return Ok(error_response(
request_id,
ApiFailure::new(
StatusCode::FORBIDDEN,
"read_only",
"API runs in read-only mode",
),
));
}
let snapshot = match crate::quota_state::reset_user_quota(
&shared.quota_state_path,
shared.stats.as_ref(),
user,
)
.await
{
Ok(snapshot) => snapshot,
Err(error) => {
shared.runtime_events.record(
"api.user.reset_quota.failed",
format!("username={} error={}", user, error),
);
return Err(ApiFailure::internal(format!(
"Failed to reset user quota: {}",
error
)));
}
};
shared
.runtime_events
.record("api.user.reset_quota.ok", format!("username={}", user));
let revision = current_revision(&shared.config_path).await?;
return Ok(success_response(
StatusCode::OK,
ResetUserQuotaResponse {
username: user.to_string(),
used_bytes: snapshot.used_bytes,
last_reset_epoch_secs: snapshot.last_reset_epoch_secs,
},
revision,
));
}
if method == Method::POST
&& let Some(base_user) = normalized_path
.strip_prefix("/v1/users/")
.and_then(|path| path.strip_suffix("/rotate-secret"))
&& !base_user.is_empty()
&& !base_user.contains('/')
{
let base_user = parse_route_username(base_user)?;
if api_cfg.read_only {
return Ok(error_response(
request_id,
ApiFailure::new(
StatusCode::FORBIDDEN,
"read_only",
"API runs in read-only mode",
),
));
}
let expected_revision = parse_if_match(req.headers());
let body =
read_optional_json::<RotateSecretRequest>(req.into_body(), body_limit)
.await?;
let result = rotate_secret(
base_user,
body.unwrap_or_default(),
expected_revision,
&shared,
)
.await;
let (mut data, revision) = match result {
Ok(ok) => ok,
Err(error) => {
shared.runtime_events.record(
"api.user.rotate_secret.failed",
format!("username={} code={}", base_user, error.code),
);
return Err(error);
}
};
let runtime_cfg = config_rx.borrow().clone();
data.user.in_runtime =
runtime_cfg.access.users.contains_key(&data.user.username);
shared.runtime_events.record(
"api.user.rotate_secret.ok",
format!("username={}", base_user),
);
let status = if data.user.in_runtime {
StatusCode::OK
} else {
StatusCode::ACCEPTED
};
return Ok(success_response(status, data, revision));
}
if let Some(user) = normalized_path.strip_prefix("/v1/users/") if let Some(user) = normalized_path.strip_prefix("/v1/users/")
&& !user.is_empty() && !user.is_empty()
&& !user.contains('/') && !user.contains('/')
{ {
let user = parse_route_username(user)?;
if method == Method::GET { if method == Method::GET {
let revision = current_revision(&shared.config_path).await?; let revision = current_revision(&shared.config_path).await?;
let disk_cfg = load_config_from_disk(&shared.config_path).await?; let disk_cfg = load_config_from_disk(&shared.config_path).await?;
@@ -595,56 +750,6 @@ async fn handle(
}; };
return Ok(success_response(status, response, revision)); return Ok(success_response(status, response, revision));
} }
if method == Method::POST
&& let Some(base_user) = user.strip_suffix("/rotate-secret")
&& !base_user.is_empty()
&& !base_user.contains('/')
{
if api_cfg.read_only {
return Ok(error_response(
request_id,
ApiFailure::new(
StatusCode::FORBIDDEN,
"read_only",
"API runs in read-only mode",
),
));
}
let expected_revision = parse_if_match(req.headers());
let body =
read_optional_json::<RotateSecretRequest>(req.into_body(), body_limit)
.await?;
let result = rotate_secret(
base_user,
body.unwrap_or_default(),
expected_revision,
&shared,
)
.await;
let (mut data, revision) = match result {
Ok(ok) => ok,
Err(error) => {
shared.runtime_events.record(
"api.user.rotate_secret.failed",
format!("username={} code={}", base_user, error.code),
);
return Err(error);
}
};
let runtime_cfg = config_rx.borrow().clone();
data.user.in_runtime =
runtime_cfg.access.users.contains_key(&data.user.username);
shared.runtime_events.record(
"api.user.rotate_secret.ok",
format!("username={}", base_user),
);
let status = if data.user.in_runtime {
StatusCode::OK
} else {
StatusCode::ACCEPTED
};
return Ok(success_response(status, data, revision));
}
if method == Method::POST { if method == Method::POST {
return Ok(error_response( return Ok(error_response(
request_id, request_id,
+7
View File
@@ -501,6 +501,13 @@ pub(super) struct DeleteUserResponse {
pub(super) in_runtime: bool, pub(super) in_runtime: bool,
} }
#[derive(Serialize)]
pub(super) struct ResetUserQuotaResponse {
pub(super) username: String,
pub(super) used_bytes: u64,
pub(super) last_reset_epoch_secs: u64,
}
#[derive(Deserialize)] #[derive(Deserialize)]
pub(super) struct CreateUserRequest { pub(super) struct CreateUserRequest {
pub(super) username: String, pub(super) username: String,
+10 -6
View File
@@ -465,12 +465,7 @@ pub(super) async fn users_from_config(
.map(|secret| { .map(|secret| {
build_user_links(cfg, secret, startup_detected_ip_v4, startup_detected_ip_v6) build_user_links(cfg, secret, startup_detected_ip_v4, startup_detected_ip_v6)
}) })
.unwrap_or(UserLinks { .unwrap_or_else(empty_user_links);
classic: Vec::new(),
secure: Vec::new(),
tls: Vec::new(),
tls_domains: Vec::new(),
});
users.push(UserInfo { users.push(UserInfo {
in_runtime: runtime_cfg in_runtime: runtime_cfg
.map(|runtime| runtime.access.users.contains_key(&username)) .map(|runtime| runtime.access.users.contains_key(&username))
@@ -511,6 +506,15 @@ pub(super) async fn users_from_config(
users users
} }
fn empty_user_links() -> UserLinks {
UserLinks {
classic: Vec::new(),
secure: Vec::new(),
tls: Vec::new(),
tls_domains: Vec::new(),
}
}
fn build_user_links( fn build_user_links(
cfg: &ProxyConfig, cfg: &ProxyConfig,
secret: &str, secret: &str,
+681 -3
View File
@@ -22,6 +22,672 @@ const MAX_ME_ROUTE_CHANNEL_CAPACITY: usize = 8_192;
const MAX_ME_C2ME_CHANNEL_CAPACITY: usize = 8_192; const MAX_ME_C2ME_CHANNEL_CAPACITY: usize = 8_192;
const MIN_MAX_CLIENT_FRAME_BYTES: usize = 4 * 1024; const MIN_MAX_CLIENT_FRAME_BYTES: usize = 4 * 1024;
const MAX_MAX_CLIENT_FRAME_BYTES: usize = 16 * 1024 * 1024; const MAX_MAX_CLIENT_FRAME_BYTES: usize = 16 * 1024 * 1024;
const MAX_API_REQUEST_BODY_LIMIT_BYTES: usize = 1024 * 1024;
fn is_valid_tls_domain_name(domain: &str) -> bool {
!domain.is_empty()
&& !domain
.chars()
.any(|ch| ch.is_whitespace() || matches!(ch, '/' | '\\'))
}
const TOP_LEVEL_CONFIG_KEYS: &[&str] = &[
"general",
"network",
"server",
"timeouts",
"censorship",
"access",
"upstreams",
"show_link",
"dc_overrides",
"default_dc",
"beobachten",
"beobachten_minutes",
"beobachten_flush_secs",
"beobachten_file",
"include",
];
const GENERAL_CONFIG_KEYS: &[&str] = &[
"data_path",
"quota_state_path",
"config_strict",
"modes",
"prefer_ipv6",
"fast_mode",
"use_middle_proxy",
"proxy_secret_path",
"proxy_secret_url",
"proxy_config_v4_cache_path",
"proxy_config_v4_url",
"proxy_config_v6_cache_path",
"proxy_config_v6_url",
"ad_tag",
"middle_proxy_nat_ip",
"middle_proxy_nat_probe",
"middle_proxy_nat_stun",
"middle_proxy_nat_stun_servers",
"stun_nat_probe_concurrency",
"middle_proxy_pool_size",
"middle_proxy_warm_standby",
"me_init_retry_attempts",
"me2dc_fallback",
"me2dc_fast",
"me_keepalive_enabled",
"me_keepalive_interval_secs",
"me_keepalive_jitter_secs",
"me_keepalive_payload_random",
"rpc_proxy_req_every",
"me_writer_cmd_channel_capacity",
"me_route_channel_capacity",
"me_c2me_channel_capacity",
"me_c2me_send_timeout_ms",
"me_reader_route_data_wait_ms",
"me_d2c_flush_batch_max_frames",
"me_d2c_flush_batch_max_bytes",
"me_d2c_flush_batch_max_delay_us",
"me_d2c_ack_flush_immediate",
"me_quota_soft_overshoot_bytes",
"me_d2c_frame_buf_shrink_threshold_bytes",
"direct_relay_copy_buf_c2s_bytes",
"direct_relay_copy_buf_s2c_bytes",
"crypto_pending_buffer",
"max_client_frame",
"desync_all_full",
"beobachten",
"beobachten_minutes",
"beobachten_flush_secs",
"beobachten_file",
"hardswap",
"me_warmup_stagger_enabled",
"me_warmup_step_delay_ms",
"me_warmup_step_jitter_ms",
"me_reconnect_max_concurrent_per_dc",
"me_reconnect_backoff_base_ms",
"me_reconnect_backoff_cap_ms",
"me_reconnect_fast_retry_count",
"me_single_endpoint_shadow_writers",
"me_single_endpoint_outage_mode_enabled",
"me_single_endpoint_outage_disable_quarantine",
"me_single_endpoint_outage_backoff_min_ms",
"me_single_endpoint_outage_backoff_max_ms",
"me_single_endpoint_shadow_rotate_every_secs",
"me_floor_mode",
"me_adaptive_floor_idle_secs",
"me_adaptive_floor_min_writers_single_endpoint",
"me_adaptive_floor_min_writers_multi_endpoint",
"me_adaptive_floor_recover_grace_secs",
"me_adaptive_floor_writers_per_core_total",
"me_adaptive_floor_cpu_cores_override",
"me_adaptive_floor_max_extra_writers_single_per_core",
"me_adaptive_floor_max_extra_writers_multi_per_core",
"me_adaptive_floor_max_active_writers_per_core",
"me_adaptive_floor_max_warm_writers_per_core",
"me_adaptive_floor_max_active_writers_global",
"me_adaptive_floor_max_warm_writers_global",
"upstream_connect_retry_attempts",
"upstream_connect_retry_backoff_ms",
"upstream_connect_budget_ms",
"tg_connect",
"upstream_unhealthy_fail_threshold",
"upstream_connect_failfast_hard_errors",
"stun_iface_mismatch_ignore",
"unknown_dc_log_path",
"unknown_dc_file_log_enabled",
"log_level",
"disable_colors",
"telemetry",
"me_socks_kdf_policy",
"me_route_backpressure_enabled",
"me_route_fairshare_enabled",
"me_route_backpressure_base_timeout_ms",
"me_route_backpressure_high_timeout_ms",
"me_route_backpressure_high_watermark_pct",
"me_health_interval_ms_unhealthy",
"me_health_interval_ms_healthy",
"me_admission_poll_ms",
"me_warn_rate_limit_ms",
"me_route_no_writer_mode",
"me_route_no_writer_wait_ms",
"me_route_hybrid_max_wait_ms",
"me_route_blocking_send_timeout_ms",
"me_route_inline_recovery_attempts",
"me_route_inline_recovery_wait_ms",
"links",
"fast_mode_min_tls_record",
"update_every",
"me_reinit_every_secs",
"me_hardswap_warmup_delay_min_ms",
"me_hardswap_warmup_delay_max_ms",
"me_hardswap_warmup_extra_passes",
"me_hardswap_warmup_pass_backoff_base_ms",
"me_config_stable_snapshots",
"me_config_apply_cooldown_secs",
"me_snapshot_require_http_2xx",
"me_snapshot_reject_empty_map",
"me_snapshot_min_proxy_for_lines",
"proxy_secret_stable_snapshots",
"proxy_secret_rotate_runtime",
"me_secret_atomic_snapshot",
"proxy_secret_len_max",
"me_pool_drain_ttl_secs",
"me_instadrain",
"me_pool_drain_threshold",
"me_pool_drain_soft_evict_enabled",
"me_pool_drain_soft_evict_grace_secs",
"me_pool_drain_soft_evict_per_writer",
"me_pool_drain_soft_evict_budget_per_core",
"me_pool_drain_soft_evict_cooldown_ms",
"me_bind_stale_mode",
"me_bind_stale_ttl_secs",
"me_pool_min_fresh_ratio",
"me_reinit_drain_timeout_secs",
"proxy_secret_auto_reload_secs",
"proxy_config_auto_reload_secs",
"me_reinit_singleflight",
"me_reinit_trigger_channel",
"me_reinit_coalesce_window_ms",
"me_deterministic_writer_sort",
"me_writer_pick_mode",
"me_writer_pick_sample_size",
"ntp_check",
"ntp_servers",
"auto_degradation_enabled",
"degradation_min_unavailable_dc_groups",
"rst_on_close",
];
const NETWORK_CONFIG_KEYS: &[&str] = &[
"ipv4",
"ipv6",
"prefer",
"multipath",
"stun_use",
"stun_servers",
"stun_tcp_fallback",
"http_ip_detect_urls",
"cache_public_ip_path",
"dns_overrides",
];
const SERVER_CONFIG_KEYS: &[&str] = &[
"port",
"listen_addr_ipv4",
"listen_addr_ipv6",
"listen_unix_sock",
"listen_unix_sock_perm",
"listen_tcp",
"proxy_protocol",
"proxy_protocol_header_timeout_ms",
"proxy_protocol_trusted_cidrs",
"metrics_port",
"metrics_listen",
"metrics_whitelist",
"api",
"admin_api",
"listeners",
"listen_backlog",
"max_connections",
"accept_permit_timeout_ms",
"conntrack_control",
];
const API_CONFIG_KEYS: &[&str] = &[
"enabled",
"listen",
"whitelist",
"gray_action",
"auth_header",
"request_body_limit_bytes",
"minimal_runtime_enabled",
"minimal_runtime_cache_ttl_ms",
"runtime_edge_enabled",
"runtime_edge_cache_ttl_ms",
"runtime_edge_top_n",
"runtime_edge_events_capacity",
"read_only",
];
const CONNTRACK_CONTROL_CONFIG_KEYS: &[&str] = &[
"inline_conntrack_control",
"mode",
"backend",
"profile",
"hybrid_listener_ips",
"pressure_high_watermark_pct",
"pressure_low_watermark_pct",
"delete_budget_per_sec",
];
const LISTENER_CONFIG_KEYS: &[&str] = &[
"ip",
"port",
"announce",
"announce_ip",
"proxy_protocol",
"reuse_allow",
];
const TIMEOUTS_CONFIG_KEYS: &[&str] = &[
"client_first_byte_idle_secs",
"client_handshake",
"relay_idle_policy_v2_enabled",
"relay_client_idle_soft_secs",
"relay_client_idle_hard_secs",
"relay_idle_grace_after_downstream_activity_secs",
"client_keepalive",
"client_ack",
"me_one_retry",
"me_one_timeout_ms",
];
const CENSORSHIP_CONFIG_KEYS: &[&str] = &[
"tls_domain",
"tls_domains",
"unknown_sni_action",
"tls_fetch_scope",
"tls_fetch",
"mask",
"mask_host",
"mask_port",
"mask_unix_sock",
"fake_cert_len",
"tls_emulation",
"tls_front_dir",
"server_hello_delay_min_ms",
"server_hello_delay_max_ms",
"tls_new_session_tickets",
"serverhello_compact",
"tls_full_cert_ttl_secs",
"alpn_enforce",
"mask_proxy_protocol",
"mask_shape_hardening",
"mask_shape_hardening_aggressive_mode",
"mask_shape_bucket_floor_bytes",
"mask_shape_bucket_cap_bytes",
"mask_shape_above_cap_blur",
"mask_shape_above_cap_blur_max_bytes",
"mask_relay_max_bytes",
"mask_relay_timeout_ms",
"mask_relay_idle_timeout_ms",
"mask_classifier_prefetch_timeout_ms",
"mask_timing_normalization_enabled",
"mask_timing_normalization_floor_ms",
"mask_timing_normalization_ceiling_ms",
];
const TLS_FETCH_CONFIG_KEYS: &[&str] = &[
"profiles",
"strict_route",
"attempt_timeout_ms",
"total_budget_ms",
"grease_enabled",
"deterministic",
"profile_cache_ttl_secs",
];
const ACCESS_CONFIG_KEYS: &[&str] = &[
"users",
"user_ad_tags",
"user_max_tcp_conns",
"user_max_tcp_conns_global_each",
"user_expirations",
"user_data_quota",
"user_rate_limits",
"cidr_rate_limits",
"user_max_unique_ips",
"user_max_unique_ips_global_each",
"user_max_unique_ips_mode",
"user_max_unique_ips_window_secs",
"replay_check_len",
"replay_window_secs",
"ignore_time_skew",
];
const RATE_LIMIT_BPS_CONFIG_KEYS: &[&str] = &["up_bps", "down_bps"];
const UPSTREAM_CONFIG_KEYS: &[&str] = &[
"type",
"interface",
"bind_addresses",
"bindtodevice",
"force_bind",
"address",
"user_id",
"username",
"password",
"url",
"weight",
"enabled",
"scopes",
"ipv4",
"ipv6",
];
const PROXY_MODES_CONFIG_KEYS: &[&str] = &["classic", "secure", "tls"];
const TELEMETRY_CONFIG_KEYS: &[&str] = &["core_enabled", "user_enabled", "me_level"];
const LINKS_CONFIG_KEYS: &[&str] = &["show", "public_host", "public_port"];
#[derive(Debug)]
struct UnknownConfigKey {
path: String,
suggestion: Option<String>,
}
fn table_at<'a>(value: &'a toml::Value, path: &[&str]) -> Option<&'a toml::Table> {
let mut current = value;
for segment in path {
current = current.get(*segment)?;
}
current.as_table()
}
fn is_strict_config(parsed_toml: &toml::Value) -> bool {
table_at(parsed_toml, &["general"])
.and_then(|table| table.get("config_strict"))
.and_then(toml::Value::as_bool)
.unwrap_or(false)
}
fn known_config_keys_for_suggestion() -> Vec<&'static str> {
let mut keys = Vec::new();
for group in [
TOP_LEVEL_CONFIG_KEYS,
GENERAL_CONFIG_KEYS,
NETWORK_CONFIG_KEYS,
SERVER_CONFIG_KEYS,
API_CONFIG_KEYS,
CONNTRACK_CONTROL_CONFIG_KEYS,
LISTENER_CONFIG_KEYS,
TIMEOUTS_CONFIG_KEYS,
CENSORSHIP_CONFIG_KEYS,
TLS_FETCH_CONFIG_KEYS,
ACCESS_CONFIG_KEYS,
RATE_LIMIT_BPS_CONFIG_KEYS,
UPSTREAM_CONFIG_KEYS,
PROXY_MODES_CONFIG_KEYS,
TELEMETRY_CONFIG_KEYS,
LINKS_CONFIG_KEYS,
] {
keys.extend_from_slice(group);
}
keys
}
fn levenshtein_distance(a: &str, b: &str) -> usize {
let b_chars: Vec<char> = b.chars().collect();
let mut prev: Vec<usize> = (0..=b_chars.len()).collect();
let mut curr = vec![0usize; b_chars.len() + 1];
for (i, ca) in a.chars().enumerate() {
curr[0] = i + 1;
for (j, cb) in b_chars.iter().enumerate() {
let replace = if ca == *cb { prev[j] } else { prev[j] + 1 };
curr[j + 1] = (prev[j + 1] + 1).min(curr[j] + 1).min(replace);
}
std::mem::swap(&mut prev, &mut curr);
}
prev[b_chars.len()]
}
fn unknown_key_suggestion(key: &str, known_keys: &[&'static str]) -> Option<String> {
let normalized = key.to_ascii_lowercase();
let mut best: Option<(&str, usize)> = None;
for known in known_keys {
let distance = levenshtein_distance(&normalized, known);
let is_better = match best {
Some((_, best_distance)) => distance < best_distance,
None => true,
};
if distance <= 4 && is_better {
best = Some((known, distance));
}
}
best.map(|(known, _)| known.to_string())
}
fn push_unknown_keys(
unknown: &mut Vec<UnknownConfigKey>,
known_for_suggestion: &[&'static str],
path: &str,
table: &toml::Table,
allowed: &[&str],
) {
for key in table.keys() {
if !allowed.contains(&key.as_str()) {
let full_path = if path.is_empty() {
key.clone()
} else {
format!("{path}.{key}")
};
unknown.push(UnknownConfigKey {
path: full_path,
suggestion: unknown_key_suggestion(key, known_for_suggestion),
});
}
}
}
fn check_known_table(
parsed_toml: &toml::Value,
unknown: &mut Vec<UnknownConfigKey>,
known_for_suggestion: &[&'static str],
path: &[&str],
allowed: &[&str],
) {
if let Some(table) = table_at(parsed_toml, path) {
push_unknown_keys(
unknown,
known_for_suggestion,
&path.join("."),
table,
allowed,
);
}
}
fn check_nested_table_value(
unknown: &mut Vec<UnknownConfigKey>,
known_for_suggestion: &[&'static str],
path: String,
value: &toml::Value,
allowed: &[&str],
) {
if let Some(table) = value.as_table() {
push_unknown_keys(unknown, known_for_suggestion, &path, table, allowed);
}
}
fn collect_unknown_config_keys(parsed_toml: &toml::Value) -> Vec<UnknownConfigKey> {
let known_for_suggestion = known_config_keys_for_suggestion();
let mut unknown = Vec::new();
if let Some(root) = parsed_toml.as_table() {
push_unknown_keys(
&mut unknown,
&known_for_suggestion,
"",
root,
TOP_LEVEL_CONFIG_KEYS,
);
}
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["general"],
GENERAL_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["general", "modes"],
PROXY_MODES_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["general", "telemetry"],
TELEMETRY_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["general", "links"],
LINKS_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["network"],
NETWORK_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["server"],
SERVER_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["server", "api"],
API_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["server", "admin_api"],
API_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["server", "conntrack_control"],
CONNTRACK_CONTROL_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["timeouts"],
TIMEOUTS_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["censorship"],
CENSORSHIP_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["censorship", "tls_fetch"],
TLS_FETCH_CONFIG_KEYS,
);
check_known_table(
parsed_toml,
&mut unknown,
&known_for_suggestion,
&["access"],
ACCESS_CONFIG_KEYS,
);
if let Some(listeners) = table_at(parsed_toml, &["server"])
.and_then(|table| table.get("listeners"))
.and_then(toml::Value::as_array)
{
for (idx, listener) in listeners.iter().enumerate() {
check_nested_table_value(
&mut unknown,
&known_for_suggestion,
format!("server.listeners[{idx}]"),
listener,
LISTENER_CONFIG_KEYS,
);
}
}
if let Some(upstreams) = parsed_toml.get("upstreams").and_then(toml::Value::as_array) {
for (idx, upstream) in upstreams.iter().enumerate() {
check_nested_table_value(
&mut unknown,
&known_for_suggestion,
format!("upstreams[{idx}]"),
upstream,
UPSTREAM_CONFIG_KEYS,
);
}
}
for access_map in ["user_rate_limits", "cidr_rate_limits"] {
if let Some(table) = table_at(parsed_toml, &["access"])
.and_then(|access| access.get(access_map))
.and_then(toml::Value::as_table)
{
for (entry_name, value) in table {
check_nested_table_value(
&mut unknown,
&known_for_suggestion,
format!("access.{access_map}.{entry_name}"),
value,
RATE_LIMIT_BPS_CONFIG_KEYS,
);
}
}
}
unknown
}
fn handle_unknown_config_keys(parsed_toml: &toml::Value) -> Result<()> {
let unknown = collect_unknown_config_keys(parsed_toml);
if unknown.is_empty() {
return Ok(());
}
for item in &unknown {
if let Some(suggestion) = item.suggestion.as_deref() {
warn!(
key = %item.path,
suggestion = %suggestion,
"Unknown config key ignored; did you mean the suggested key?"
);
} else {
warn!(key = %item.path, "Unknown config key ignored");
}
}
if is_strict_config(parsed_toml) {
let mut paths = Vec::with_capacity(unknown.len());
for item in unknown {
if let Some(suggestion) = item.suggestion {
paths.push(format!("{} (did you mean `{}`?)", item.path, suggestion));
} else {
paths.push(item.path);
}
}
return Err(ProxyError::Config(format!(
"unknown config keys are not allowed when general.config_strict=true: {}",
paths.join(", ")
)));
}
Ok(())
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct LoadedConfig { pub(crate) struct LoadedConfig {
@@ -337,6 +1003,7 @@ impl ProxyConfig {
let parsed_toml: toml::Value = let parsed_toml: toml::Value =
toml::from_str(&processed).map_err(|e| ProxyError::Config(e.to_string()))?; toml::from_str(&processed).map_err(|e| ProxyError::Config(e.to_string()))?;
handle_unknown_config_keys(&parsed_toml)?;
let general_table = parsed_toml let general_table = parsed_toml
.get("general") .get("general")
.and_then(|value| value.as_table()); .and_then(|value| value.as_table());
@@ -1111,9 +1778,11 @@ impl ProxyConfig {
)); ));
} }
if config.server.api.request_body_limit_bytes == 0 { if !(1..=MAX_API_REQUEST_BODY_LIMIT_BYTES)
.contains(&config.server.api.request_body_limit_bytes)
{
return Err(ProxyError::Config( return Err(ProxyError::Config(
"server.api.request_body_limit_bytes must be > 0".to_string(), "server.api.request_body_limit_bytes must be within [1, 1048576]".to_string(),
)); ));
} }
@@ -1441,13 +2110,22 @@ impl ProxyConfig {
return Err(ProxyError::Config("No modes enabled".to_string())); return Err(ProxyError::Config("No modes enabled".to_string()));
} }
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') { if !is_valid_tls_domain_name(&self.censorship.tls_domain) {
return Err(ProxyError::Config(format!( return Err(ProxyError::Config(format!(
"Invalid tls_domain: '{}'. Must be a valid domain name", "Invalid tls_domain: '{}'. Must be a valid domain name",
self.censorship.tls_domain self.censorship.tls_domain
))); )));
} }
for domain in &self.censorship.tls_domains {
if !is_valid_tls_domain_name(domain) {
return Err(ProxyError::Config(format!(
"Invalid tls_domains entry: '{}'. Must be a valid domain name",
domain
)));
}
}
for (user, tag) in &self.access.user_ad_tags { for (user, tag) in &self.access.user_ad_tags {
let zeros = "00000000000000000000000000000000"; let zeros = "00000000000000000000000000000000";
if !is_valid_ad_tag(tag) { if !is_valid_ad_tag(tag) {
+43 -2
View File
@@ -26,6 +26,10 @@ pub enum LogLevel {
Silent, Silent,
} }
fn default_quota_state_path() -> PathBuf {
PathBuf::from("telemt.limit.json")
}
impl LogLevel { impl LogLevel {
/// Convert to tracing EnvFilter directive string. /// Convert to tracing EnvFilter directive string.
pub fn to_filter_str(&self) -> &'static str { pub fn to_filter_str(&self) -> &'static str {
@@ -375,6 +379,15 @@ pub struct GeneralConfig {
#[serde(default)] #[serde(default)]
pub data_path: Option<PathBuf>, pub data_path: Option<PathBuf>,
/// JSON state file for runtime per-user quota consumption.
#[serde(default = "default_quota_state_path")]
pub quota_state_path: PathBuf,
/// Reject unknown TOML config keys during load.
/// Startup fails fast; hot-reload rejects the new snapshot and keeps the current config.
#[serde(default)]
pub config_strict: bool,
#[serde(default)] #[serde(default)]
pub modes: ProxyModes, pub modes: ProxyModes,
@@ -530,10 +543,17 @@ pub struct GeneralConfig {
pub me_d2c_frame_buf_shrink_threshold_bytes: usize, pub me_d2c_frame_buf_shrink_threshold_bytes: usize,
/// Copy buffer size for client->DC direction in direct relay. /// Copy buffer size for client->DC direction in direct relay.
///
/// This is also the upper bound for one amortized upload rate-limit burst:
/// upload debt is settled before the next relay read instead of blocking
/// inside the completed read path.
#[serde(default = "default_direct_relay_copy_buf_c2s_bytes")] #[serde(default = "default_direct_relay_copy_buf_c2s_bytes")]
pub direct_relay_copy_buf_c2s_bytes: usize, pub direct_relay_copy_buf_c2s_bytes: usize,
/// Copy buffer size for DC->client direction in direct relay. /// Copy buffer size for DC->client direction in direct relay.
///
/// This bounds one direct download rate-limit grant because writes are
/// clipped to the currently available shaper budget.
#[serde(default = "default_direct_relay_copy_buf_s2c_bytes")] #[serde(default = "default_direct_relay_copy_buf_s2c_bytes")]
pub direct_relay_copy_buf_s2c_bytes: usize, pub direct_relay_copy_buf_s2c_bytes: usize,
@@ -974,6 +994,8 @@ impl Default for GeneralConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
data_path: None, data_path: None,
quota_state_path: default_quota_state_path(),
config_strict: false,
modes: ProxyModes::default(), modes: ProxyModes::default(),
prefer_ipv6: false, prefer_ipv6: false,
fast_mode: default_true(), fast_mode: default_true(),
@@ -1876,17 +1898,26 @@ pub struct AccessConfig {
/// ///
/// Each entry supports independent upload (`up_bps`) and download /// Each entry supports independent upload (`up_bps`) and download
/// (`down_bps`) ceilings. A value of `0` in one direction means /// (`down_bps`) ceilings. A value of `0` in one direction means
/// "unlimited" for that direction. /// "unlimited" for that direction. Limits are amortized: a relay quantum
/// may pass as a bounded burst, and the limiter applies the resulting wait
/// before later traffic in the same direction proceeds.
#[serde(default)] #[serde(default)]
pub user_rate_limits: HashMap<String, RateLimitBps>, pub user_rate_limits: HashMap<String, RateLimitBps>,
/// Per-CIDR aggregate transport rate limits in bits-per-second. /// Per-CIDR aggregate transport rate limits in bits-per-second.
/// ///
/// Matching uses longest-prefix-wins semantics. A value of `0` in one /// Matching uses longest-prefix-wins semantics. A value of `0` in one
/// direction means "unlimited" for that direction. /// direction means "unlimited" for that direction. Limits are amortized
/// with the same bounded-burst contract as per-user rate limits.
#[serde(default)] #[serde(default)]
pub cidr_rate_limits: HashMap<IpNetwork, RateLimitBps>, pub cidr_rate_limits: HashMap<IpNetwork, RateLimitBps>,
/// Per-username client source IP/CIDR deny list. Checked after successful
/// authentication; matching IPs get the same rejection path as invalid auth
/// (handshake fails closed for that connection).
#[serde(default)]
pub user_source_deny: HashMap<String, Vec<IpNetwork>>,
#[serde(default)] #[serde(default)]
pub user_max_unique_ips: HashMap<String, usize>, pub user_max_unique_ips: HashMap<String, usize>,
@@ -1922,6 +1953,7 @@ impl Default for AccessConfig {
user_data_quota: HashMap::new(), user_data_quota: HashMap::new(),
user_rate_limits: HashMap::new(), user_rate_limits: HashMap::new(),
cidr_rate_limits: HashMap::new(), cidr_rate_limits: HashMap::new(),
user_source_deny: HashMap::new(),
user_max_unique_ips: HashMap::new(), user_max_unique_ips: HashMap::new(),
user_max_unique_ips_global_each: default_user_max_unique_ips_global_each(), user_max_unique_ips_global_each: default_user_max_unique_ips_global_each(),
user_max_unique_ips_mode: UserMaxUniqueIpsMode::default(), user_max_unique_ips_mode: UserMaxUniqueIpsMode::default(),
@@ -1933,6 +1965,15 @@ impl Default for AccessConfig {
} }
} }
impl AccessConfig {
/// Returns true if `ip` is contained in any CIDR listed for `username` under `user_source_deny`.
pub fn is_user_source_ip_denied(&self, username: &str, ip: IpAddr) -> bool {
self.user_source_deny
.get(username)
.is_some_and(|nets| nets.iter().any(|n| n.contains(ip)))
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct RateLimitBps { pub struct RateLimitBps {
#[serde(default)] #[serde(default)]
+15
View File
@@ -222,6 +222,21 @@ pub enum ProxyError {
#[error("Proxy error: {0}")] #[error("Proxy error: {0}")]
Proxy(String), Proxy(String),
#[error("ME connection lost")]
MiddleConnectionLost,
#[error("Session terminated")]
RouteSwitched,
#[error("Traffic budget wait cancelled")]
TrafficBudgetWaitCancelled,
#[error("Traffic budget wait deadline exceeded")]
TrafficBudgetWaitDeadlineExceeded,
#[error("ME client writer cancelled")]
MiddleClientWriterCancelled,
// ============= Config Errors ============= // ============= Config Errors =============
#[error("Config error: {0}")] #[error("Config error: {0}")]
Config(String), Config(String),
+23 -5
View File
@@ -32,6 +32,7 @@ pub struct UserIpTracker {
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>, limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
limit_window: Arc<RwLock<Duration>>, limit_window: Arc<RwLock<Duration>>,
last_compact_epoch_secs: Arc<AtomicU64>, last_compact_epoch_secs: Arc<AtomicU64>,
cleanup_queue_len: Arc<AtomicU64>,
cleanup_queue: Arc<Mutex<HashMap<(String, IpAddr), usize>>>, cleanup_queue: Arc<Mutex<HashMap<(String, IpAddr), usize>>>,
cleanup_drain_lock: Arc<AsyncMutex<()>>, cleanup_drain_lock: Arc<AsyncMutex<()>>,
} }
@@ -72,6 +73,7 @@ impl UserIpTracker {
limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)), limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)),
limit_window: Arc::new(RwLock::new(Duration::from_secs(30))), limit_window: Arc::new(RwLock::new(Duration::from_secs(30))),
last_compact_epoch_secs: Arc::new(AtomicU64::new(0)), last_compact_epoch_secs: Arc::new(AtomicU64::new(0)),
cleanup_queue_len: Arc::new(AtomicU64::new(0)),
cleanup_queue: Arc::new(Mutex::new(HashMap::new())), cleanup_queue: Arc::new(Mutex::new(HashMap::new())),
cleanup_drain_lock: Arc::new(AsyncMutex::new(())), cleanup_drain_lock: Arc::new(AsyncMutex::new(())),
} }
@@ -120,6 +122,9 @@ impl UserIpTracker {
match self.cleanup_queue.lock() { match self.cleanup_queue.lock() {
Ok(mut queue) => { Ok(mut queue) => {
let count = queue.entry((user, ip)).or_insert(0); let count = queue.entry((user, ip)).or_insert(0);
if *count == 0 {
self.cleanup_queue_len.fetch_add(1, Ordering::Relaxed);
}
*count = count.saturating_add(1); *count = count.saturating_add(1);
self.cleanup_deferred_releases self.cleanup_deferred_releases
.fetch_add(1, Ordering::Relaxed); .fetch_add(1, Ordering::Relaxed);
@@ -127,6 +132,9 @@ impl UserIpTracker {
Err(poisoned) => { Err(poisoned) => {
let mut queue = poisoned.into_inner(); let mut queue = poisoned.into_inner();
let count = queue.entry((user.clone(), ip)).or_insert(0); let count = queue.entry((user.clone(), ip)).or_insert(0);
if *count == 0 {
self.cleanup_queue_len.fetch_add(1, Ordering::Relaxed);
}
*count = count.saturating_add(1); *count = count.saturating_add(1);
self.cleanup_deferred_releases self.cleanup_deferred_releases
.fetch_add(1, Ordering::Relaxed); .fetch_add(1, Ordering::Relaxed);
@@ -156,6 +164,9 @@ impl UserIpTracker {
} }
pub(crate) async fn drain_cleanup_queue(&self) { pub(crate) async fn drain_cleanup_queue(&self) {
if self.cleanup_queue_len.load(Ordering::Relaxed) == 0 {
return;
}
let Ok(_drain_guard) = self.cleanup_drain_lock.try_lock() else { let Ok(_drain_guard) = self.cleanup_drain_lock.try_lock() else {
return; return;
}; };
@@ -173,6 +184,7 @@ impl UserIpTracker {
break; break;
}; };
if let Some(count) = queue.remove(&key) { if let Some(count) = queue.remove(&key) {
self.cleanup_queue_len.fetch_sub(1, Ordering::Relaxed);
drained.insert(key, count); drained.insert(key, count);
} }
} }
@@ -191,6 +203,7 @@ impl UserIpTracker {
break; break;
}; };
if let Some(count) = queue.remove(&key) { if let Some(count) = queue.remove(&key) {
self.cleanup_queue_len.fetch_sub(1, Ordering::Relaxed);
drained.insert(key, count); drained.insert(key, count);
} }
} }
@@ -294,12 +307,17 @@ impl UserIpTracker {
} }
} }
pub async fn run_periodic_maintenance(self: Arc<Self>) {
let mut interval = tokio::time::interval(Duration::from_secs(1));
loop {
interval.tick().await;
self.drain_cleanup_queue().await;
self.maybe_compact_empty_users().await;
}
}
pub async fn memory_stats(&self) -> UserIpTrackerMemoryStats { pub async fn memory_stats(&self) -> UserIpTrackerMemoryStats {
let cleanup_queue_len = self let cleanup_queue_len = self.cleanup_queue_len.load(Ordering::Relaxed) as usize;
.cleanup_queue
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.len();
let active_ips = self.active_ips.read().await; let active_ips = self.active_ips.read().await;
let recent_ips = self.recent_ips.read().await; let recent_ips = self.recent_ips.read().await;
let active_entries = active_ips.values().map(HashMap::len).sum(); let active_entries = active_ips.values().map(HashMap::len).sum();
+7
View File
@@ -17,6 +17,7 @@ pub(crate) async fn configure_admission_gate(
route_runtime: Arc<RouteRuntimeController>, route_runtime: Arc<RouteRuntimeController>,
admission_tx: &watch::Sender<bool>, admission_tx: &watch::Sender<bool>,
config_rx: watch::Receiver<Arc<ProxyConfig>>, config_rx: watch::Receiver<Arc<ProxyConfig>>,
me_ready_rx: watch::Receiver<u64>,
) { ) {
if config.general.use_middle_proxy { if config.general.use_middle_proxy {
if let Some(pool) = me_pool.as_ref() { if let Some(pool) = me_pool.as_ref() {
@@ -52,6 +53,7 @@ pub(crate) async fn configure_admission_gate(
let admission_tx_gate = admission_tx.clone(); let admission_tx_gate = admission_tx.clone();
let route_runtime_gate = route_runtime.clone(); let route_runtime_gate = route_runtime.clone();
let mut config_rx_gate = config_rx.clone(); let mut config_rx_gate = config_rx.clone();
let mut me_ready_rx_gate = me_ready_rx;
let mut admission_poll_ms = config.general.me_admission_poll_ms.max(1); let mut admission_poll_ms = config.general.me_admission_poll_ms.max(1);
tokio::spawn(async move { tokio::spawn(async move {
let mut gate_open = initial_gate_open; let mut gate_open = initial_gate_open;
@@ -74,6 +76,11 @@ pub(crate) async fn configure_admission_gate(
fast_fallback_enabled = cfg.general.me2dc_fallback && cfg.general.me2dc_fast; fast_fallback_enabled = cfg.general.me2dc_fallback && cfg.general.me2dc_fast;
continue; continue;
} }
changed = me_ready_rx_gate.changed() => {
if changed.is_err() {
break;
}
}
_ = tokio::time::sleep(Duration::from_millis(admission_poll_ms)) => {} _ = tokio::time::sleep(Duration::from_millis(admission_poll_ms)) => {}
} }
let ready = pool_for_gate.admission_ready_conditional_cast().await; let ready = pool_for_gate.admission_ready_conditional_cast().await;
+5 -9
View File
@@ -13,7 +13,7 @@ use crate::config::{ProxyConfig, RstOnCloseMode};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::proxy::ClientHandler; use crate::proxy::ClientHandler;
use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController}; use crate::proxy::route_mode::RouteRuntimeController;
use crate::proxy::shared_state::ProxySharedState; use crate::proxy::shared_state::ProxySharedState;
use crate::startup::{COMPONENT_LISTENERS_BIND, StartupTracker}; use crate::startup::{COMPONENT_LISTENERS_BIND, StartupTracker};
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
@@ -492,14 +492,10 @@ pub(crate) fn spawn_tcp_accept_loops(
let handshake_close_reason = let handshake_close_reason =
expected_handshake_close_description(&e); expected_handshake_close_description(&e);
let me_closed = matches!( let me_closed =
&e, matches!(&e, crate::error::ProxyError::MiddleConnectionLost);
crate::error::ProxyError::Proxy(msg) if msg == "ME connection lost" let route_switched =
); matches!(&e, crate::error::ProxyError::RouteSwitched);
let route_switched = matches!(
&e,
crate::error::ProxyError::Proxy(msg) if msg == ROUTE_SWITCH_ERROR_MSG
);
match (peer_close_reason, me_closed) { match (peer_close_reason, me_closed) {
(Some(reason), _) => { (Some(reason), _) => {
+9 -1
View File
@@ -3,7 +3,7 @@
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::RwLock; use tokio::sync::{RwLock, watch};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
@@ -29,6 +29,7 @@ pub(crate) async fn initialize_me_pool(
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
stats: Arc<Stats>, stats: Arc<Stats>,
api_me_pool: Arc<RwLock<Option<Arc<MePool>>>>, api_me_pool: Arc<RwLock<Option<Arc<MePool>>>>,
me_ready_tx: watch::Sender<u64>,
) -> Option<Arc<MePool>> { ) -> Option<Arc<MePool>> {
if !use_middle_proxy { if !use_middle_proxy {
return None; return None;
@@ -314,6 +315,7 @@ pub(crate) async fn initialize_me_pool(
let pool_bg = pool.clone(); let pool_bg = pool.clone();
let rng_bg = rng.clone(); let rng_bg = rng.clone();
let startup_tracker_bg = startup_tracker.clone(); let startup_tracker_bg = startup_tracker.clone();
let me_ready_tx_bg = me_ready_tx.clone();
let retry_limit = if me_init_retry_attempts == 0 { let retry_limit = if me_init_retry_attempts == 0 {
String::from("unlimited") String::from("unlimited")
} else { } else {
@@ -347,6 +349,9 @@ pub(crate) async fn initialize_me_pool(
startup_tracker_bg startup_tracker_bg
.set_me_status(StartupMeStatus::Ready, "ready") .set_me_status(StartupMeStatus::Ready, "ready")
.await; .await;
me_ready_tx_bg.send_modify(|version| {
*version = version.saturating_add(1);
});
info!( info!(
attempt = init_attempt, attempt = init_attempt,
"Middle-End pool initialized successfully" "Middle-End pool initialized successfully"
@@ -474,6 +479,9 @@ pub(crate) async fn initialize_me_pool(
startup_tracker startup_tracker
.set_me_status(StartupMeStatus::Ready, "ready") .set_me_status(StartupMeStatus::Ready, "ready")
.await; .await;
me_ready_tx.send_modify(|version| {
*version = version.saturating_add(1);
});
info!( info!(
attempt = init_attempt, attempt = init_attempt,
"Middle-End pool initialized successfully" "Middle-End pool initialized successfully"
+11 -1
View File
@@ -417,6 +417,8 @@ async fn run_telemt_core(
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
stats.apply_telemetry_policy(TelemetryPolicy::from_config(&config.general.telemetry)); stats.apply_telemetry_policy(TelemetryPolicy::from_config(&config.general.telemetry));
let quota_state_path = config.general.quota_state_path.clone();
crate::quota_state::load_quota_state(&quota_state_path, stats.as_ref()).await;
let upstream_manager = Arc::new(UpstreamManager::new( let upstream_manager = Arc::new(UpstreamManager::new(
config.upstreams.clone(), config.upstreams.clone(),
@@ -496,6 +498,7 @@ async fn run_telemt_core(
let config_rx_api = api_config_rx.clone(); let config_rx_api = api_config_rx.clone();
let admission_rx_api = admission_rx.clone(); let admission_rx_api = admission_rx.clone();
let config_path_api = config_path.clone(); let config_path_api = config_path.clone();
let quota_state_path_api = quota_state_path.clone();
let startup_tracker_api = startup_tracker.clone(); let startup_tracker_api = startup_tracker.clone();
let detected_ips_rx_api = detected_ips_rx.clone(); let detected_ips_rx_api = detected_ips_rx.clone();
tokio::spawn(async move { tokio::spawn(async move {
@@ -509,6 +512,7 @@ async fn run_telemt_core(
config_rx_api, config_rx_api,
admission_rx_api, admission_rx_api,
config_path_api, config_path_api,
quota_state_path_api,
detected_ips_rx_api, detected_ips_rx_api,
process_started_at_epoch_secs, process_started_at_epoch_secs,
startup_tracker_api, startup_tracker_api,
@@ -660,6 +664,8 @@ async fn run_telemt_core(
.await; .await;
} }
let (me_ready_tx, me_ready_rx) = watch::channel(0_u64);
let me_pool: Option<Arc<MePool>> = me_startup::initialize_me_pool( let me_pool: Option<Arc<MePool>> = me_startup::initialize_me_pool(
use_middle_proxy, use_middle_proxy,
&config, &config,
@@ -670,6 +676,7 @@ async fn run_telemt_core(
rng.clone(), rng.clone(),
stats.clone(), stats.clone(),
api_me_pool.clone(), api_me_pool.clone(),
me_ready_tx.clone(),
) )
.await; .await;
@@ -743,6 +750,7 @@ async fn run_telemt_core(
api_config_tx.clone(), api_config_tx.clone(),
me_pool.clone(), me_pool.clone(),
shared_state.clone(), shared_state.clone(),
me_ready_tx.clone(),
) )
.await; .await;
let config_rx = runtime_watches.config_rx; let config_rx = runtime_watches.config_rx;
@@ -756,6 +764,7 @@ async fn run_telemt_core(
route_runtime.clone(), route_runtime.clone(),
&admission_tx, &admission_tx,
config_rx.clone(), config_rx.clone(),
me_ready_rx,
) )
.await; .await;
let _admission_tx_hold = admission_tx; let _admission_tx_hold = admission_tx;
@@ -814,6 +823,7 @@ async fn run_telemt_core(
beobachten.clone(), beobachten.clone(),
shared_state.clone(), shared_state.clone(),
ip_tracker.clone(), ip_tracker.clone(),
tls_cache.clone(),
config_rx.clone(), config_rx.clone(),
) )
.await; .await;
@@ -841,7 +851,7 @@ async fn run_telemt_core(
max_connections.clone(), max_connections.clone(),
); );
shutdown::wait_for_shutdown(process_started_at, me_pool, stats).await; shutdown::wait_for_shutdown(process_started_at, me_pool, stats, quota_state_path).await;
Ok(()) Ok(())
} }
+19
View File
@@ -21,6 +21,7 @@ use crate::startup::{
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stats::telemetry::TelemetryPolicy; use crate::stats::telemetry::TelemetryPolicy;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::tls_front::TlsFrontCache;
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::{MePool, MeReinitTrigger}; use crate::transport::middle_proxy::{MePool, MeReinitTrigger};
@@ -52,6 +53,7 @@ pub(crate) async fn spawn_runtime_tasks(
api_config_tx: watch::Sender<Arc<ProxyConfig>>, api_config_tx: watch::Sender<Arc<ProxyConfig>>,
me_pool_for_policy: Option<Arc<MePool>>, me_pool_for_policy: Option<Arc<MePool>>,
shared_state: Arc<ProxySharedState>, shared_state: Arc<ProxySharedState>,
me_ready_tx: watch::Sender<u64>,
) -> RuntimeWatches { ) -> RuntimeWatches {
let um_clone = upstream_manager.clone(); let um_clone = upstream_manager.clone();
let dc_overrides_for_health = config.dc_overrides.clone(); let dc_overrides_for_health = config.dc_overrides.clone();
@@ -71,6 +73,18 @@ pub(crate) async fn spawn_runtime_tasks(
rc_clone.run_periodic_cleanup().await; rc_clone.run_periodic_cleanup().await;
}); });
let stats_maintenance = stats.clone();
tokio::spawn(async move {
stats_maintenance
.run_periodic_user_stats_maintenance()
.await;
});
let ip_tracker_maintenance = ip_tracker.clone();
tokio::spawn(async move {
ip_tracker_maintenance.run_periodic_maintenance().await;
});
let detected_ip_v4: Option<IpAddr> = probe.detected_ipv4.map(IpAddr::V4); let detected_ip_v4: Option<IpAddr> = probe.detected_ipv4.map(IpAddr::V4);
let detected_ip_v6: Option<IpAddr> = probe.detected_ipv6.map(IpAddr::V6); let detected_ip_v6: Option<IpAddr> = probe.detected_ipv6.map(IpAddr::V6);
debug!( debug!(
@@ -249,12 +263,14 @@ pub(crate) async fn spawn_runtime_tasks(
let pool_clone_sched = pool.clone(); let pool_clone_sched = pool.clone();
let rng_clone_sched = rng.clone(); let rng_clone_sched = rng.clone();
let config_rx_clone_sched = config_rx.clone(); let config_rx_clone_sched = config_rx.clone();
let me_ready_tx_sched = me_ready_tx.clone();
tokio::spawn(async move { tokio::spawn(async move {
crate::transport::middle_proxy::me_reinit_scheduler( crate::transport::middle_proxy::me_reinit_scheduler(
pool_clone_sched, pool_clone_sched,
rng_clone_sched, rng_clone_sched,
config_rx_clone_sched, config_rx_clone_sched,
reinit_rx, reinit_rx,
me_ready_tx_sched,
) )
.await; .await;
}); });
@@ -328,6 +344,7 @@ pub(crate) async fn spawn_metrics_if_configured(
beobachten: Arc<BeobachtenStore>, beobachten: Arc<BeobachtenStore>,
shared_state: Arc<ProxySharedState>, shared_state: Arc<ProxySharedState>,
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
tls_cache: Option<Arc<TlsFrontCache>>,
config_rx: watch::Receiver<Arc<ProxyConfig>>, config_rx: watch::Receiver<Arc<ProxyConfig>>,
) { ) {
// metrics_listen takes precedence; fall back to metrics_port for backward compat. // metrics_listen takes precedence; fall back to metrics_port for backward compat.
@@ -363,6 +380,7 @@ pub(crate) async fn spawn_metrics_if_configured(
let shared_state = shared_state.clone(); let shared_state = shared_state.clone();
let config_rx_metrics = config_rx.clone(); let config_rx_metrics = config_rx.clone();
let ip_tracker_metrics = ip_tracker.clone(); let ip_tracker_metrics = ip_tracker.clone();
let tls_cache_metrics = tls_cache.clone();
let whitelist = config.server.metrics_whitelist.clone(); let whitelist = config.server.metrics_whitelist.clone();
let listen_backlog = config.server.listen_backlog; let listen_backlog = config.server.listen_backlog;
tokio::spawn(async move { tokio::spawn(async move {
@@ -374,6 +392,7 @@ pub(crate) async fn spawn_metrics_if_configured(
beobachten, beobachten,
shared_state, shared_state,
ip_tracker_metrics, ip_tracker_metrics,
tls_cache_metrics,
config_rx_metrics, config_rx_metrics,
whitelist, whitelist,
) )
+27 -1
View File
@@ -8,6 +8,7 @@
//! //!
//! SIGHUP is handled separately in config/hot_reload.rs for config reload. //! SIGHUP is handled separately in config/hot_reload.rs for config reload.
use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@@ -48,9 +49,17 @@ pub(crate) async fn wait_for_shutdown(
process_started_at: Instant, process_started_at: Instant,
me_pool: Option<Arc<MePool>>, me_pool: Option<Arc<MePool>>,
stats: Arc<Stats>, stats: Arc<Stats>,
quota_state_path: PathBuf,
) { ) {
let signal = wait_for_shutdown_signal().await; let signal = wait_for_shutdown_signal().await;
perform_shutdown(signal, process_started_at, me_pool, &stats).await; perform_shutdown(
signal,
process_started_at,
me_pool,
&stats,
quota_state_path,
)
.await;
} }
/// Waits for any shutdown signal (SIGINT, SIGTERM, SIGQUIT). /// Waits for any shutdown signal (SIGINT, SIGTERM, SIGQUIT).
@@ -79,6 +88,7 @@ async fn perform_shutdown(
process_started_at: Instant, process_started_at: Instant,
me_pool: Option<Arc<MePool>>, me_pool: Option<Arc<MePool>>,
stats: &Stats, stats: &Stats,
quota_state_path: PathBuf,
) { ) {
let shutdown_started_at = Instant::now(); let shutdown_started_at = Instant::now();
info!(signal = %signal, "Received shutdown signal"); info!(signal = %signal, "Received shutdown signal");
@@ -109,6 +119,22 @@ async fn perform_shutdown(
} }
} }
match crate::quota_state::save_quota_state(&quota_state_path, stats).await {
Ok(()) => {
info!(
path = %quota_state_path.display(),
"Persisted per-user quota state"
);
}
Err(error) => {
warn!(
error = %error,
path = %quota_state_path.display(),
"Failed to persist per-user quota state"
);
}
}
let shutdown_secs = shutdown_started_at.elapsed().as_secs(); let shutdown_secs = shutdown_started_at.elapsed().as_secs();
info!( info!(
"Shutdown completed successfully in {} {}.", "Shutdown completed successfully in {} {}.",
+1
View File
@@ -25,6 +25,7 @@ mod metrics;
mod network; mod network;
mod protocol; mod protocol;
mod proxy; mod proxy;
mod quota_state;
mod service; mod service;
mod startup; mod startup;
mod stats; mod stats;
+479 -12
View File
@@ -11,6 +11,8 @@ use hyper::service::service_fn;
use hyper::{Request, Response, StatusCode}; use hyper::{Request, Response, StatusCode};
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::Semaphore;
use tokio::time::timeout;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
@@ -18,12 +20,17 @@ use crate::ip_tracker::UserIpTracker;
use crate::proxy::shared_state::ProxySharedState; use crate::proxy::shared_state::ProxySharedState;
use crate::stats::Stats; use crate::stats::Stats;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::tls_front::TlsFrontCache;
use crate::tls_front::cache; use crate::tls_front::cache;
use crate::tls_front::fetcher; use crate::tls_front::fetcher;
use crate::transport::{ListenOptions, create_listener}; use crate::transport::{ListenOptions, create_listener};
// Keeps `/metrics` response size bounded when per-user telemetry is enabled. // Keeps `/metrics` response size bounded when per-user telemetry is enabled.
const USER_LABELED_METRICS_MAX_USERS: usize = 4096; const USER_LABELED_METRICS_MAX_USERS: usize = 4096;
// Keeps TLS-front per-domain health series bounded for large generated configs.
const TLS_FRONT_PROFILE_HEALTH_MAX_DOMAINS: usize = 256;
const METRICS_MAX_CONTROL_CONNECTIONS: usize = 512;
const METRICS_HTTP_CONNECTION_TIMEOUT: Duration = Duration::from_secs(15);
pub async fn serve( pub async fn serve(
port: u16, port: u16,
@@ -33,6 +40,7 @@ pub async fn serve(
beobachten: Arc<BeobachtenStore>, beobachten: Arc<BeobachtenStore>,
shared_state: Arc<ProxySharedState>, shared_state: Arc<ProxySharedState>,
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
tls_cache: Option<Arc<TlsFrontCache>>,
config_rx: tokio::sync::watch::Receiver<Arc<ProxyConfig>>, config_rx: tokio::sync::watch::Receiver<Arc<ProxyConfig>>,
whitelist: Vec<IpNetwork>, whitelist: Vec<IpNetwork>,
) { ) {
@@ -57,6 +65,7 @@ pub async fn serve(
beobachten, beobachten,
shared_state, shared_state,
ip_tracker, ip_tracker,
tls_cache,
config_rx, config_rx,
whitelist, whitelist,
) )
@@ -69,11 +78,11 @@ pub async fn serve(
return; return;
} }
// Fallback: bind on 0.0.0.0 and [::] using metrics_port. // Fallback: keep metrics local unless an explicit metrics_listen is configured.
let mut listener_v4 = None; let mut listener_v4 = None;
let mut listener_v6 = None; let mut listener_v6 = None;
let addr_v4 = SocketAddr::from(([0, 0, 0, 0], port)); let addr_v4 = SocketAddr::from(([127, 0, 0, 1], port));
match bind_metrics_listener(addr_v4, false, listen_backlog) { match bind_metrics_listener(addr_v4, false, listen_backlog) {
Ok(listener) => { Ok(listener) => {
info!( info!(
@@ -87,11 +96,11 @@ pub async fn serve(
} }
} }
let addr_v6 = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], port)); let addr_v6 = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], port));
match bind_metrics_listener(addr_v6, true, listen_backlog) { match bind_metrics_listener(addr_v6, true, listen_backlog) {
Ok(listener) => { Ok(listener) => {
info!( info!(
"Metrics endpoint: http://[::]:{}/metrics and /beobachten", "Metrics endpoint: http://[::1]:{}/metrics and /beobachten",
port port
); );
listener_v6 = Some(listener); listener_v6 = Some(listener);
@@ -112,6 +121,7 @@ pub async fn serve(
beobachten, beobachten,
shared_state, shared_state,
ip_tracker, ip_tracker,
tls_cache,
config_rx, config_rx,
whitelist, whitelist,
) )
@@ -122,6 +132,7 @@ pub async fn serve(
let beobachten_v6 = beobachten.clone(); let beobachten_v6 = beobachten.clone();
let shared_state_v6 = shared_state.clone(); let shared_state_v6 = shared_state.clone();
let ip_tracker_v6 = ip_tracker.clone(); let ip_tracker_v6 = ip_tracker.clone();
let tls_cache_v6 = tls_cache.clone();
let config_rx_v6 = config_rx.clone(); let config_rx_v6 = config_rx.clone();
let whitelist_v6 = whitelist.clone(); let whitelist_v6 = whitelist.clone();
tokio::spawn(async move { tokio::spawn(async move {
@@ -131,6 +142,7 @@ pub async fn serve(
beobachten_v6, beobachten_v6,
shared_state_v6, shared_state_v6,
ip_tracker_v6, ip_tracker_v6,
tls_cache_v6,
config_rx_v6, config_rx_v6,
whitelist_v6, whitelist_v6,
) )
@@ -142,6 +154,7 @@ pub async fn serve(
beobachten, beobachten,
shared_state, shared_state,
ip_tracker, ip_tracker,
tls_cache,
config_rx, config_rx,
whitelist, whitelist,
) )
@@ -171,9 +184,12 @@ async fn serve_listener(
beobachten: Arc<BeobachtenStore>, beobachten: Arc<BeobachtenStore>,
shared_state: Arc<ProxySharedState>, shared_state: Arc<ProxySharedState>,
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
tls_cache: Option<Arc<TlsFrontCache>>,
config_rx: tokio::sync::watch::Receiver<Arc<ProxyConfig>>, config_rx: tokio::sync::watch::Receiver<Arc<ProxyConfig>>,
whitelist: Arc<Vec<IpNetwork>>, whitelist: Arc<Vec<IpNetwork>>,
) { ) {
let connection_permits = Arc::new(Semaphore::new(METRICS_MAX_CONTROL_CONNECTIONS));
loop { loop {
let (stream, peer) = match listener.accept().await { let (stream, peer) = match listener.accept().await {
Ok(v) => v, Ok(v) => v,
@@ -188,17 +204,32 @@ async fn serve_listener(
continue; continue;
} }
let connection_permit = match connection_permits.clone().try_acquire_owned() {
Ok(permit) => permit,
Err(_) => {
debug!(
peer = %peer,
max_connections = METRICS_MAX_CONTROL_CONNECTIONS,
"Dropping metrics connection: control-plane connection budget exhausted"
);
continue;
}
};
let stats = stats.clone(); let stats = stats.clone();
let beobachten = beobachten.clone(); let beobachten = beobachten.clone();
let shared_state = shared_state.clone(); let shared_state = shared_state.clone();
let ip_tracker = ip_tracker.clone(); let ip_tracker = ip_tracker.clone();
let tls_cache = tls_cache.clone();
let config_rx_conn = config_rx.clone(); let config_rx_conn = config_rx.clone();
tokio::spawn(async move { tokio::spawn(async move {
let _connection_permit = connection_permit;
let svc = service_fn(move |req| { let svc = service_fn(move |req| {
let stats = stats.clone(); let stats = stats.clone();
let beobachten = beobachten.clone(); let beobachten = beobachten.clone();
let shared_state = shared_state.clone(); let shared_state = shared_state.clone();
let ip_tracker = ip_tracker.clone(); let ip_tracker = ip_tracker.clone();
let tls_cache = tls_cache.clone();
let config = config_rx_conn.borrow().clone(); let config = config_rx_conn.borrow().clone();
async move { async move {
handle( handle(
@@ -207,17 +238,30 @@ async fn serve_listener(
&beobachten, &beobachten,
&shared_state, &shared_state,
&ip_tracker, &ip_tracker,
tls_cache.as_deref(),
&config, &config,
) )
.await .await
} }
}); });
if let Err(e) = http1::Builder::new() match timeout(
.serve_connection(hyper_util::rt::TokioIo::new(stream), svc) METRICS_HTTP_CONNECTION_TIMEOUT,
http1::Builder::new().serve_connection(hyper_util::rt::TokioIo::new(stream), svc),
)
.await .await
{ {
Ok(Ok(())) => {}
Ok(Err(e)) => {
debug!(error = %e, "Metrics connection error"); debug!(error = %e, "Metrics connection error");
} }
Err(_) => {
debug!(
peer = %peer,
timeout_ms = METRICS_HTTP_CONNECTION_TIMEOUT.as_millis() as u64,
"Metrics connection timed out"
);
}
}
}); });
} }
} }
@@ -228,10 +272,11 @@ async fn handle<B>(
beobachten: &BeobachtenStore, beobachten: &BeobachtenStore,
shared_state: &ProxySharedState, shared_state: &ProxySharedState,
ip_tracker: &UserIpTracker, ip_tracker: &UserIpTracker,
tls_cache: Option<&TlsFrontCache>,
config: &ProxyConfig, config: &ProxyConfig,
) -> Result<Response<Full<Bytes>>, Infallible> { ) -> Result<Response<Full<Bytes>>, Infallible> {
if req.uri().path() == "/metrics" { if req.uri().path() == "/metrics" {
let body = render_metrics(stats, shared_state, config, ip_tracker).await; let body = render_metrics(stats, shared_state, config, ip_tracker, tls_cache).await;
let resp = Response::builder() let resp = Response::builder()
.status(StatusCode::OK) .status(StatusCode::OK)
.header("content-type", "text/plain; version=0.0.4; charset=utf-8") .header("content-type", "text/plain; version=0.0.4; charset=utf-8")
@@ -266,11 +311,138 @@ fn render_beobachten(beobachten: &BeobachtenStore, config: &ProxyConfig) -> Stri
beobachten.snapshot_text(ttl) beobachten.snapshot_text(ttl)
} }
fn tls_front_domains(config: &ProxyConfig) -> Vec<String> {
let mut domains = Vec::with_capacity(1 + config.censorship.tls_domains.len());
if !config.censorship.tls_domain.is_empty() {
domains.push(config.censorship.tls_domain.clone());
}
for domain in &config.censorship.tls_domains {
if !domain.is_empty() && !domains.contains(domain) {
domains.push(domain.clone());
}
}
domains
}
fn prometheus_label_value(value: &str) -> String {
value.replace('\\', "\\\\").replace('"', "\\\"")
}
async fn render_tls_front_profile_health(
out: &mut String,
config: &ProxyConfig,
tls_cache: Option<&TlsFrontCache>,
) {
use std::fmt::Write;
let domains = tls_front_domains(config);
let (health, suppressed) = match (config.censorship.tls_emulation, tls_cache) {
(true, Some(cache)) => {
cache
.profile_health_snapshot(&domains, TLS_FRONT_PROFILE_HEALTH_MAX_DOMAINS)
.await
}
_ => (Vec::new(), domains.len()),
};
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_domains TLS front configured profile domains by export status"
);
let _ = writeln!(out, "# TYPE telemt_tls_front_profile_domains gauge");
let _ = writeln!(
out,
"telemt_tls_front_profile_domains{{status=\"configured\"}} {}",
domains.len()
);
let _ = writeln!(
out,
"telemt_tls_front_profile_domains{{status=\"emitted\"}} {}",
health.len()
);
let _ = writeln!(
out,
"telemt_tls_front_profile_domains{{status=\"suppressed\"}} {}",
suppressed
);
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_info TLS front profile source and feature flags per configured domain"
);
let _ = writeln!(out, "# TYPE telemt_tls_front_profile_info gauge");
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_age_seconds Age of cached TLS front profile data per configured domain"
);
let _ = writeln!(out, "# TYPE telemt_tls_front_profile_age_seconds gauge");
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_app_data_records TLS front cached app-data record count per configured domain"
);
let _ = writeln!(
out,
"# TYPE telemt_tls_front_profile_app_data_records gauge"
);
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_ticket_records TLS front cached ticket-like tail record count per configured domain"
);
let _ = writeln!(out, "# TYPE telemt_tls_front_profile_ticket_records gauge");
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_change_cipher_spec_records TLS front cached ChangeCipherSpec record count per configured domain"
);
let _ = writeln!(
out,
"# TYPE telemt_tls_front_profile_change_cipher_spec_records gauge"
);
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_app_data_bytes TLS front cached total app-data bytes per configured domain"
);
let _ = writeln!(out, "# TYPE telemt_tls_front_profile_app_data_bytes gauge");
for item in health {
let domain = prometheus_label_value(&item.domain);
let _ = writeln!(
out,
"telemt_tls_front_profile_info{{domain=\"{}\",source=\"{}\",is_default=\"{}\",has_cert_info=\"{}\",has_cert_payload=\"{}\"}} 1",
domain, item.source, item.is_default, item.has_cert_info, item.has_cert_payload
);
let _ = writeln!(
out,
"telemt_tls_front_profile_age_seconds{{domain=\"{}\"}} {}",
domain, item.age_seconds
);
let _ = writeln!(
out,
"telemt_tls_front_profile_app_data_records{{domain=\"{}\"}} {}",
domain, item.app_data_records
);
let _ = writeln!(
out,
"telemt_tls_front_profile_ticket_records{{domain=\"{}\"}} {}",
domain, item.ticket_records
);
let _ = writeln!(
out,
"telemt_tls_front_profile_change_cipher_spec_records{{domain=\"{}\"}} {}",
domain, item.change_cipher_spec_count
);
let _ = writeln!(
out,
"telemt_tls_front_profile_app_data_bytes{{domain=\"{}\"}} {}",
domain, item.total_app_data_len
);
}
}
async fn render_metrics( async fn render_metrics(
stats: &Stats, stats: &Stats,
shared_state: &ProxySharedState, shared_state: &ProxySharedState,
config: &ProxyConfig, config: &ProxyConfig,
ip_tracker: &UserIpTracker, ip_tracker: &UserIpTracker,
tls_cache: Option<&TlsFrontCache>,
) -> String { ) -> String {
use std::fmt::Write; use std::fmt::Write;
let mut out = String::with_capacity(4096); let mut out = String::with_capacity(4096);
@@ -423,6 +595,7 @@ async fn render_metrics(
"telemt_tls_front_full_cert_budget_cap_drops_total {}", "telemt_tls_front_full_cert_budget_cap_drops_total {}",
cache::full_cert_sent_cap_drops_for_metrics() cache::full_cert_sent_cap_drops_for_metrics()
); );
render_tls_front_profile_health(&mut out, config, tls_cache).await;
let _ = writeln!( let _ = writeln!(
out, out,
@@ -454,6 +627,21 @@ async fn render_metrics(
} }
); );
let _ = writeln!(
out,
"# HELP telemt_connections_bad_by_class_total Bad/rejected connections by class"
);
let _ = writeln!(out, "# TYPE telemt_connections_bad_by_class_total counter");
if core_enabled {
for (class, total) in stats.get_connects_bad_class_counts() {
let _ = writeln!(
out,
"telemt_connections_bad_by_class_total{{class=\"{}\"}} {}",
class, total
);
}
}
let _ = writeln!( let _ = writeln!(
out, out,
"# HELP telemt_handshake_timeouts_total Handshake timeouts" "# HELP telemt_handshake_timeouts_total Handshake timeouts"
@@ -469,6 +657,24 @@ async fn render_metrics(
} }
); );
let _ = writeln!(
out,
"# HELP telemt_handshake_failures_by_class_total Handshake failures by class"
);
let _ = writeln!(
out,
"# TYPE telemt_handshake_failures_by_class_total counter"
);
if core_enabled {
for (class, total) in stats.get_handshake_failure_class_counts() {
let _ = writeln!(
out,
"telemt_handshake_failures_by_class_total{{class=\"{}\"}} {}",
class, total
);
}
}
let _ = writeln!( let _ = writeln!(
out, out,
"# HELP telemt_auth_expensive_checks_total Expensive authentication candidate checks executed during handshake validation" "# HELP telemt_auth_expensive_checks_total Expensive authentication candidate checks executed during handshake validation"
@@ -520,6 +726,63 @@ async fn render_metrics(
} }
); );
let _ = writeln!(
out,
"# HELP telemt_quota_refund_bytes_total Reserved quota bytes returned before commit"
);
let _ = writeln!(out, "# TYPE telemt_quota_refund_bytes_total counter");
let _ = writeln!(
out,
"telemt_quota_refund_bytes_total {}",
if core_enabled {
stats.get_quota_refund_bytes_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_quota_contention_total Quota reservation CAS contention events"
);
let _ = writeln!(out, "# TYPE telemt_quota_contention_total counter");
let _ = writeln!(
out,
"telemt_quota_contention_total {}",
if core_enabled {
stats.get_quota_contention_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_quota_contention_timeout_total Quota reservations that hit the bounded contention budget"
);
let _ = writeln!(out, "# TYPE telemt_quota_contention_timeout_total counter");
let _ = writeln!(
out,
"telemt_quota_contention_timeout_total {}",
if core_enabled {
stats.get_quota_contention_timeout_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_quota_acquire_cancelled_total Quota acquisitions cancelled before reservation completed"
);
let _ = writeln!(out, "# TYPE telemt_quota_acquire_cancelled_total counter");
let _ = writeln!(
out,
"telemt_quota_acquire_cancelled_total {}",
if core_enabled {
stats.get_quota_acquire_cancelled_total()
} else {
0
}
);
let _ = writeln!( let _ = writeln!(
out, out,
"# HELP telemt_conntrack_control_state Runtime conntrack control state flags" "# HELP telemt_conntrack_control_state Runtime conntrack control state flags"
@@ -634,6 +897,29 @@ async fn render_metrics(
); );
let limiter_metrics = shared_state.traffic_limiter.metrics_snapshot(); let limiter_metrics = shared_state.traffic_limiter.metrics_snapshot();
let _ = writeln!(
out,
"# HELP telemt_rate_limiter_burst_bound_bytes Configured upper bound for one direct relay rate-limit burst"
);
let _ = writeln!(out, "# TYPE telemt_rate_limiter_burst_bound_bytes gauge");
let _ = writeln!(
out,
"telemt_rate_limiter_burst_bound_bytes{{direction=\"up\"}} {}",
if core_enabled {
config.general.direct_relay_copy_buf_c2s_bytes
} else {
0
}
);
let _ = writeln!(
out,
"telemt_rate_limiter_burst_bound_bytes{{direction=\"down\"}} {}",
if core_enabled {
config.general.direct_relay_copy_buf_s2c_bytes
} else {
0
}
);
let _ = writeln!( let _ = writeln!(
out, out,
"# HELP telemt_rate_limiter_throttle_total Traffic limiter throttle events by scope and direction" "# HELP telemt_rate_limiter_throttle_total Traffic limiter throttle events by scope and direction"
@@ -1736,6 +2022,85 @@ async fn render_metrics(
0 0
} }
); );
let _ = writeln!(
out,
"# HELP telemt_me_child_join_timeout_total Middle relay child tasks that did not join before cleanup deadline"
);
let _ = writeln!(out, "# TYPE telemt_me_child_join_timeout_total counter");
let _ = writeln!(
out,
"telemt_me_child_join_timeout_total {}",
if core_enabled {
stats.get_me_child_join_timeout_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_child_abort_total Middle relay child tasks aborted after bounded cleanup timeout"
);
let _ = writeln!(out, "# TYPE telemt_me_child_abort_total counter");
let _ = writeln!(
out,
"telemt_me_child_abort_total {}",
if core_enabled {
stats.get_me_child_abort_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_flow_wait_events_total Flow wait events by reason, direction, and outcome"
);
let _ = writeln!(out, "# TYPE telemt_flow_wait_events_total counter");
let _ = writeln!(
out,
"telemt_flow_wait_events_total{{reason=\"middle_rate_limit\",direction=\"down\",outcome=\"waited\"}} {}",
if core_enabled {
stats.get_flow_wait_middle_rate_limit_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_flow_wait_events_total{{reason=\"middle_rate_limit\",direction=\"down\",outcome=\"cancelled\"}} {}",
if core_enabled {
stats.get_flow_wait_middle_rate_limit_cancelled_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_flow_wait_ms_total Flow wait time in milliseconds by reason and direction"
);
let _ = writeln!(out, "# TYPE telemt_flow_wait_ms_total counter");
let _ = writeln!(
out,
"telemt_flow_wait_ms_total{{reason=\"middle_rate_limit\",direction=\"down\"}} {}",
if core_enabled {
stats.get_flow_wait_middle_rate_limit_ms_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_session_drop_fallback_total Session reservations cleaned by Drop instead of explicit async release"
);
let _ = writeln!(out, "# TYPE telemt_session_drop_fallback_total counter");
let _ = writeln!(
out,
"telemt_session_drop_fallback_total {}",
if core_enabled {
stats.get_session_drop_fallback_total()
} else {
0
}
);
let _ = writeln!( let _ = writeln!(
out, out,
@@ -3328,6 +3693,11 @@ mod tests {
use super::*; use super::*;
use http_body_util::BodyExt; use http_body_util::BodyExt;
use std::net::IpAddr; use std::net::IpAddr;
use std::time::SystemTime;
use crate::tls_front::types::{
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsProfileSource,
};
#[tokio::test] #[tokio::test]
async fn test_render_metrics_format() { async fn test_render_metrics_format() {
@@ -3342,8 +3712,9 @@ mod tests {
stats.increment_connects_all(); stats.increment_connects_all();
stats.increment_connects_all(); stats.increment_connects_all();
stats.increment_connects_bad(); stats.increment_connects_bad_with_class("tls_handshake_bad_client");
stats.increment_handshake_timeouts(); stats.increment_handshake_timeouts();
stats.increment_handshake_failure_class("timeout");
shared_state shared_state
.handshake .handshake
.auth_expensive_checks_total .auth_expensive_checks_total
@@ -3395,7 +3766,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let output = render_metrics(&stats, shared_state.as_ref(), &config, &tracker).await; let output = render_metrics(&stats, shared_state.as_ref(), &config, &tracker, None).await;
assert!(output.contains(&format!( assert!(output.contains(&format!(
"telemt_build_info{{version=\"{}\"}} 1", "telemt_build_info{{version=\"{}\"}} 1",
@@ -3403,7 +3774,11 @@ mod tests {
))); )));
assert!(output.contains("telemt_connections_total 2")); assert!(output.contains("telemt_connections_total 2"));
assert!(output.contains("telemt_connections_bad_total 1")); assert!(output.contains("telemt_connections_bad_total 1"));
assert!(output.contains(
"telemt_connections_bad_by_class_total{class=\"tls_handshake_bad_client\"} 1"
));
assert!(output.contains("telemt_handshake_timeouts_total 1")); assert!(output.contains("telemt_handshake_timeouts_total 1"));
assert!(output.contains("telemt_handshake_failures_by_class_total{class=\"timeout\"} 1"));
assert!(output.contains("telemt_auth_expensive_checks_total 9")); assert!(output.contains("telemt_auth_expensive_checks_total 9"));
assert!(output.contains("telemt_auth_budget_exhausted_total 2")); assert!(output.contains("telemt_auth_budget_exhausted_total 2"));
assert!(output.contains("telemt_upstream_connect_attempt_total 2")); assert!(output.contains("telemt_upstream_connect_attempt_total 2"));
@@ -3457,13 +3832,91 @@ mod tests {
assert!(output.contains("telemt_ip_tracker_cleanup_queue_len 0")); assert!(output.contains("telemt_ip_tracker_cleanup_queue_len 0"));
} }
#[tokio::test]
async fn test_render_tls_front_profile_health() {
let stats = Stats::new();
let shared_state = ProxySharedState::new();
let tracker = UserIpTracker::new();
let mut config = ProxyConfig::default();
config.censorship.tls_domain = "primary.example".to_string();
config.censorship.tls_domains = vec!["fallback.example".to_string()];
let cache = TlsFrontCache::new(
&[
"primary.example".to_string(),
"fallback.example".to_string(),
],
1024,
"tlsfront-profile-health-test",
);
cache
.set(
"primary.example",
CachedTlsData {
server_hello_template: ParsedServerHello {
version: [0x03, 0x03],
random: [0u8; 32],
session_id: Vec::new(),
cipher_suite: [0x13, 0x01],
compression: 0,
extensions: Vec::new(),
},
cert_info: None,
cert_payload: Some(TlsCertPayload {
cert_chain_der: vec![vec![0x30, 0x01]],
certificate_message: vec![0x0b, 0x00, 0x00, 0x00],
}),
app_data_records_sizes: vec![1024, 512],
total_app_data_len: 1536,
behavior_profile: TlsBehaviorProfile {
change_cipher_spec_count: 1,
app_data_record_sizes: vec![1024, 512],
ticket_record_sizes: vec![69],
source: TlsProfileSource::Merged,
},
fetched_at: SystemTime::now(),
domain: "primary.example".to_string(),
},
)
.await;
let output = render_metrics(&stats, &shared_state, &config, &tracker, Some(&cache)).await;
assert!(output.contains("telemt_tls_front_profile_domains{status=\"configured\"} 2"));
assert!(output.contains("telemt_tls_front_profile_domains{status=\"emitted\"} 2"));
assert!(output.contains("telemt_tls_front_profile_domains{status=\"suppressed\"} 0"));
assert!(
output.contains("telemt_tls_front_profile_info{domain=\"primary.example\",source=\"merged\",is_default=\"false\",has_cert_info=\"false\",has_cert_payload=\"true\"} 1")
);
assert!(
output.contains("telemt_tls_front_profile_info{domain=\"fallback.example\",source=\"default\",is_default=\"true\",has_cert_info=\"false\",has_cert_payload=\"false\"} 1")
);
assert!(
output.contains(
"telemt_tls_front_profile_app_data_records{domain=\"primary.example\"} 2"
)
);
assert!(
output
.contains("telemt_tls_front_profile_ticket_records{domain=\"primary.example\"} 1")
);
assert!(output.contains(
"telemt_tls_front_profile_change_cipher_spec_records{domain=\"primary.example\"} 1"
));
assert!(
output.contains(
"telemt_tls_front_profile_app_data_bytes{domain=\"primary.example\"} 1536"
)
);
}
#[tokio::test] #[tokio::test]
async fn test_render_empty_stats() { async fn test_render_empty_stats() {
let stats = Stats::new(); let stats = Stats::new();
let shared_state = ProxySharedState::new(); let shared_state = ProxySharedState::new();
let tracker = UserIpTracker::new(); let tracker = UserIpTracker::new();
let config = ProxyConfig::default(); let config = ProxyConfig::default();
let output = render_metrics(&stats, &shared_state, &config, &tracker).await; let output = render_metrics(&stats, &shared_state, &config, &tracker, None).await;
assert!(output.contains("telemt_connections_total 0")); assert!(output.contains("telemt_connections_total 0"));
assert!(output.contains("telemt_connections_bad_total 0")); assert!(output.contains("telemt_connections_bad_total 0"));
assert!(output.contains("telemt_handshake_timeouts_total 0")); assert!(output.contains("telemt_handshake_timeouts_total 0"));
@@ -3487,7 +3940,7 @@ mod tests {
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_unique_ips_global_each = 2; config.access.user_max_unique_ips_global_each = 2;
let output = render_metrics(&stats, &shared_state, &config, &tracker).await; let output = render_metrics(&stats, &shared_state, &config, &tracker, None).await;
assert!(output.contains("telemt_user_unique_ips_limit{user=\"alice\"} 2")); assert!(output.contains("telemt_user_unique_ips_limit{user=\"alice\"} 2"));
assert!(output.contains("telemt_user_unique_ips_utilization{user=\"alice\"} 0.500000")); assert!(output.contains("telemt_user_unique_ips_utilization{user=\"alice\"} 0.500000"));
@@ -3499,11 +3952,13 @@ mod tests {
let shared_state = ProxySharedState::new(); let shared_state = ProxySharedState::new();
let tracker = UserIpTracker::new(); let tracker = UserIpTracker::new();
let config = ProxyConfig::default(); let config = ProxyConfig::default();
let output = render_metrics(&stats, &shared_state, &config, &tracker).await; let output = render_metrics(&stats, &shared_state, &config, &tracker, None).await;
assert!(output.contains("# TYPE telemt_uptime_seconds gauge")); assert!(output.contains("# TYPE telemt_uptime_seconds gauge"));
assert!(output.contains("# TYPE telemt_connections_total counter")); assert!(output.contains("# TYPE telemt_connections_total counter"));
assert!(output.contains("# TYPE telemt_connections_bad_total counter")); assert!(output.contains("# TYPE telemt_connections_bad_total counter"));
assert!(output.contains("# TYPE telemt_connections_bad_by_class_total counter"));
assert!(output.contains("# TYPE telemt_handshake_timeouts_total counter")); assert!(output.contains("# TYPE telemt_handshake_timeouts_total counter"));
assert!(output.contains("# TYPE telemt_handshake_failures_by_class_total counter"));
assert!(output.contains("# TYPE telemt_auth_expensive_checks_total counter")); assert!(output.contains("# TYPE telemt_auth_expensive_checks_total counter"));
assert!(output.contains("# TYPE telemt_auth_budget_exhausted_total counter")); assert!(output.contains("# TYPE telemt_auth_budget_exhausted_total counter"));
assert!(output.contains("# TYPE telemt_upstream_connect_attempt_total counter")); assert!(output.contains("# TYPE telemt_upstream_connect_attempt_total counter"));
@@ -3546,6 +4001,15 @@ mod tests {
assert!( assert!(
output.contains("# TYPE telemt_tls_front_full_cert_budget_cap_drops_total counter") output.contains("# TYPE telemt_tls_front_full_cert_budget_cap_drops_total counter")
); );
assert!(output.contains("# TYPE telemt_tls_front_profile_domains gauge"));
assert!(output.contains("# TYPE telemt_tls_front_profile_info gauge"));
assert!(output.contains("# TYPE telemt_tls_front_profile_age_seconds gauge"));
assert!(output.contains("# TYPE telemt_tls_front_profile_app_data_records gauge"));
assert!(output.contains("# TYPE telemt_tls_front_profile_ticket_records gauge"));
assert!(
output.contains("# TYPE telemt_tls_front_profile_change_cipher_spec_records gauge")
);
assert!(output.contains("# TYPE telemt_tls_front_profile_app_data_bytes gauge"));
} }
#[tokio::test] #[tokio::test]
@@ -3566,6 +4030,7 @@ mod tests {
&beobachten, &beobachten,
shared_state.as_ref(), shared_state.as_ref(),
&tracker, &tracker,
None,
&config, &config,
) )
.await .await
@@ -3600,6 +4065,7 @@ mod tests {
&beobachten, &beobachten,
shared_state.as_ref(), shared_state.as_ref(),
&tracker, &tracker,
None,
&config, &config,
) )
.await .await
@@ -3617,6 +4083,7 @@ mod tests {
&beobachten, &beobachten,
shared_state.as_ref(), shared_state.as_ref(),
&tracker, &tracker,
None,
&config, &config,
) )
.await .await
+41 -32
View File
@@ -32,7 +32,13 @@ struct UserConnectionReservation {
user: String, user: String,
ip: IpAddr, ip: IpAddr,
tracks_ip: bool, tracks_ip: bool,
active: bool, state: SessionReservationState,
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum SessionReservationState {
Active,
Released,
} }
impl UserConnectionReservation { impl UserConnectionReservation {
@@ -49,28 +55,35 @@ impl UserConnectionReservation {
user, user,
ip, ip,
tracks_ip, tracks_ip,
active: true, state: SessionReservationState::Active,
} }
} }
fn mark_released(&mut self) -> bool {
if self.state != SessionReservationState::Active {
return false;
}
self.state = SessionReservationState::Released;
true
}
async fn release(mut self) { async fn release(mut self) {
if !self.active { if !self.mark_released() {
return; return;
} }
if self.tracks_ip { if self.tracks_ip {
self.ip_tracker.remove_ip(&self.user, self.ip).await; self.ip_tracker.remove_ip(&self.user, self.ip).await;
} }
self.active = false;
self.stats.decrement_user_curr_connects(&self.user); self.stats.decrement_user_curr_connects(&self.user);
} }
} }
impl Drop for UserConnectionReservation { impl Drop for UserConnectionReservation {
fn drop(&mut self) { fn drop(&mut self) {
if !self.active { if !self.mark_released() {
return; return;
} }
self.active = false; self.stats.increment_session_drop_fallback_total();
self.stats.decrement_user_curr_connects(&self.user); self.stats.decrement_user_curr_connects(&self.user);
if self.tracks_ip { if self.tracks_ip {
self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip); self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip);
@@ -466,17 +479,7 @@ where
let mut local_addr = synthetic_local_addr(config.server.port); let mut local_addr = synthetic_local_addr(config.server.port);
if proxy_protocol_enabled { if proxy_protocol_enabled {
let proxy_header_timeout = if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs) {
Duration::from_millis(config.server.proxy_protocol_header_timeout_ms.max(1));
match timeout(
proxy_header_timeout,
parse_proxy_protocol(&mut stream, peer),
)
.await
{
Ok(Ok(info)) => {
if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs)
{
stats.increment_connects_bad_with_class("proxy_protocol_untrusted"); stats.increment_connects_bad_with_class("proxy_protocol_untrusted");
warn!( warn!(
peer = %peer, peer = %peer,
@@ -486,6 +489,16 @@ where
record_beobachten_class(&beobachten, &config, peer.ip(), "other"); record_beobachten_class(&beobachten, &config, peer.ip(), "other");
return Err(ProxyError::InvalidProxyProtocol); return Err(ProxyError::InvalidProxyProtocol);
} }
let proxy_header_timeout =
Duration::from_millis(config.server.proxy_protocol_header_timeout_ms.max(1));
match timeout(
proxy_header_timeout,
parse_proxy_protocol(&mut stream, peer),
)
.await
{
Ok(Ok(info)) => {
debug!( debug!(
peer = %peer, peer = %peer,
client = %info.src_addr, client = %info.src_addr,
@@ -978,15 +991,6 @@ impl RunningClientHandler {
let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
if self.proxy_protocol_enabled { if self.proxy_protocol_enabled {
let proxy_header_timeout =
Duration::from_millis(self.config.server.proxy_protocol_header_timeout_ms.max(1));
match timeout(
proxy_header_timeout,
parse_proxy_protocol(&mut self.stream, self.peer),
)
.await
{
Ok(Ok(info)) => {
if !is_trusted_proxy_source( if !is_trusted_proxy_source(
self.peer.ip(), self.peer.ip(),
&self.config.server.proxy_protocol_trusted_cidrs, &self.config.server.proxy_protocol_trusted_cidrs,
@@ -998,14 +1002,19 @@ impl RunningClientHandler {
trusted = ?self.config.server.proxy_protocol_trusted_cidrs, trusted = ?self.config.server.proxy_protocol_trusted_cidrs,
"Rejecting PROXY protocol header from untrusted source" "Rejecting PROXY protocol header from untrusted source"
); );
record_beobachten_class( record_beobachten_class(&self.beobachten, &self.config, self.peer.ip(), "other");
&self.beobachten,
&self.config,
self.peer.ip(),
"other",
);
return Err(ProxyError::InvalidProxyProtocol); return Err(ProxyError::InvalidProxyProtocol);
} }
let proxy_header_timeout =
Duration::from_millis(self.config.server.proxy_protocol_header_timeout_ms.max(1));
match timeout(
proxy_header_timeout,
parse_proxy_protocol(&mut self.stream, self.peer),
)
.await
{
Ok(Ok(info)) => {
debug!( debug!(
peer = %self.peer, peer = %self.peer,
client = %info.src_addr, client = %info.src_addr,
+2 -3
View File
@@ -18,8 +18,7 @@ use crate::error::{ProxyError, Result};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce}; use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce};
use crate::proxy::route_mode::{ use crate::proxy::route_mode::{
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay,
cutover_stagger_delay,
}; };
use crate::proxy::shared_state::{ use crate::proxy::shared_state::{
ConntrackCloseEvent, ConntrackClosePublishResult, ConntrackCloseReason, ProxySharedState, ConntrackCloseEvent, ConntrackClosePublishResult, ConntrackCloseReason, ProxySharedState,
@@ -360,7 +359,7 @@ where
"Cutover affected direct session, closing client connection" "Cutover affected direct session, closing client connection"
); );
tokio::time::sleep(delay).await; tokio::time::sleep(delay).await;
break Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); break Err(ProxyError::RouteSwitched);
} }
tokio::select! { tokio::select! {
result = &mut relay_result => { result = &mut relay_result => {
+42
View File
@@ -1450,6 +1450,20 @@ where
validated_secret.copy_from_slice(secret); validated_secret.copy_from_slice(secret);
} }
if config
.access
.is_user_source_ip_denied(validated_user.as_str(), peer.ip())
{
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(
peer = %peer,
user = %validated_user,
"TLS handshake rejected: client source IP on per-user deny list (access.user_source_deny)"
);
return HandshakeResult::BadClient { reader, writer };
}
// Reject known replay digests before expensive cache/domain/ALPN policy work. // Reject known replay digests before expensive cache/domain/ALPN policy work.
let digest_half = &validation_digest[..tls::TLS_DIGEST_HALF_LEN]; let digest_half = &validation_digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_tls_digest(digest_half) { if replay_checker.check_tls_digest(digest_half) {
@@ -1795,6 +1809,20 @@ where
let validation = matched_validation.expect("validation must exist when matched"); let validation = matched_validation.expect("validation must exist when matched");
if config
.access
.is_user_source_ip_denied(matched_user.as_str(), peer.ip())
{
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(
peer = %peer,
user = %matched_user,
"MTProto handshake rejected: client source IP on per-user deny list (access.user_source_deny)"
);
return HandshakeResult::BadClient { reader, writer };
}
// Apply replay tracking only after successful authentication. // Apply replay tracking only after successful authentication.
// //
// This ordering prevents an attacker from producing invalid handshakes that // This ordering prevents an attacker from producing invalid handshakes that
@@ -1873,6 +1901,20 @@ where
.auth_expensive_checks_total .auth_expensive_checks_total
.fetch_add(validation_checks as u64, Ordering::Relaxed); .fetch_add(validation_checks as u64, Ordering::Relaxed);
if config
.access
.is_user_source_ip_denied(user.as_str(), peer.ip())
{
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(
peer = %peer,
user = %user,
"MTProto handshake rejected: client source IP on per-user deny list (access.user_source_deny)"
);
return HandshakeResult::BadClient { reader, writer };
}
// Apply replay tracking only after successful authentication. // Apply replay tracking only after successful authentication.
// //
// This ordering prevents an attacker from producing invalid handshakes that // This ordering prevents an attacker from producing invalid handshakes that
+116 -5
View File
@@ -2,6 +2,7 @@
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr; use crate::network::dns_overrides::resolve_socket_addr;
use crate::protocol::tls;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
#[cfg(unix)] #[cfg(unix)]
@@ -328,6 +329,89 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) {
} }
} }
#[cfg(test)]
mod tls_domain_mask_host_tests {
use super::{mask_host_for_initial_data, matching_tls_domain_for_sni};
use crate::config::ProxyConfig;
fn client_hello_with_sni(sni_host: &str) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&[0x03, 0x03]);
body.extend_from_slice(&[0u8; 32]);
body.push(32);
body.extend_from_slice(&[0x42u8; 32]);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let host_bytes = sni_host.as_bytes();
let mut sni_payload = Vec::new();
sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes());
sni_payload.push(0);
sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes());
sni_payload.extend_from_slice(host_bytes);
let mut extensions = Vec::new();
extensions.extend_from_slice(&0x0000u16.to_be_bytes());
extensions.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
extensions.extend_from_slice(&sni_payload);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(0x16);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record
}
fn config_with_tls_domains() -> ProxyConfig {
let mut config = ProxyConfig::default();
config.censorship.tls_domain = "a.com".to_string();
config.censorship.tls_domains = vec!["b.com".to_string(), "c.com".to_string()];
config.censorship.mask_host = Some("a.com".to_string());
config
}
#[test]
fn matching_tls_domain_accepts_primary_and_extra_domains_case_insensitively() {
let config = config_with_tls_domains();
assert_eq!(matching_tls_domain_for_sni(&config, "A.COM"), Some("a.com"));
assert_eq!(matching_tls_domain_for_sni(&config, "B.COM"), Some("b.com"));
assert_eq!(matching_tls_domain_for_sni(&config, "unknown.com"), None);
}
#[test]
fn mask_host_preserves_explicit_non_primary_origin() {
let mut config = config_with_tls_domains();
config.censorship.mask_host = Some("origin.example".to_string());
let initial_data = client_hello_with_sni("b.com");
assert_eq!(
mask_host_for_initial_data(&config, &initial_data),
"origin.example"
);
}
#[test]
fn mask_host_uses_matching_tls_domain_when_mask_host_is_primary_default() {
let config = config_with_tls_domains();
let initial_data = client_hello_with_sni("b.com");
assert_eq!(mask_host_for_initial_data(&config, &initial_data), "b.com");
}
}
/// Detect client type based on initial data /// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str { fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request // Check for HTTP request
@@ -360,6 +444,37 @@ fn parse_mask_host_ip_literal(host: &str) -> Option<IpAddr> {
host.parse::<IpAddr>().ok() host.parse::<IpAddr>().ok()
} }
fn matching_tls_domain_for_sni<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> {
if config.censorship.tls_domain.eq_ignore_ascii_case(sni) {
return Some(config.censorship.tls_domain.as_str());
}
for domain in &config.censorship.tls_domains {
if domain.eq_ignore_ascii_case(sni) {
return Some(domain.as_str());
}
}
None
}
fn mask_host_for_initial_data<'a>(config: &'a ProxyConfig, initial_data: &[u8]) -> &'a str {
let configured_mask_host = config
.censorship
.mask_host
.as_deref()
.unwrap_or(&config.censorship.tls_domain);
if !configured_mask_host.eq_ignore_ascii_case(&config.censorship.tls_domain) {
return configured_mask_host;
}
tls::extract_sni_from_client_hello(initial_data)
.as_deref()
.and_then(|sni| matching_tls_domain_for_sni(config, sni))
.unwrap_or(configured_mask_host)
}
fn canonical_ip(ip: IpAddr) -> IpAddr { fn canonical_ip(ip: IpAddr) -> IpAddr {
match ip { match ip {
IpAddr::V6(v6) => v6 IpAddr::V6(v6) => v6
@@ -734,11 +849,7 @@ pub async fn handle_bad_client<R, W>(
return; return;
} }
let mask_host = config let mask_host = mask_host_for_initial_data(config, initial_data);
.censorship
.mask_host
.as_deref()
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port; let mask_port = config.censorship.mask_port;
// Fail closed when fallback points at our own listener endpoint. // Fail closed when fallback points at our own listener endpoint.
+265 -86
View File
@@ -14,6 +14,7 @@ use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot, watch}; use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot, watch};
use tokio::time::timeout; use tokio::time::timeout;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, trace, warn}; use tracing::{debug, info, trace, warn};
use crate::config::{ConntrackPressureProfile, ProxyConfig}; use crate::config::{ConntrackPressureProfile, ProxyConfig};
@@ -22,8 +23,7 @@ use crate::error::{ProxyError, Result};
use crate::protocol::constants::{secure_padding_len, *}; use crate::protocol::constants::{secure_padding_len, *};
use crate::proxy::handshake::HandshakeSuccess; use crate::proxy::handshake::HandshakeSuccess;
use crate::proxy::route_mode::{ use crate::proxy::route_mode::{
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay,
cutover_stagger_delay,
}; };
use crate::proxy::shared_state::{ use crate::proxy::shared_state::{
ConntrackCloseEvent, ConntrackClosePublishResult, ConntrackCloseReason, ProxySharedState, ConntrackCloseEvent, ConntrackClosePublishResult, ConntrackCloseReason, ProxySharedState,
@@ -65,6 +65,15 @@ const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024;
const QUOTA_RESERVE_SPIN_RETRIES: usize = 32; const QUOTA_RESERVE_SPIN_RETRIES: usize = 32;
const QUOTA_RESERVE_BACKOFF_MIN_MS: u64 = 1; const QUOTA_RESERVE_BACKOFF_MIN_MS: u64 = 1;
const QUOTA_RESERVE_BACKOFF_MAX_MS: u64 = 16; const QUOTA_RESERVE_BACKOFF_MAX_MS: u64 = 16;
const QUOTA_RESERVE_MAX_BACKOFF_ROUNDS: usize = 16;
const ME_CHILD_JOIN_TIMEOUT: Duration = Duration::from_secs(2);
enum MiddleQuotaReserveError {
LimitExceeded,
Contended,
Cancelled,
DeadlineExceeded,
}
#[derive(Default)] #[derive(Default)]
pub(crate) struct DesyncDedupRotationState { pub(crate) struct DesyncDedupRotationState {
@@ -622,21 +631,43 @@ async fn reserve_user_quota_with_yield(
user_stats: &UserStats, user_stats: &UserStats,
bytes: u64, bytes: u64,
limit: u64, limit: u64,
) -> std::result::Result<u64, QuotaReserveError> { stats: &Stats,
cancel: &CancellationToken,
deadline: Option<Instant>,
) -> std::result::Result<u64, MiddleQuotaReserveError> {
let mut backoff_ms = QUOTA_RESERVE_BACKOFF_MIN_MS; let mut backoff_ms = QUOTA_RESERVE_BACKOFF_MIN_MS;
let mut backoff_rounds = 0usize;
loop { loop {
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
match user_stats.quota_try_reserve(bytes, limit) { match user_stats.quota_try_reserve(bytes, limit) {
Ok(total) => return Ok(total), Ok(total) => return Ok(total),
Err(QuotaReserveError::LimitExceeded) => { Err(QuotaReserveError::LimitExceeded) => {
return Err(QuotaReserveError::LimitExceeded); return Err(MiddleQuotaReserveError::LimitExceeded);
}
Err(QuotaReserveError::Contended) => {
stats.increment_quota_contention_total();
std::hint::spin_loop();
} }
Err(QuotaReserveError::Contended) => std::hint::spin_loop(),
} }
} }
tokio::task::yield_now().await; tokio::task::yield_now().await;
tokio::time::sleep(Duration::from_millis(backoff_ms)).await; if deadline.is_some_and(|deadline| Instant::now() >= deadline) {
stats.increment_quota_contention_timeout_total();
return Err(MiddleQuotaReserveError::DeadlineExceeded);
}
tokio::select! {
_ = tokio::time::sleep(Duration::from_millis(backoff_ms)) => {}
_ = cancel.cancelled() => {
stats.increment_quota_acquire_cancelled_total();
return Err(MiddleQuotaReserveError::Cancelled);
}
}
backoff_rounds = backoff_rounds.saturating_add(1);
if backoff_rounds >= QUOTA_RESERVE_MAX_BACKOFF_ROUNDS {
stats.increment_quota_contention_timeout_total();
return Err(MiddleQuotaReserveError::Contended);
}
backoff_ms = backoff_ms backoff_ms = backoff_ms
.saturating_mul(2) .saturating_mul(2)
.min(QUOTA_RESERVE_BACKOFF_MAX_MS); .min(QUOTA_RESERVE_BACKOFF_MAX_MS);
@@ -647,12 +678,13 @@ async fn wait_for_traffic_budget(
lease: Option<&Arc<TrafficLease>>, lease: Option<&Arc<TrafficLease>>,
direction: RateDirection, direction: RateDirection,
bytes: u64, bytes: u64,
) { deadline: Option<Instant>,
) -> Result<()> {
if bytes == 0 { if bytes == 0 {
return; return Ok(());
} }
let Some(lease) = lease else { let Some(lease) = lease else {
return; return Ok(());
}; };
let mut remaining = bytes; let mut remaining = bytes;
@@ -664,6 +696,9 @@ async fn wait_for_traffic_budget(
} }
let wait_started_at = Instant::now(); let wait_started_at = Instant::now();
if deadline.is_some_and(|deadline| wait_started_at >= deadline) {
return Err(ProxyError::TrafficBudgetWaitDeadlineExceeded);
}
tokio::time::sleep(next_refill_delay()).await; tokio::time::sleep(next_refill_delay()).await;
let wait_ms = wait_started_at let wait_ms = wait_started_at
.elapsed() .elapsed()
@@ -676,6 +711,59 @@ async fn wait_for_traffic_budget(
wait_ms, wait_ms,
); );
} }
Ok(())
}
async fn wait_for_traffic_budget_or_cancel(
lease: Option<&Arc<TrafficLease>>,
direction: RateDirection,
bytes: u64,
cancel: &CancellationToken,
stats: &Stats,
deadline: Option<Instant>,
) -> Result<()> {
if bytes == 0 {
return Ok(());
}
let Some(lease) = lease else {
return Ok(());
};
let mut remaining = bytes;
while remaining > 0 {
let consume = lease.try_consume(direction, remaining);
if consume.granted > 0 {
remaining = remaining.saturating_sub(consume.granted);
continue;
}
let wait_started_at = Instant::now();
if deadline.is_some_and(|deadline| wait_started_at >= deadline) {
stats.increment_flow_wait_middle_rate_limit_cancelled_total();
return Err(ProxyError::TrafficBudgetWaitDeadlineExceeded);
}
tokio::select! {
_ = tokio::time::sleep(next_refill_delay()) => {}
_ = cancel.cancelled() => {
stats.increment_flow_wait_middle_rate_limit_cancelled_total();
return Err(ProxyError::TrafficBudgetWaitCancelled);
}
}
let wait_ms = wait_started_at
.elapsed()
.as_millis()
.min(u128::from(u64::MAX)) as u64;
lease.observe_wait_ms(
direction,
consume.blocked_user,
consume.blocked_cidr,
wait_ms,
);
stats.observe_flow_wait_middle_rate_limit_ms(wait_ms);
}
Ok(())
} }
fn classify_me_d2c_flush_reason( fn classify_me_d2c_flush_reason(
@@ -1114,7 +1202,7 @@ where
tokio::time::sleep(delay).await; tokio::time::sleep(delay).await;
let _ = me_pool.send_close(conn_id).await; let _ = me_pool.send_close(conn_id).await;
me_pool.registry().unregister(conn_id).await; me_pool.registry().unregister(conn_id).await;
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); return Err(ProxyError::RouteSwitched);
} }
// Per-user ad_tag from access.user_ad_tags; fallback to general.ad_tag (hot-reloadable) // Per-user ad_tag from access.user_ad_tags; fallback to general.ad_tag (hot-reloadable)
@@ -1169,7 +1257,7 @@ where
let c2me_byte_semaphore = Arc::new(Semaphore::new(c2me_byte_budget)); let c2me_byte_semaphore = Arc::new(Semaphore::new(c2me_byte_budget));
let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(c2me_channel_capacity); let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(c2me_channel_capacity);
let me_pool_c2me = me_pool.clone(); let me_pool_c2me = me_pool.clone();
let c2me_sender = tokio::spawn(async move { let mut c2me_sender = tokio::spawn(async move {
let mut sent_since_yield = 0usize; let mut sent_since_yield = 0usize;
while let Some(cmd) = c2me_rx.recv().await { while let Some(cmd) = c2me_rx.recv().await {
match cmd { match cmd {
@@ -1205,16 +1293,18 @@ where
}); });
let (stop_tx, mut stop_rx) = oneshot::channel::<()>(); let (stop_tx, mut stop_rx) = oneshot::channel::<()>();
let flow_cancel = CancellationToken::new();
let mut me_rx_task = me_rx; let mut me_rx_task = me_rx;
let stats_clone = stats.clone(); let stats_clone = stats.clone();
let rng_clone = rng.clone(); let rng_clone = rng.clone();
let user_clone = user.clone(); let user_clone = user.clone();
let quota_user_stats_me_writer = quota_user_stats.clone(); let quota_user_stats_me_writer = quota_user_stats.clone();
let traffic_lease_me_writer = traffic_lease.clone(); let traffic_lease_me_writer = traffic_lease.clone();
let flow_cancel_me_writer = flow_cancel.clone();
let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone();
let bytes_me2c_clone = bytes_me2c.clone(); let bytes_me2c_clone = bytes_me2c.clone();
let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config);
let me_writer = tokio::spawn(async move { let mut me_writer = tokio::spawn(async move {
let mut writer = crypto_writer; let mut writer = crypto_writer;
let mut frame_buf = Vec::with_capacity(16 * 1024); let mut frame_buf = Vec::with_capacity(16 * 1024);
let shrink_threshold = d2c_flush_policy.frame_buf_shrink_threshold_bytes; let shrink_threshold = d2c_flush_policy.frame_buf_shrink_threshold_bytes;
@@ -1234,7 +1324,7 @@ where
let Some(first) = msg else { let Some(first) = msg else {
debug!(conn_id, "ME channel closed"); debug!(conn_id, "ME channel closed");
shrink_session_vec(&mut frame_buf, shrink_threshold); shrink_session_vec(&mut frame_buf, shrink_threshold);
return Err(ProxyError::Proxy("ME connection lost".into())); return Err(ProxyError::MiddleConnectionLost);
}; };
let mut batch_frames = 0usize; let mut batch_frames = 0usize;
@@ -1256,6 +1346,7 @@ where
quota_limit, quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes, d2c_flush_policy.quota_soft_overshoot_bytes,
traffic_lease_me_writer.as_ref(), traffic_lease_me_writer.as_ref(),
&flow_cancel_me_writer,
bytes_me2c_clone.as_ref(), bytes_me2c_clone.as_ref(),
conn_id, conn_id,
d2c_flush_policy.ack_flush_immediate, d2c_flush_policy.ack_flush_immediate,
@@ -1276,7 +1367,7 @@ where
} else { } else {
None None
}; };
let _ = writer.flush().await; let _ = flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await;
let flush_duration_us = flush_started_at.map(|started| { let flush_duration_us = flush_started_at.map(|started| {
started started
.elapsed() .elapsed()
@@ -1317,6 +1408,7 @@ where
quota_limit, quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes, d2c_flush_policy.quota_soft_overshoot_bytes,
traffic_lease_me_writer.as_ref(), traffic_lease_me_writer.as_ref(),
&flow_cancel_me_writer,
bytes_me2c_clone.as_ref(), bytes_me2c_clone.as_ref(),
conn_id, conn_id,
d2c_flush_policy.ack_flush_immediate, d2c_flush_policy.ack_flush_immediate,
@@ -1338,7 +1430,8 @@ where
} else { } else {
None None
}; };
let _ = writer.flush().await; let _ =
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await;
let flush_duration_us = flush_started_at.map(|started| { let flush_duration_us = flush_started_at.map(|started| {
started started
.elapsed() .elapsed()
@@ -1381,6 +1474,7 @@ where
quota_limit, quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes, d2c_flush_policy.quota_soft_overshoot_bytes,
traffic_lease_me_writer.as_ref(), traffic_lease_me_writer.as_ref(),
&flow_cancel_me_writer,
bytes_me2c_clone.as_ref(), bytes_me2c_clone.as_ref(),
conn_id, conn_id,
d2c_flush_policy.ack_flush_immediate, d2c_flush_policy.ack_flush_immediate,
@@ -1405,7 +1499,11 @@ where
} else { } else {
None None
}; };
let _ = writer.flush().await; let _ = flush_client_or_cancel(
&mut writer,
&flow_cancel_me_writer,
)
.await;
let flush_duration_us = flush_started_at.map(|started| { let flush_duration_us = flush_started_at.map(|started| {
started started
.elapsed() .elapsed()
@@ -1447,6 +1545,7 @@ where
quota_limit, quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes, d2c_flush_policy.quota_soft_overshoot_bytes,
traffic_lease_me_writer.as_ref(), traffic_lease_me_writer.as_ref(),
&flow_cancel_me_writer,
bytes_me2c_clone.as_ref(), bytes_me2c_clone.as_ref(),
conn_id, conn_id,
d2c_flush_policy.ack_flush_immediate, d2c_flush_policy.ack_flush_immediate,
@@ -1471,7 +1570,11 @@ where
} else { } else {
None None
}; };
let _ = writer.flush().await; let _ = flush_client_or_cancel(
&mut writer,
&flow_cancel_me_writer,
)
.await;
let flush_duration_us = flush_started_at.map(|started| { let flush_duration_us = flush_started_at.map(|started| {
started started
.elapsed() .elapsed()
@@ -1495,7 +1598,7 @@ where
Ok(None) => { Ok(None) => {
debug!(conn_id, "ME channel closed"); debug!(conn_id, "ME channel closed");
shrink_session_vec(&mut frame_buf, shrink_threshold); shrink_session_vec(&mut frame_buf, shrink_threshold);
return Err(ProxyError::Proxy("ME connection lost".into())); return Err(ProxyError::MiddleConnectionLost);
} }
Err(_) => { Err(_) => {
max_delay_fired = true; max_delay_fired = true;
@@ -1517,7 +1620,7 @@ where
} else { } else {
None None
}; };
writer.flush().await.map_err(ProxyError::Io)?; flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
let flush_duration_us = flush_started_at.map(|started| { let flush_duration_us = flush_started_at.map(|started| {
started started
.elapsed() .elapsed()
@@ -1610,7 +1713,7 @@ where
stats.as_ref(), stats.as_ref(),
) )
.await; .await;
main_result = Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); main_result = Err(ProxyError::RouteSwitched);
break; break;
} }
@@ -1641,27 +1744,51 @@ where
traffic_lease.as_ref(), traffic_lease.as_ref(),
RateDirection::Up, RateDirection::Up,
payload.len() as u64, payload.len() as u64,
None,
) )
.await; .await?;
forensics.bytes_c2me = forensics forensics.bytes_c2me = forensics
.bytes_c2me .bytes_c2me
.saturating_add(payload.len() as u64); .saturating_add(payload.len() as u64);
if let (Some(limit), Some(user_stats)) = if let (Some(limit), Some(user_stats)) =
(quota_limit, quota_user_stats.as_deref()) (quota_limit, quota_user_stats.as_deref())
{ {
if reserve_user_quota_with_yield( match reserve_user_quota_with_yield(
user_stats, user_stats,
payload.len() as u64, payload.len() as u64,
limit, limit,
stats.as_ref(),
&flow_cancel,
None,
) )
.await .await
.is_err()
{ {
Ok(_) => {}
Err(MiddleQuotaReserveError::LimitExceeded) => {
main_result = Err(ProxyError::DataQuotaExceeded { main_result = Err(ProxyError::DataQuotaExceeded {
user: user.clone(), user: user.clone(),
}); });
break; break;
} }
Err(MiddleQuotaReserveError::Contended) => {
main_result = Err(ProxyError::Proxy(
"ME C->ME quota reservation contended".into(),
));
break;
}
Err(MiddleQuotaReserveError::Cancelled) => {
main_result = Err(ProxyError::Proxy(
"ME C->ME quota reservation cancelled".into(),
));
break;
}
Err(MiddleQuotaReserveError::DeadlineExceeded) => {
main_result = Err(ProxyError::Proxy(
"ME C->ME quota reservation deadline exceeded".into(),
));
break;
}
}
stats.add_user_octets_from_handle(user_stats, payload.len() as u64); stats.add_user_octets_from_handle(user_stats, payload.len() as u64);
} else { } else {
stats.add_user_octets_from(&user, payload.len() as u64); stats.add_user_octets_from(&user, payload.len() as u64);
@@ -1729,22 +1856,34 @@ where
} }
drop(c2me_tx); drop(c2me_tx);
let c2me_result = c2me_sender let c2me_result = match timeout(ME_CHILD_JOIN_TIMEOUT, &mut c2me_sender).await {
.await Ok(joined) => {
.unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME sender join error: {e}")))); joined.unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME sender join error: {e}"))))
}
Err(_) => {
stats.increment_me_child_join_timeout_total();
stats.increment_me_child_abort_total();
c2me_sender.abort();
Err(ProxyError::Proxy("ME sender join timeout".into()))
}
};
flow_cancel.cancel();
let _ = stop_tx.send(()); let _ = stop_tx.send(());
let mut writer_result = me_writer let mut writer_result = match timeout(ME_CHILD_JOIN_TIMEOUT, &mut me_writer).await {
.await Ok(joined) => {
.unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME writer join error: {e}")))); joined.unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME writer join error: {e}"))))
}
Err(_) => {
stats.increment_me_child_join_timeout_total();
stats.increment_me_child_abort_total();
me_writer.abort();
Err(ProxyError::Proxy("ME writer join timeout".into()))
}
};
// When client closes, but ME channel stopped as unregistered - it isnt error // When client closes, but ME channel stopped as unregistered - it isnt error
if client_closed if client_closed && matches!(writer_result, Err(ProxyError::MiddleConnectionLost)) {
&& matches!(
writer_result,
Err(ProxyError::Proxy(ref msg)) if msg == "ME connection lost"
)
{
writer_result = Ok(()); writer_result = Ok(());
} }
@@ -2300,6 +2439,7 @@ where
quota_limit, quota_limit,
quota_soft_overshoot_bytes, quota_soft_overshoot_bytes,
None, None,
&CancellationToken::new(),
bytes_me2c, bytes_me2c,
conn_id, conn_id,
ack_flush_immediate, ack_flush_immediate,
@@ -2320,6 +2460,7 @@ async fn process_me_writer_response_with_traffic_lease<W>(
quota_limit: Option<u64>, quota_limit: Option<u64>,
quota_soft_overshoot_bytes: u64, quota_soft_overshoot_bytes: u64,
traffic_lease: Option<&Arc<TrafficLease>>, traffic_lease: Option<&Arc<TrafficLease>>,
cancel: &CancellationToken,
bytes_me2c: &AtomicU64, bytes_me2c: &AtomicU64,
conn_id: u64, conn_id: u64,
ack_flush_immediate: bool, ack_flush_immediate: bool,
@@ -2338,20 +2479,54 @@ where
let data_len = data.len() as u64; let data_len = data.len() as u64;
if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) { if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) {
let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes);
if reserve_user_quota_with_yield(user_stats, data_len, soft_limit) match reserve_user_quota_with_yield(
user_stats, data_len, soft_limit, stats, cancel, None,
)
.await .await
.is_err()
{ {
Ok(_) => {}
Err(MiddleQuotaReserveError::LimitExceeded) => {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded { return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(), user: user.to_string(),
}); });
} }
Err(MiddleQuotaReserveError::Contended) => {
return Err(ProxyError::Proxy(
"ME D->C quota reservation contended".into(),
));
} }
wait_for_traffic_budget(traffic_lease, RateDirection::Down, data_len).await; Err(MiddleQuotaReserveError::Cancelled) => {
return Err(ProxyError::Proxy(
"ME D->C quota reservation cancelled".into(),
));
}
Err(MiddleQuotaReserveError::DeadlineExceeded) => {
return Err(ProxyError::Proxy(
"ME D->C quota reservation deadline exceeded".into(),
));
}
}
}
wait_for_traffic_budget_or_cancel(
traffic_lease,
RateDirection::Down,
data_len,
cancel,
stats,
None,
)
.await?;
let write_mode = let write_mode = match write_client_payload(
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
cancel,
)
.await .await
{ {
Ok(mode) => mode, Ok(mode) => mode,
@@ -2386,8 +2561,16 @@ where
} else { } else {
trace!(conn_id, confirm, "ME->C quickack"); trace!(conn_id, confirm, "ME->C quickack");
} }
wait_for_traffic_budget(traffic_lease, RateDirection::Down, 4).await; wait_for_traffic_budget_or_cancel(
write_client_ack(client_writer, proto_tag, confirm).await?; traffic_lease,
RateDirection::Down,
4,
cancel,
stats,
None,
)
.await?;
write_client_ack(client_writer, proto_tag, confirm, cancel).await?;
stats.increment_me_d2c_ack_frames_total(); stats.increment_me_d2c_ack_frames_total();
Ok(MeWriterResponseOutcome::Continue { Ok(MeWriterResponseOutcome::Continue {
@@ -2439,6 +2622,7 @@ async fn write_client_payload<W>(
data: &[u8], data: &[u8],
rng: &SecureRandom, rng: &SecureRandom,
frame_buf: &mut Vec<u8>, frame_buf: &mut Vec<u8>,
cancel: &CancellationToken,
) -> Result<MeD2cWriteMode> ) -> Result<MeD2cWriteMode>
where where
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
@@ -2466,21 +2650,12 @@ where
frame_buf.reserve(wire_len); frame_buf.reserve(wire_len);
frame_buf.push(first); frame_buf.push(first);
frame_buf.extend_from_slice(data); frame_buf.extend_from_slice(data);
client_writer write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Coalesced MeD2cWriteMode::Coalesced
} else { } else {
let header = [first]; let header = [first];
client_writer write_all_client_or_cancel(client_writer, &header, cancel).await?;
.write_all(&header) write_all_client_or_cancel(client_writer, data, cancel).await?;
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Split MeD2cWriteMode::Split
} }
} else if len_words < (1 << 24) { } else if len_words < (1 << 24) {
@@ -2495,21 +2670,12 @@ where
frame_buf.reserve(wire_len); frame_buf.reserve(wire_len);
frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]);
frame_buf.extend_from_slice(data); frame_buf.extend_from_slice(data);
client_writer write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Coalesced MeD2cWriteMode::Coalesced
} else { } else {
let header = [first, lw[0], lw[1], lw[2]]; let header = [first, lw[0], lw[1], lw[2]];
client_writer write_all_client_or_cancel(client_writer, &header, cancel).await?;
.write_all(&header) write_all_client_or_cancel(client_writer, data, cancel).await?;
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Split MeD2cWriteMode::Split
} }
} else { } else {
@@ -2544,21 +2710,12 @@ where
frame_buf.resize(start + padding_len, 0); frame_buf.resize(start + padding_len, 0);
rng.fill(&mut frame_buf[start..]); rng.fill(&mut frame_buf[start..]);
} }
client_writer write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Coalesced MeD2cWriteMode::Coalesced
} else { } else {
let header = len_val.to_le_bytes(); let header = len_val.to_le_bytes();
client_writer write_all_client_or_cancel(client_writer, &header, cancel).await?;
.write_all(&header) write_all_client_or_cancel(client_writer, data, cancel).await?;
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
if padding_len > 0 { if padding_len > 0 {
frame_buf.clear(); frame_buf.clear();
if frame_buf.capacity() < padding_len { if frame_buf.capacity() < padding_len {
@@ -2566,10 +2723,7 @@ where
} }
frame_buf.resize(padding_len, 0); frame_buf.resize(padding_len, 0);
rng.fill(frame_buf.as_mut_slice()); rng.fill(frame_buf.as_mut_slice());
client_writer write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
} }
MeD2cWriteMode::Split MeD2cWriteMode::Split
} }
@@ -2583,6 +2737,7 @@ async fn write_client_ack<W>(
client_writer: &mut CryptoWriter<W>, client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag, proto_tag: ProtoTag,
confirm: u32, confirm: u32,
cancel: &CancellationToken,
) -> Result<()> ) -> Result<()>
where where
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
@@ -2592,10 +2747,34 @@ where
} else { } else {
confirm.to_le_bytes() confirm.to_le_bytes()
}; };
client_writer write_all_client_or_cancel(client_writer, &bytes, cancel).await
.write_all(&bytes) }
.await
.map_err(ProxyError::Io) async fn write_all_client_or_cancel<W>(
client_writer: &mut CryptoWriter<W>,
bytes: &[u8],
cancel: &CancellationToken,
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
tokio::select! {
result = client_writer.write_all(bytes) => result.map_err(ProxyError::Io),
_ = cancel.cancelled() => Err(ProxyError::MiddleClientWriterCancelled),
}
}
async fn flush_client_or_cancel<W>(
client_writer: &mut CryptoWriter<W>,
cancel: &CancellationToken,
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
tokio::select! {
result = client_writer.flush() => result.map_err(ProxyError::Io),
_ = cancel.cancelled() => Err(ProxyError::MiddleClientWriterCancelled),
}
} }
#[cfg(test)] #[cfg(test)]
+99 -42
View File
@@ -215,6 +215,7 @@ struct StatsIo<S> {
c2s_rate_debt_bytes: u64, c2s_rate_debt_bytes: u64,
c2s_wait: RateWaitState, c2s_wait: RateWaitState,
s2c_wait: RateWaitState, s2c_wait: RateWaitState,
quota_wait: RateWaitState,
quota_limit: Option<u64>, quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>, quota_exceeded: Arc<AtomicBool>,
quota_bytes_since_check: u64, quota_bytes_since_check: u64,
@@ -275,6 +276,7 @@ impl<S> StatsIo<S> {
c2s_rate_debt_bytes: 0, c2s_rate_debt_bytes: 0,
c2s_wait: RateWaitState::default(), c2s_wait: RateWaitState::default(),
s2c_wait: RateWaitState::default(), s2c_wait: RateWaitState::default(),
quota_wait: RateWaitState::default(),
quota_limit, quota_limit,
quota_exceeded, quota_exceeded,
quota_bytes_since_check: 0, quota_bytes_since_check: 0,
@@ -353,6 +355,11 @@ impl<S> StatsIo<S> {
Poll::Ready(()) Poll::Ready(())
} }
fn arm_quota_wait(&mut self, cx: &mut Context<'_>) -> Poll<()> {
Self::arm_wait(&mut self.quota_wait, false, false);
Self::poll_wait(&mut self.quota_wait, cx, None, RateDirection::Up)
}
} }
#[derive(Debug)] #[derive(Debug)]
@@ -430,8 +437,13 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
if this.settle_c2s_rate_debt(cx).is_pending() { if this.settle_c2s_rate_debt(cx).is_pending() {
return Poll::Pending; return Poll::Pending;
} }
if buf.remaining() == 0 {
return Pin::new(&mut this.inner).poll_read(cx, buf);
}
let mut remaining_before = None; let mut remaining_before = None;
let mut reserved_read_bytes = 0u64;
let mut read_limit = buf.remaining();
if let Some(limit) = this.quota_limit { if let Some(limit) = this.quota_limit {
let used_before = this.user_stats.quota_used(); let used_before = this.user_stats.quota_used();
let remaining = limit.saturating_sub(used_before); let remaining = limit.saturating_sub(used_before);
@@ -440,50 +452,79 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
return Poll::Ready(Err(quota_io_error())); return Poll::Ready(Err(quota_io_error()));
} }
remaining_before = Some(remaining); remaining_before = Some(remaining);
read_limit = read_limit.min(remaining as usize);
if read_limit == 0 {
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
} }
let before = buf.filled().len(); let desired = read_limit as u64;
match Pin::new(&mut this.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let n = buf.filled().len() - before;
if n > 0 {
let n_to_charge = n as u64;
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
let mut reserved_total = None;
let mut reserve_rounds = 0usize; let mut reserve_rounds = 0usize;
while reserved_total.is_none() { while reserved_read_bytes == 0 {
let mut saw_contention = false;
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
match this.user_stats.quota_try_reserve(n_to_charge, limit) { match this.user_stats.quota_try_reserve(desired, limit) {
Ok(total) => { Ok(_) => {
reserved_total = Some(total); reserved_read_bytes = desired;
break; break;
} }
Err(crate::stats::QuotaReserveError::LimitExceeded) => { Err(crate::stats::QuotaReserveError::LimitExceeded) => {
this.quota_exceeded.store(true, Ordering::Release); this.quota_exceeded.store(true, Ordering::Release);
buf.set_filled(before);
return Poll::Ready(Err(quota_io_error())); return Poll::Ready(Err(quota_io_error()));
} }
Err(crate::stats::QuotaReserveError::Contended) => { Err(crate::stats::QuotaReserveError::Contended) => {
saw_contention = true; this.stats.increment_quota_contention_total();
}
}
}
if reserved_total.is_none() {
reserve_rounds = reserve_rounds.saturating_add(1);
if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS {
this.quota_exceeded.store(true, Ordering::Release);
buf.set_filled(before);
return Poll::Ready(Err(quota_io_error()));
}
if saw_contention {
std::thread::yield_now();
} }
} }
} }
if reserved_read_bytes == 0 {
reserve_rounds = reserve_rounds.saturating_add(1);
if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS {
this.stats.increment_quota_contention_timeout_total();
if this.arm_quota_wait(cx).is_pending() {
return Poll::Pending;
}
reserve_rounds = 0;
}
}
}
}
let limited_read = read_limit < buf.remaining();
let read_result = if limited_read {
let mut limited_buf = ReadBuf::new(buf.initialize_unfilled_to(read_limit));
match Pin::new(&mut this.inner).poll_read(cx, &mut limited_buf) {
Poll::Ready(Ok(())) => {
let n = limited_buf.filled().len();
buf.advance(n);
Poll::Ready(Ok(n))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending,
}
} else {
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let n = buf.filled().len() - before;
Poll::Ready(Ok(n))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending,
}
};
match read_result {
Poll::Ready(Ok(n)) => {
if reserved_read_bytes > n as u64 {
let refund_bytes = reserved_read_bytes - n as u64;
refund_reserved_quota_bytes(this.user_stats.as_ref(), refund_bytes);
this.stats.add_quota_refund_bytes_total(refund_bytes);
}
if n > 0 {
let n_to_charge = n as u64;
if let Some(remaining) = remaining_before {
if should_immediate_quota_check(remaining, n_to_charge) { if should_immediate_quota_check(remaining, n_to_charge) {
this.quota_bytes_since_check = 0; this.quota_bytes_since_check = 0;
} else { } else {
@@ -494,10 +535,11 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
this.quota_bytes_since_check = 0; this.quota_bytes_since_check = 0;
} }
} }
if reserved_total.unwrap_or(0) >= limit {
this.quota_exceeded.store(true, Ordering::Release);
} }
if let Some(limit) = this.quota_limit
&& this.user_stats.quota_used() >= limit
{
this.quota_exceeded.store(true, Ordering::Release);
} }
// C→S: client sent data // C→S: client sent data
@@ -521,7 +563,20 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
} }
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
other => other, Poll::Pending => {
if reserved_read_bytes > 0 {
refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_read_bytes);
this.stats.add_quota_refund_bytes_total(reserved_read_bytes);
}
Poll::Pending
}
Poll::Ready(Err(err)) => {
if reserved_read_bytes > 0 {
refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_read_bytes);
this.stats.add_quota_refund_bytes_total(reserved_read_bytes);
}
Poll::Ready(Err(err))
}
} }
} }
} }
@@ -603,6 +658,7 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
break; break;
} }
Err(crate::stats::QuotaReserveError::Contended) => { Err(crate::stats::QuotaReserveError::Contended) => {
this.stats.increment_quota_contention_total();
saw_contention = true; saw_contention = true;
} }
} }
@@ -611,14 +667,14 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
if reserved_bytes == 0 { if reserved_bytes == 0 {
reserve_rounds = reserve_rounds.saturating_add(1); reserve_rounds = reserve_rounds.saturating_add(1);
if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS {
this.stats.increment_quota_contention_timeout_total();
if let Some(lease) = this.traffic_lease.as_ref() { if let Some(lease) = this.traffic_lease.as_ref() {
lease.refund(RateDirection::Down, shaper_reserved_bytes); lease.refund(RateDirection::Down, shaper_reserved_bytes);
} }
this.quota_exceeded.store(true, Ordering::Release); let _ = this.arm_quota_wait(cx);
return Poll::Ready(Err(quota_io_error())); return Poll::Pending;
} } else if saw_contention {
if saw_contention { std::hint::spin_loop();
std::thread::yield_now();
} }
} }
} }
@@ -639,10 +695,9 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
match Pin::new(&mut this.inner).poll_write(cx, write_buf) { match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if reserved_bytes > n as u64 { if reserved_bytes > n as u64 {
refund_reserved_quota_bytes( let refund_bytes = reserved_bytes - n as u64;
this.user_stats.as_ref(), refund_reserved_quota_bytes(this.user_stats.as_ref(), refund_bytes);
reserved_bytes - n as u64, this.stats.add_quota_refund_bytes_total(refund_bytes);
);
} }
if shaper_reserved_bytes > n as u64 if shaper_reserved_bytes > n as u64
&& let Some(lease) = this.traffic_lease.as_ref() && let Some(lease) = this.traffic_lease.as_ref()
@@ -693,6 +748,7 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
Poll::Ready(Err(err)) => { Poll::Ready(Err(err)) => {
if reserved_bytes > 0 { if reserved_bytes > 0 {
refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes); refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes);
this.stats.add_quota_refund_bytes_total(reserved_bytes);
} }
if shaper_reserved_bytes > 0 if shaper_reserved_bytes > 0
&& let Some(lease) = this.traffic_lease.as_ref() && let Some(lease) = this.traffic_lease.as_ref()
@@ -704,6 +760,7 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
Poll::Pending => { Poll::Pending => {
if reserved_bytes > 0 { if reserved_bytes > 0 {
refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes); refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes);
this.stats.add_quota_refund_bytes_total(reserved_bytes);
} }
if shaper_reserved_bytes > 0 if shaper_reserved_bytes > 0
&& let Some(lease) = this.traffic_lease.as_ref() && let Some(lease) = this.traffic_lease.as_ref()
-2
View File
@@ -4,8 +4,6 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::watch; use tokio::sync::watch;
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Session terminated";
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)] #[repr(u8)]
pub(crate) enum RelayRouteMode { pub(crate) enum RelayRouteMode {
+1 -1
View File
@@ -661,7 +661,7 @@ async fn integration_route_cutover_and_quota_overlap_fails_closed_and_releases_s
assert!( assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))
|| matches!(relay_result, Err(ProxyError::Proxy(ref msg)) if msg == crate::proxy::route_mode::ROUTE_SWITCH_ERROR_MSG), || matches!(relay_result, Err(ProxyError::RouteSwitched)),
"overlap race must fail closed via quota enforcement or generic cutover termination" "overlap race must fail closed via quota enforcement or generic cutover termination"
); );
+37 -16
View File
@@ -637,6 +637,22 @@ fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() {
"telemt-unknown-dc-parent-swap-{}", "telemt-unknown-dc-parent-swap-{}",
std::process::id() std::process::id()
)); ));
if let Ok(meta) = fs::symlink_metadata(&parent) {
if meta.file_type().is_symlink() || meta.is_file() {
fs::remove_file(&parent).expect("stale parent-swap path must be removable");
} else {
fs::remove_dir_all(&parent).expect("stale parent-swap directory must be removable");
}
}
let moved = parent.with_extension("bak");
if let Ok(meta) = fs::symlink_metadata(&moved) {
if meta.file_type().is_symlink() || meta.is_file() {
fs::remove_file(&moved).expect("stale parent-swap backup path must be removable");
} else {
fs::remove_dir_all(&moved)
.expect("stale parent-swap backup directory must be removable");
}
}
fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable"); fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable");
let rel_candidate = format!( let rel_candidate = format!(
@@ -646,8 +662,6 @@ fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() {
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) let sanitized = sanitize_unknown_dc_log_path(&rel_candidate)
.expect("candidate must sanitize before parent swap"); .expect("candidate must sanitize before parent swap");
let moved = parent.with_extension("bak");
let _ = fs::remove_dir_all(&moved);
fs::rename(&parent, &moved).expect("parent must be movable for swap simulation"); fs::rename(&parent, &moved).expect("parent must be movable for swap simulation");
symlink("/tmp", &parent).expect("symlink replacement for parent must be creatable"); symlink("/tmp", &parent).expect("symlink replacement for parent must be creatable");
@@ -720,6 +734,24 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
"telemt-unknown-dc-parent-swap-openat-{}", "telemt-unknown-dc-parent-swap-openat-{}",
std::process::id() std::process::id()
)); ));
if let Ok(meta) = fs::symlink_metadata(&base) {
if meta.file_type().is_symlink() || meta.is_file() {
fs::remove_file(&base).expect("stale parent-swap-openat path must be removable");
} else {
fs::remove_dir_all(&base)
.expect("stale parent-swap-openat directory must be removable");
}
}
let moved = base.with_extension("bak");
if let Ok(meta) = fs::symlink_metadata(&moved) {
if meta.file_type().is_symlink() || meta.is_file() {
fs::remove_file(&moved)
.expect("stale parent-swap-openat backup path must be removable");
} else {
fs::remove_dir_all(&moved)
.expect("stale parent-swap-openat backup directory must be removable");
}
}
fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable"); fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable");
let rel_candidate = format!( let rel_candidate = format!(
@@ -743,8 +775,6 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
let outside_target = outside_parent.join("unknown-dc.log"); let outside_target = outside_parent.join("unknown-dc.log");
let _ = fs::remove_file(&outside_target); let _ = fs::remove_file(&outside_target);
let moved = base.with_extension("bak");
let _ = fs::remove_dir_all(&moved);
fs::rename(&base, &moved).expect("base parent must be movable for swap simulation"); fs::rename(&base, &moved).expect("base parent must be movable for swap simulation");
symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable"); symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable");
@@ -1489,10 +1519,7 @@ async fn direct_relay_cutover_midflight_releases_route_gauge() {
"cutover should terminate direct relay session" "cutover should terminate direct relay session"
); );
assert!( assert!(
matches!( matches!(relay_result, Err(ProxyError::RouteSwitched)),
relay_result,
Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG
),
"client-visible cutover error must stay generic and avoid route-internal metadata" "client-visible cutover error must stay generic and avoid route-internal metadata"
); );
@@ -1629,10 +1656,7 @@ async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea
.expect("direct relay task must not panic"); .expect("direct relay task must not panic");
assert!( assert!(
matches!( matches!(relay_result, Err(ProxyError::RouteSwitched)),
relay_result,
Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG
),
"storm-cutover termination must remain generic for all direct sessions" "storm-cutover termination must remain generic for all direct sessions"
); );
} }
@@ -1935,10 +1959,7 @@ async fn adversarial_direct_relay_cutover_integrity() {
.expect("Session must not panic"); .expect("Session must not panic");
assert!( assert!(
matches!( matches!(result, Err(ProxyError::RouteSwitched)),
result,
Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG
),
"Session must terminate with route switch error on cutover" "Session must terminate with route switch error on cutover"
); );
} }
@@ -13,6 +13,8 @@ struct CountedWriter {
fail_writes: bool, fail_writes: bool,
} }
struct StalledWriter;
impl CountedWriter { impl CountedWriter {
fn new(write_calls: Arc<AtomicUsize>, fail_writes: bool) -> Self { fn new(write_calls: Arc<AtomicUsize>, fail_writes: bool) -> Self {
Self { Self {
@@ -49,12 +51,36 @@ impl AsyncWrite for CountedWriter {
} }
} }
impl AsyncWrite for StalledWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Pending
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Pending
}
}
fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter<CountedWriter> { fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter<CountedWriter> {
let key = [0u8; 32]; let key = [0u8; 32];
let iv = 0u128; let iv = 0u128;
CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024) CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024)
} }
fn make_stalled_crypto_writer() -> CryptoWriter<StalledWriter> {
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(StalledWriter, AesCtr::new(&key, iv), 8 * 1024)
}
#[tokio::test] #[tokio::test]
async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() { async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() {
let stats = Stats::new(); let stats = Stats::new();
@@ -189,3 +215,53 @@ async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() {
); );
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
} }
#[tokio::test]
async fn me_writer_data_write_obeys_flow_cancellation() {
let stats = Stats::new();
let user = "middle-me-writer-cancel-user";
let mut writer = make_stalled_crypto_writer();
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let cancel = CancellationToken::new();
cancel.cancel();
let result = process_me_writer_response_with_traffic_lease(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x31, 0x32, 0x33, 0x34]),
route_permit: None,
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
user,
None,
None,
0,
None,
&cancel,
&bytes_me2c,
13,
true,
false,
)
.await;
assert!(
matches!(result, Err(ProxyError::MiddleClientWriterCancelled)),
"cancelled middle writer must return a bounded cancellation error"
);
assert_eq!(
bytes_me2c.load(Ordering::Relaxed),
0,
"cancelled write must not advance committed ME->C bytes"
);
assert_eq!(
stats.get_user_total_octets(user),
0,
"cancelled write must not advance user output telemetry"
);
}
@@ -4,10 +4,67 @@ use std::io;
use std::pin::Pin; use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::task::{Context, Poll}; use std::task::{Context, Poll, Wake};
use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::time::Instant; use tokio::time::Instant;
enum ReadStep {
Data(Vec<u8>),
Pending,
Eof,
Error,
}
struct ScriptedReader {
scripted_reads: Arc<Mutex<VecDeque<ReadStep>>>,
read_calls: Arc<AtomicUsize>,
}
impl ScriptedReader {
fn new(script: Vec<ReadStep>, read_calls: Arc<AtomicUsize>) -> Self {
Self {
scripted_reads: Arc::new(Mutex::new(script.into())),
read_calls,
}
}
}
impl AsyncRead for ScriptedReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
this.read_calls.fetch_add(1, Ordering::Relaxed);
let step = this
.scripted_reads
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.pop_front()
.unwrap_or(ReadStep::Eof);
match step {
ReadStep::Data(data) => {
let n = data.len().min(buf.remaining());
buf.put_slice(&data[..n]);
Poll::Ready(Ok(()))
}
ReadStep::Pending => Poll::Pending,
ReadStep::Eof => Poll::Ready(Ok(())),
ReadStep::Error => Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"forced read failure",
))),
}
}
}
struct NoopWake;
impl Wake for NoopWake {
fn wake(self: Arc<Self>) {}
}
struct ScriptedWriter { struct ScriptedWriter {
scripted_writes: Arc<Mutex<VecDeque<usize>>>, scripted_writes: Arc<Mutex<VecDeque<usize>>>,
write_calls: Arc<AtomicUsize>, write_calls: Arc<AtomicUsize>,
@@ -80,6 +137,127 @@ fn make_stats_io_with_script(
(io, stats, write_calls, quota_exceeded) (io, stats, write_calls, quota_exceeded)
} }
fn make_stats_io_with_read_script(
user: &str,
quota_limit: u64,
precharged_quota: u64,
script: Vec<ReadStep>,
) -> (
StatsIo<ScriptedReader>,
Arc<Stats>,
Arc<AtomicUsize>,
Arc<AtomicBool>,
) {
let stats = Arc::new(Stats::new());
if precharged_quota > 0 {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota);
}
let read_calls = Arc::new(AtomicUsize::new(0));
let quota_exceeded = Arc::new(AtomicBool::new(false));
let io = StatsIo::new(
ScriptedReader::new(script, read_calls.clone()),
Arc::new(SharedCounters::new()),
stats.clone(),
user.to_string(),
Some(quota_limit),
quota_exceeded.clone(),
Instant::now(),
);
(io, stats, read_calls, quota_exceeded)
}
fn poll_read_once<R: AsyncRead + Unpin>(
io: &mut StatsIo<R>,
storage: &mut [u8],
) -> Poll<io::Result<usize>> {
let waker = Arc::new(NoopWake).into();
let mut cx = Context::from_waker(&waker);
let mut read_buf = ReadBuf::new(storage);
let before = read_buf.filled().len();
match Pin::new(io).poll_read(&mut cx, &mut read_buf) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len() - before)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
}
#[test]
fn direct_c2s_quota_refunds_unused_on_short_read() {
let user = "direct-c2s-short-read-refund-user";
let (mut io, stats, read_calls, quota_exceeded) =
make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Data(vec![0x11; 5])]);
let mut storage = [0u8; 16];
let n = match poll_read_once(&mut io, &mut storage) {
Poll::Ready(Ok(n)) => n,
other => panic!("short read must complete, got {other:?}"),
};
assert_eq!(n, 5);
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
assert_eq!(stats.get_user_quota_used(user), 5);
assert_eq!(stats.get_quota_refund_bytes_total(), 11);
assert!(!quota_exceeded.load(Ordering::Acquire));
}
#[test]
fn direct_c2s_quota_refunds_full_reservation_on_pending() {
let user = "direct-c2s-pending-refund-user";
let (mut io, stats, read_calls, quota_exceeded) =
make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Pending]);
let mut storage = [0u8; 16];
assert!(matches!(
poll_read_once(&mut io, &mut storage),
Poll::Pending
));
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
assert_eq!(stats.get_user_quota_used(user), 0);
assert_eq!(stats.get_quota_refund_bytes_total(), 16);
assert!(!quota_exceeded.load(Ordering::Acquire));
}
#[test]
fn direct_c2s_quota_refunds_full_reservation_on_eof() {
let user = "direct-c2s-eof-refund-user";
let (mut io, stats, read_calls, quota_exceeded) =
make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Eof]);
let mut storage = [0u8; 16];
let n = match poll_read_once(&mut io, &mut storage) {
Poll::Ready(Ok(n)) => n,
other => panic!("EOF read must complete with zero bytes, got {other:?}"),
};
assert_eq!(n, 0);
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
assert_eq!(stats.get_user_quota_used(user), 0);
assert_eq!(stats.get_quota_refund_bytes_total(), 16);
assert!(!quota_exceeded.load(Ordering::Acquire));
}
#[test]
fn direct_c2s_quota_refunds_full_reservation_on_error() {
let user = "direct-c2s-error-refund-user";
let (mut io, stats, read_calls, quota_exceeded) =
make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Error]);
let mut storage = [0u8; 16];
let error = match poll_read_once(&mut io, &mut storage) {
Poll::Ready(Err(error)) => error,
other => panic!("error read must return error, got {other:?}"),
};
assert_eq!(error.kind(), io::ErrorKind::BrokenPipe);
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
assert_eq!(stats.get_user_quota_used(user), 0);
assert_eq!(stats.get_quota_refund_bytes_total(), 16);
assert!(!quota_exceeded.load(Ordering::Acquire));
}
#[tokio::test] #[tokio::test]
async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() { async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() {
let user = "direct-partial-charge-user"; let user = "direct-partial-charge-user";
+114
View File
@@ -0,0 +1,114 @@
use std::collections::BTreeMap;
use std::path::Path;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;
use tracing::{info, warn};
use crate::stats::{Stats, UserQuotaSnapshot};
#[derive(Debug, Default, Serialize, Deserialize)]
pub(crate) struct QuotaStateFile {
pub(crate) last_reset_epoch_secs: u64,
pub(crate) users: BTreeMap<String, QuotaUserState>,
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub(crate) struct QuotaUserState {
pub(crate) used_bytes: u64,
pub(crate) last_reset_epoch_secs: u64,
}
fn now_epoch_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub(crate) async fn load_quota_state(path: &Path, stats: &Stats) {
let bytes = match tokio::fs::read(path).await {
Ok(bytes) => bytes,
Err(error) if error.kind() == std::io::ErrorKind::NotFound => return,
Err(error) => {
warn!(
error = %error,
path = %path.display(),
"Failed to read quota state file"
);
return;
}
};
let state = match serde_json::from_slice::<QuotaStateFile>(&bytes) {
Ok(state) => state,
Err(error) => {
warn!(
error = %error,
path = %path.display(),
"Failed to parse quota state file"
);
return;
}
};
let loaded_users = state.users.len();
for (user, quota) in state.users {
stats.load_user_quota_state(&user, quota.used_bytes, quota.last_reset_epoch_secs);
}
info!(
path = %path.display(),
loaded_users,
"Loaded per-user quota state"
);
}
pub(crate) async fn save_quota_state(path: &Path, stats: &Stats) -> std::io::Result<()> {
let mut users = BTreeMap::new();
let mut last_reset_epoch_secs = 0;
for (user, quota) in stats.user_quota_snapshot() {
last_reset_epoch_secs = last_reset_epoch_secs.max(quota.last_reset_epoch_secs);
users.insert(user, quota_user_state(quota));
}
let state = QuotaStateFile {
last_reset_epoch_secs,
users,
};
write_state_file(path, &state).await
}
pub(crate) async fn reset_user_quota(
path: &Path,
stats: &Stats,
user: &str,
) -> std::io::Result<UserQuotaSnapshot> {
let snapshot = stats.reset_user_quota(user);
save_quota_state(path, stats).await?;
Ok(snapshot)
}
async fn write_state_file(path: &Path, state: &QuotaStateFile) -> std::io::Result<()> {
if let Some(parent) = path.parent()
&& !parent.as_os_str().is_empty()
{
tokio::fs::create_dir_all(parent).await?;
}
let tmp_path = path.with_extension(format!("tmp.{}", now_epoch_secs()));
let payload = serde_json::to_vec_pretty(state)?;
let mut file = tokio::fs::File::create(&tmp_path).await?;
file.write_all(&payload).await?;
file.write_all(b"\n").await?;
file.sync_all().await?;
drop(file);
tokio::fs::rename(&tmp_path, path).await
}
fn quota_user_state(quota: UserQuotaSnapshot) -> QuotaUserState {
QuotaUserState {
used_bytes: quota.used_bytes,
last_reset_epoch_secs: quota.last_reset_epoch_secs,
}
}
+185 -17
View File
@@ -8,8 +8,8 @@ pub mod telemetry;
use dashmap::DashMap; use dashmap::DashMap;
use lru::LruCache; use lru::LruCache;
use parking_lot::Mutex; use parking_lot::Mutex;
use std::collections::VecDeque;
use std::collections::hash_map::DefaultHasher; use std::collections::hash_map::DefaultHasher;
use std::collections::{HashMap, VecDeque};
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize; use std::num::NonZeroUsize;
use std::sync::Arc; use std::sync::Arc;
@@ -274,11 +274,22 @@ pub struct Stats {
me_inline_recovery_total: AtomicU64, me_inline_recovery_total: AtomicU64,
ip_reservation_rollback_tcp_limit_total: AtomicU64, ip_reservation_rollback_tcp_limit_total: AtomicU64,
ip_reservation_rollback_quota_limit_total: AtomicU64, ip_reservation_rollback_quota_limit_total: AtomicU64,
quota_refund_bytes_total: AtomicU64,
quota_contention_total: AtomicU64,
quota_contention_timeout_total: AtomicU64,
quota_acquire_cancelled_total: AtomicU64,
quota_write_fail_bytes_total: AtomicU64, quota_write_fail_bytes_total: AtomicU64,
quota_write_fail_events_total: AtomicU64, quota_write_fail_events_total: AtomicU64,
me_child_join_timeout_total: AtomicU64,
me_child_abort_total: AtomicU64,
flow_wait_middle_rate_limit_total: AtomicU64,
flow_wait_middle_rate_limit_cancelled_total: AtomicU64,
flow_wait_middle_rate_limit_ms_total: AtomicU64,
session_drop_fallback_total: AtomicU64,
telemetry_core_enabled: AtomicBool, telemetry_core_enabled: AtomicBool,
telemetry_user_enabled: AtomicBool, telemetry_user_enabled: AtomicBool,
telemetry_me_level: AtomicU8, telemetry_me_level: AtomicU8,
cached_epoch_secs: AtomicU64,
user_stats: DashMap<String, Arc<UserStats>>, user_stats: DashMap<String, Arc<UserStats>>,
user_stats_last_cleanup_epoch_secs: AtomicU64, user_stats_last_cleanup_epoch_secs: AtomicU64,
start_time: parking_lot::RwLock<Option<Instant>>, start_time: parking_lot::RwLock<Option<Instant>>,
@@ -297,9 +308,16 @@ pub struct UserStats {
/// This counter is the single source of truth for quota enforcement and /// This counter is the single source of truth for quota enforcement and
/// intentionally tracks attempted traffic, not guaranteed delivery. /// intentionally tracks attempted traffic, not guaranteed delivery.
pub quota_used: AtomicU64, pub quota_used: AtomicU64,
pub quota_last_reset_epoch_secs: AtomicU64,
pub last_seen_epoch_secs: AtomicU64, pub last_seen_epoch_secs: AtomicU64,
} }
#[derive(Debug, Clone)]
pub struct UserQuotaSnapshot {
pub used_bytes: u64,
pub last_reset_epoch_secs: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuotaReserveError { pub enum QuotaReserveError {
LimitExceeded, LimitExceeded,
@@ -341,6 +359,7 @@ impl Stats {
pub fn new() -> Self { pub fn new() -> Self {
let stats = Self::default(); let stats = Self::default();
stats.apply_telemetry_policy(TelemetryPolicy::default()); stats.apply_telemetry_policy(TelemetryPolicy::default());
stats.refresh_cached_epoch_secs();
*stats.start_time.write() = Some(Instant::now()); *stats.start_time.write() = Some(Instant::now());
stats stats
} }
@@ -390,33 +409,55 @@ impl Stats {
.as_secs() .as_secs()
} }
fn touch_user_stats(stats: &UserStats) { fn refresh_cached_epoch_secs(&self) -> u64 {
let now_epoch_secs = Self::now_epoch_secs();
self.cached_epoch_secs
.store(now_epoch_secs, Ordering::Relaxed);
now_epoch_secs
}
fn cached_epoch_secs(&self) -> u64 {
let cached = self.cached_epoch_secs.load(Ordering::Relaxed);
if cached != 0 {
return cached;
}
self.refresh_cached_epoch_secs()
}
fn touch_user_stats(&self, stats: &UserStats) {
stats stats
.last_seen_epoch_secs .last_seen_epoch_secs
.store(Self::now_epoch_secs(), Ordering::Relaxed); .store(self.cached_epoch_secs(), Ordering::Relaxed);
} }
pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc<UserStats> { pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc<UserStats> {
self.maybe_cleanup_user_stats();
if let Some(existing) = self.user_stats.get(user) { if let Some(existing) = self.user_stats.get(user) {
let handle = Arc::clone(existing.value()); let handle = Arc::clone(existing.value());
Self::touch_user_stats(handle.as_ref()); self.touch_user_stats(handle.as_ref());
return handle; return handle;
} }
let entry = self.user_stats.entry(user.to_string()).or_default(); let entry = self.user_stats.entry(user.to_string()).or_default();
if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 { if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 {
Self::touch_user_stats(entry.value().as_ref()); self.touch_user_stats(entry.value().as_ref());
} }
Arc::clone(entry.value()) Arc::clone(entry.value())
} }
pub(crate) async fn run_periodic_user_stats_maintenance(self: Arc<Self>) {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
self.maybe_cleanup_user_stats();
}
}
#[inline] #[inline]
pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) { pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
Self::touch_user_stats(user_stats); self.touch_user_stats(user_stats);
user_stats user_stats
.octets_from_client .octets_from_client
.fetch_add(bytes, Ordering::Relaxed); .fetch_add(bytes, Ordering::Relaxed);
@@ -427,7 +468,7 @@ impl Stats {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
Self::touch_user_stats(user_stats); self.touch_user_stats(user_stats);
user_stats user_stats
.octets_to_client .octets_to_client
.fetch_add(bytes, Ordering::Relaxed); .fetch_add(bytes, Ordering::Relaxed);
@@ -438,7 +479,7 @@ impl Stats {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
Self::touch_user_stats(user_stats); self.touch_user_stats(user_stats);
user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
} }
@@ -447,7 +488,7 @@ impl Stats {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
Self::touch_user_stats(user_stats); self.touch_user_stats(user_stats);
user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
} }
@@ -457,7 +498,7 @@ impl Stats {
/// mixing reserve and post-charge on a single I/O event. /// mixing reserve and post-charge on a single I/O event.
#[inline] #[inline]
pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 { pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 {
Self::touch_user_stats(user_stats); self.touch_user_stats(user_stats);
user_stats user_stats
.quota_used .quota_used
.fetch_add(bytes, Ordering::Relaxed) .fetch_add(bytes, Ordering::Relaxed)
@@ -468,7 +509,7 @@ impl Stats {
const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60; const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60;
const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60; const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60;
let now_epoch_secs = Self::now_epoch_secs(); let now_epoch_secs = self.refresh_cached_epoch_secs();
let last_cleanup_epoch_secs = self let last_cleanup_epoch_secs = self
.user_stats_last_cleanup_epoch_secs .user_stats_last_cleanup_epoch_secs
.load(Ordering::Relaxed); .load(Ordering::Relaxed);
@@ -1430,6 +1471,29 @@ impl Stats {
.fetch_add(1, Ordering::Relaxed); .fetch_add(1, Ordering::Relaxed);
} }
} }
pub fn add_quota_refund_bytes_total(&self, bytes: u64) {
if self.telemetry_core_enabled() {
self.quota_refund_bytes_total
.fetch_add(bytes, Ordering::Relaxed);
}
}
pub fn increment_quota_contention_total(&self) {
if self.telemetry_core_enabled() {
self.quota_contention_total.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_quota_contention_timeout_total(&self) {
if self.telemetry_core_enabled() {
self.quota_contention_timeout_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_quota_acquire_cancelled_total(&self) {
if self.telemetry_core_enabled() {
self.quota_acquire_cancelled_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) { pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) {
if self.telemetry_core_enabled() { if self.telemetry_core_enabled() {
self.quota_write_fail_bytes_total self.quota_write_fail_bytes_total
@@ -1442,6 +1506,37 @@ impl Stats {
.fetch_add(1, Ordering::Relaxed); .fetch_add(1, Ordering::Relaxed);
} }
} }
pub fn increment_me_child_join_timeout_total(&self) {
if self.telemetry_core_enabled() {
self.me_child_join_timeout_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_child_abort_total(&self) {
if self.telemetry_core_enabled() {
self.me_child_abort_total.fetch_add(1, Ordering::Relaxed);
}
}
pub fn observe_flow_wait_middle_rate_limit_ms(&self, wait_ms: u64) {
if self.telemetry_core_enabled() {
self.flow_wait_middle_rate_limit_total
.fetch_add(1, Ordering::Relaxed);
self.flow_wait_middle_rate_limit_ms_total
.fetch_add(wait_ms, Ordering::Relaxed);
}
}
pub fn increment_flow_wait_middle_rate_limit_cancelled_total(&self) {
if self.telemetry_core_enabled() {
self.flow_wait_middle_rate_limit_cancelled_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_session_drop_fallback_total(&self) {
if self.telemetry_core_enabled() {
self.session_drop_fallback_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_endpoint_quarantine_total(&self) { pub fn increment_me_endpoint_quarantine_total(&self) {
if self.telemetry_me_allows_normal() { if self.telemetry_me_allows_normal() {
self.me_endpoint_quarantine_total self.me_endpoint_quarantine_total
@@ -2276,19 +2371,52 @@ impl Stats {
self.ip_reservation_rollback_quota_limit_total self.ip_reservation_rollback_quota_limit_total
.load(Ordering::Relaxed) .load(Ordering::Relaxed)
} }
pub fn get_quota_refund_bytes_total(&self) -> u64 {
self.quota_refund_bytes_total.load(Ordering::Relaxed)
}
pub fn get_quota_contention_total(&self) -> u64 {
self.quota_contention_total.load(Ordering::Relaxed)
}
pub fn get_quota_contention_timeout_total(&self) -> u64 {
self.quota_contention_timeout_total.load(Ordering::Relaxed)
}
pub fn get_quota_acquire_cancelled_total(&self) -> u64 {
self.quota_acquire_cancelled_total.load(Ordering::Relaxed)
}
pub fn get_quota_write_fail_bytes_total(&self) -> u64 { pub fn get_quota_write_fail_bytes_total(&self) -> u64 {
self.quota_write_fail_bytes_total.load(Ordering::Relaxed) self.quota_write_fail_bytes_total.load(Ordering::Relaxed)
} }
pub fn get_quota_write_fail_events_total(&self) -> u64 { pub fn get_quota_write_fail_events_total(&self) -> u64 {
self.quota_write_fail_events_total.load(Ordering::Relaxed) self.quota_write_fail_events_total.load(Ordering::Relaxed)
} }
pub fn get_me_child_join_timeout_total(&self) -> u64 {
self.me_child_join_timeout_total.load(Ordering::Relaxed)
}
pub fn get_me_child_abort_total(&self) -> u64 {
self.me_child_abort_total.load(Ordering::Relaxed)
}
pub fn get_flow_wait_middle_rate_limit_total(&self) -> u64 {
self.flow_wait_middle_rate_limit_total
.load(Ordering::Relaxed)
}
pub fn get_flow_wait_middle_rate_limit_cancelled_total(&self) -> u64 {
self.flow_wait_middle_rate_limit_cancelled_total
.load(Ordering::Relaxed)
}
pub fn get_flow_wait_middle_rate_limit_ms_total(&self) -> u64 {
self.flow_wait_middle_rate_limit_ms_total
.load(Ordering::Relaxed)
}
pub fn get_session_drop_fallback_total(&self) -> u64 {
self.session_drop_fallback_total.load(Ordering::Relaxed)
}
pub fn increment_user_connects(&self, user: &str) { pub fn increment_user_connects(&self, user: &str) {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
let stats = self.get_or_create_user_stats_handle(user); let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref()); self.touch_user_stats(stats.as_ref());
stats.connects.fetch_add(1, Ordering::Relaxed); stats.connects.fetch_add(1, Ordering::Relaxed);
} }
@@ -2297,7 +2425,7 @@ impl Stats {
return; return;
} }
let stats = self.get_or_create_user_stats_handle(user); let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref()); self.touch_user_stats(stats.as_ref());
stats.curr_connects.fetch_add(1, Ordering::Relaxed); stats.curr_connects.fetch_add(1, Ordering::Relaxed);
} }
@@ -2307,7 +2435,7 @@ impl Stats {
} }
let stats = self.get_or_create_user_stats_handle(user); let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref()); self.touch_user_stats(stats.as_ref());
let counter = &stats.curr_connects; let counter = &stats.curr_connects;
let mut current = counter.load(Ordering::Relaxed); let mut current = counter.load(Ordering::Relaxed);
@@ -2330,9 +2458,8 @@ impl Stats {
} }
pub fn decrement_user_curr_connects(&self, user: &str) { pub fn decrement_user_curr_connects(&self, user: &str) {
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) { if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value().as_ref()); self.touch_user_stats(stats.value().as_ref());
let counter = &stats.curr_connects; let counter = &stats.curr_connects;
let mut current = counter.load(Ordering::Relaxed); let mut current = counter.load(Ordering::Relaxed);
loop { loop {
@@ -2408,6 +2535,47 @@ impl Stats {
.unwrap_or(0) .unwrap_or(0)
} }
pub fn load_user_quota_state(&self, user: &str, used_bytes: u64, last_reset_epoch_secs: u64) {
let stats = self.get_or_create_user_stats_handle(user);
stats.quota_used.store(used_bytes, Ordering::Relaxed);
stats
.quota_last_reset_epoch_secs
.store(last_reset_epoch_secs, Ordering::Relaxed);
}
pub fn reset_user_quota(&self, user: &str) -> UserQuotaSnapshot {
let stats = self.get_or_create_user_stats_handle(user);
let last_reset_epoch_secs = Self::now_epoch_secs();
stats.quota_used.store(0, Ordering::Relaxed);
stats
.quota_last_reset_epoch_secs
.store(last_reset_epoch_secs, Ordering::Relaxed);
UserQuotaSnapshot {
used_bytes: 0,
last_reset_epoch_secs,
}
}
pub fn user_quota_snapshot(&self) -> HashMap<String, UserQuotaSnapshot> {
let mut out = HashMap::new();
for entry in self.user_stats.iter() {
let stats = entry.value();
let used_bytes = stats.quota_used.load(Ordering::Relaxed);
let last_reset_epoch_secs = stats.quota_last_reset_epoch_secs.load(Ordering::Relaxed);
if used_bytes == 0 && last_reset_epoch_secs == 0 {
continue;
}
out.insert(
entry.key().clone(),
UserQuotaSnapshot {
used_bytes,
last_reset_epoch_secs,
},
);
}
out
}
pub fn get_handshake_timeouts(&self) -> u64 { pub fn get_handshake_timeouts(&self) -> u64 {
self.handshake_timeouts.load(Ordering::Relaxed) self.handshake_timeouts.load(Ordering::Relaxed)
} }
+71 -1
View File
@@ -12,7 +12,7 @@ use tokio::time::sleep;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::tls_front::types::{ use crate::tls_front::types::{
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsFetchResult, CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsFetchResult, TlsProfileSource,
}; };
const FULL_CERT_SENT_SWEEP_INTERVAL_SECS: u64 = 30; const FULL_CERT_SENT_SWEEP_INTERVAL_SECS: u64 = 30;
@@ -42,6 +42,30 @@ pub struct TlsFrontCache {
disk_path: PathBuf, disk_path: PathBuf,
} }
/// Read-only health view for one configured TLS front domain.
#[derive(Debug, Clone)]
pub(crate) struct TlsFrontProfileHealth {
pub(crate) domain: String,
pub(crate) source: &'static str,
pub(crate) age_seconds: u64,
pub(crate) is_default: bool,
pub(crate) has_cert_info: bool,
pub(crate) has_cert_payload: bool,
pub(crate) app_data_records: usize,
pub(crate) ticket_records: usize,
pub(crate) change_cipher_spec_count: u8,
pub(crate) total_app_data_len: usize,
}
fn profile_source_label(source: TlsProfileSource) -> &'static str {
match source {
TlsProfileSource::Default => "default",
TlsProfileSource::Raw => "raw",
TlsProfileSource::Rustls => "rustls",
TlsProfileSource::Merged => "merged",
}
}
#[allow(dead_code)] #[allow(dead_code)]
impl TlsFrontCache { impl TlsFrontCache {
pub fn new(domains: &[String], default_len: usize, disk_path: impl AsRef<Path>) -> Self { pub fn new(domains: &[String], default_len: usize, disk_path: impl AsRef<Path>) -> Self {
@@ -93,6 +117,52 @@ impl TlsFrontCache {
self.memory.read().await.contains_key(domain) self.memory.read().await.contains_key(domain)
} }
pub(crate) async fn profile_health_snapshot(
&self,
domains: &[String],
max_domains: usize,
) -> (Vec<TlsFrontProfileHealth>, usize) {
let guard = self.memory.read().await;
let now = SystemTime::now();
let mut snapshot = Vec::with_capacity(domains.len().min(max_domains));
let mut suppressed = 0usize;
for domain in domains {
if snapshot.len() >= max_domains {
suppressed = suppressed.saturating_add(1);
continue;
}
let cached = guard
.get(domain)
.cloned()
.unwrap_or_else(|| self.default.clone());
let behavior = &cached.behavior_profile;
let age_seconds = now
.duration_since(cached.fetched_at)
.map(|duration| duration.as_secs())
.unwrap_or(0);
snapshot.push(TlsFrontProfileHealth {
domain: domain.clone(),
source: profile_source_label(behavior.source),
age_seconds,
is_default: cached.domain == "default",
has_cert_info: cached.cert_info.is_some(),
has_cert_payload: cached.cert_payload.is_some(),
app_data_records: cached
.app_data_records_sizes
.len()
.max(behavior.app_data_record_sizes.len()),
ticket_records: behavior.ticket_record_sizes.len(),
change_cipher_spec_count: behavior.change_cipher_spec_count,
total_app_data_len: cached.total_app_data_len,
});
}
(snapshot, suppressed)
}
fn full_cert_sent_shard_index(client_ip: IpAddr) -> usize { fn full_cert_sent_shard_index(client_ip: IpAddr) -> usize {
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
client_ip.hash(&mut hasher); client_ip.hash(&mut hasher);
+11 -7
View File
@@ -365,7 +365,10 @@ impl MePool {
} }
} }
pub async fn zero_downtime_reinit_after_map_change(self: &Arc<Self>, rng: &SecureRandom) { pub async fn zero_downtime_reinit_after_map_change(
self: &Arc<Self>,
rng: &SecureRandom,
) -> bool {
let desired_by_dc = self.desired_dc_endpoints().await; let desired_by_dc = self.desired_dc_endpoints().await;
let now_epoch_secs = Self::now_epoch_secs(); let now_epoch_secs = Self::now_epoch_secs();
let v4_suppressed = self.is_family_temporarily_suppressed(IpFamily::V4, now_epoch_secs); let v4_suppressed = self.is_family_temporarily_suppressed(IpFamily::V4, now_epoch_secs);
@@ -380,7 +383,7 @@ impl MePool {
MeDrainGateReason::CoverageQuorum MeDrainGateReason::CoverageQuorum
}; };
self.set_last_drain_gate(false, false, reason, now_epoch_secs); self.set_last_drain_gate(false, false, reason, now_epoch_secs);
return; return false;
} }
let desired_map_hash = Self::desired_map_hash(&desired_by_dc); let desired_map_hash = Self::desired_map_hash(&desired_by_dc);
@@ -490,7 +493,7 @@ impl MePool {
missing_dc = ?missing_dc, missing_dc = ?missing_dc,
"ME reinit coverage below threshold; keeping stale writers" "ME reinit coverage below threshold; keeping stale writers"
); );
return; return false;
} }
if hardswap { if hardswap {
@@ -520,7 +523,7 @@ impl MePool {
missing_dc = ?fresh_missing_dc, missing_dc = ?fresh_missing_dc,
"ME hardswap pending: fresh generation DC coverage incomplete" "ME hardswap pending: fresh generation DC coverage incomplete"
); );
return; return false;
} }
} }
@@ -567,7 +570,7 @@ impl MePool {
self.clear_pending_hardswap_state(); self.clear_pending_hardswap_state();
} }
debug!("ME reinit cycle completed with no stale writers"); debug!("ME reinit cycle completed with no stale writers");
return; return true;
} }
let drain_timeout = self.force_close_timeout(); let drain_timeout = self.force_close_timeout();
@@ -606,10 +609,11 @@ impl MePool {
if hardswap { if hardswap {
self.clear_pending_hardswap_state(); self.clear_pending_hardswap_state();
} }
true
} }
pub async fn zero_downtime_reinit_periodic(self: &Arc<Self>, rng: &SecureRandom) { pub async fn zero_downtime_reinit_periodic(self: &Arc<Self>, rng: &SecureRandom) -> bool {
self.zero_downtime_reinit_after_map_change(rng).await; self.zero_downtime_reinit_after_map_change(rng).await
} }
} }
+14 -3
View File
@@ -47,6 +47,7 @@ pub async fn me_reinit_scheduler(
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
config_rx: watch::Receiver<Arc<ProxyConfig>>, config_rx: watch::Receiver<Arc<ProxyConfig>>,
mut trigger_rx: mpsc::Receiver<MeReinitTrigger>, mut trigger_rx: mpsc::Receiver<MeReinitTrigger>,
me_ready_tx: watch::Sender<u64>,
) { ) {
info!("ME reinit scheduler started"); info!("ME reinit scheduler started");
loop { loop {
@@ -90,15 +91,25 @@ pub async fn me_reinit_scheduler(
if cfg.general.me_reinit_singleflight { if cfg.general.me_reinit_singleflight {
debug!(reason, "ME reinit scheduled (single-flight)"); debug!(reason, "ME reinit scheduled (single-flight)");
pool.zero_downtime_reinit_periodic(rng.as_ref()).await; if pool.zero_downtime_reinit_periodic(rng.as_ref()).await {
me_ready_tx.send_modify(|version| {
*version = version.saturating_add(1);
});
}
} else { } else {
debug!(reason, "ME reinit scheduled (concurrent mode)"); debug!(reason, "ME reinit scheduled (concurrent mode)");
let pool_clone = pool.clone(); let pool_clone = pool.clone();
let rng_clone = rng.clone(); let rng_clone = rng.clone();
let me_ready_tx_clone = me_ready_tx.clone();
tokio::spawn(async move { tokio::spawn(async move {
pool_clone if pool_clone
.zero_downtime_reinit_periodic(rng_clone.as_ref()) .zero_downtime_reinit_periodic(rng_clone.as_ref())
.await; .await
{
me_ready_tx_clone.send_modify(|version| {
*version = version.saturating_add(1);
});
}
}); });
} }
} }
+6
View File
@@ -18,6 +18,9 @@ const PROXY_V1_MIN_LEN: usize = 6;
/// Minimum length for v2 header /// Minimum length for v2 header
const PROXY_V2_MIN_LEN: usize = 16; const PROXY_V2_MIN_LEN: usize = 16;
/// Maximum accepted PROXY v2 address and TLV payload.
const PROXY_V2_MAX_ADDR_LEN: usize = 216;
/// Address families for v2 /// Address families for v2
mod address_family { mod address_family {
pub const UNSPEC: u8 = 0x0; pub const UNSPEC: u8 = 0x0;
@@ -169,6 +172,9 @@ async fn parse_v2<R: AsyncRead + Unpin>(
let family_protocol = header[13]; let family_protocol = header[13];
let addr_len = u16::from_be_bytes([header[14], header[15]]) as usize; let addr_len = u16::from_be_bytes([header[14], header[15]]) as usize;
if addr_len > PROXY_V2_MAX_ADDR_LEN {
return Err(ProxyError::InvalidProxyProtocol);
}
// Read address data // Read address data
let mut addr_data = vec![0u8; addr_len]; let mut addr_data = vec![0u8; addr_len];
File diff suppressed because it is too large Load Diff