Harden ME Writer Cancellation paths

This commit is contained in:
Alexey
2026-05-10 14:09:10 +03:00
parent beed6b4679
commit 900b574fb8
4 changed files with 359 additions and 58 deletions
+65 -52
View File
@@ -1367,7 +1367,7 @@ where
} else {
None
};
let _ = writer.flush().await;
let _ = flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await;
let flush_duration_us = flush_started_at.map(|started| {
started
.elapsed()
@@ -1430,7 +1430,8 @@ where
} else {
None
};
let _ = writer.flush().await;
let _ =
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await;
let flush_duration_us = flush_started_at.map(|started| {
started
.elapsed()
@@ -1498,7 +1499,11 @@ where
} else {
None
};
let _ = writer.flush().await;
let _ = flush_client_or_cancel(
&mut writer,
&flow_cancel_me_writer,
)
.await;
let flush_duration_us = flush_started_at.map(|started| {
started
.elapsed()
@@ -1565,7 +1570,11 @@ where
} else {
None
};
let _ = writer.flush().await;
let _ = flush_client_or_cancel(
&mut writer,
&flow_cancel_me_writer,
)
.await;
let flush_duration_us = flush_started_at.map(|started| {
started
.elapsed()
@@ -1611,7 +1620,7 @@ where
} else {
None
};
writer.flush().await.map_err(ProxyError::Io)?;
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
let flush_duration_us = flush_started_at.map(|started| {
started
.elapsed()
@@ -2512,8 +2521,16 @@ where
.await?;
let write_mode =
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await
match write_client_payload(
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
cancel,
)
.await
{
Ok(mode) => mode,
Err(err) => {
@@ -2556,7 +2573,7 @@ where
None,
)
.await?;
write_client_ack(client_writer, proto_tag, confirm).await?;
write_client_ack(client_writer, proto_tag, confirm, cancel).await?;
stats.increment_me_d2c_ack_frames_total();
Ok(MeWriterResponseOutcome::Continue {
@@ -2608,6 +2625,7 @@ async fn write_client_payload<W>(
data: &[u8],
rng: &SecureRandom,
frame_buf: &mut Vec<u8>,
cancel: &CancellationToken,
) -> Result<MeD2cWriteMode>
where
W: AsyncWrite + Unpin + Send + 'static,
@@ -2635,21 +2653,12 @@ where
frame_buf.reserve(wire_len);
frame_buf.push(first);
frame_buf.extend_from_slice(data);
client_writer
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
MeD2cWriteMode::Coalesced
} else {
let header = [first];
client_writer
.write_all(&header)
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
write_all_client_or_cancel(client_writer, &header, cancel).await?;
write_all_client_or_cancel(client_writer, data, cancel).await?;
MeD2cWriteMode::Split
}
} else if len_words < (1 << 24) {
@@ -2664,21 +2673,12 @@ where
frame_buf.reserve(wire_len);
frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]);
frame_buf.extend_from_slice(data);
client_writer
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
MeD2cWriteMode::Coalesced
} else {
let header = [first, lw[0], lw[1], lw[2]];
client_writer
.write_all(&header)
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
write_all_client_or_cancel(client_writer, &header, cancel).await?;
write_all_client_or_cancel(client_writer, data, cancel).await?;
MeD2cWriteMode::Split
}
} else {
@@ -2713,21 +2713,12 @@ where
frame_buf.resize(start + padding_len, 0);
rng.fill(&mut frame_buf[start..]);
}
client_writer
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
MeD2cWriteMode::Coalesced
} else {
let header = len_val.to_le_bytes();
client_writer
.write_all(&header)
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
write_all_client_or_cancel(client_writer, &header, cancel).await?;
write_all_client_or_cancel(client_writer, data, cancel).await?;
if padding_len > 0 {
frame_buf.clear();
if frame_buf.capacity() < padding_len {
@@ -2735,10 +2726,7 @@ where
}
frame_buf.resize(padding_len, 0);
rng.fill(frame_buf.as_mut_slice());
client_writer
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
}
MeD2cWriteMode::Split
}
@@ -2752,6 +2740,7 @@ async fn write_client_ack<W>(
client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag,
confirm: u32,
cancel: &CancellationToken,
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
@@ -2761,10 +2750,34 @@ where
} else {
confirm.to_le_bytes()
};
client_writer
.write_all(&bytes)
.await
.map_err(ProxyError::Io)
write_all_client_or_cancel(client_writer, &bytes, cancel).await
}
async fn write_all_client_or_cancel<W>(
client_writer: &mut CryptoWriter<W>,
bytes: &[u8],
cancel: &CancellationToken,
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
tokio::select! {
result = client_writer.write_all(bytes) => result.map_err(ProxyError::Io),
_ = cancel.cancelled() => Err(ProxyError::Proxy("ME client writer cancelled".into())),
}
}
async fn flush_client_or_cancel<W>(
client_writer: &mut CryptoWriter<W>,
cancel: &CancellationToken,
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
tokio::select! {
result = client_writer.flush() => result.map_err(ProxyError::Io),
_ = cancel.cancelled() => Err(ProxyError::Proxy("ME client writer cancelled".into())),
}
}
#[cfg(test)]