mirror of
https://github.com/telemt/telemt.git
synced 2026-06-22 02:00:10 +07:00
Harden ME Writer Cancellation paths
This commit is contained in:
+65
-52
@@ -1367,7 +1367,7 @@ where
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let _ = writer.flush().await;
|
||||
let _ = flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await;
|
||||
let flush_duration_us = flush_started_at.map(|started| {
|
||||
started
|
||||
.elapsed()
|
||||
@@ -1430,7 +1430,8 @@ where
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let _ = writer.flush().await;
|
||||
let _ =
|
||||
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await;
|
||||
let flush_duration_us = flush_started_at.map(|started| {
|
||||
started
|
||||
.elapsed()
|
||||
@@ -1498,7 +1499,11 @@ where
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let _ = writer.flush().await;
|
||||
let _ = flush_client_or_cancel(
|
||||
&mut writer,
|
||||
&flow_cancel_me_writer,
|
||||
)
|
||||
.await;
|
||||
let flush_duration_us = flush_started_at.map(|started| {
|
||||
started
|
||||
.elapsed()
|
||||
@@ -1565,7 +1570,11 @@ where
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let _ = writer.flush().await;
|
||||
let _ = flush_client_or_cancel(
|
||||
&mut writer,
|
||||
&flow_cancel_me_writer,
|
||||
)
|
||||
.await;
|
||||
let flush_duration_us = flush_started_at.map(|started| {
|
||||
started
|
||||
.elapsed()
|
||||
@@ -1611,7 +1620,7 @@ where
|
||||
} else {
|
||||
None
|
||||
};
|
||||
writer.flush().await.map_err(ProxyError::Io)?;
|
||||
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
|
||||
let flush_duration_us = flush_started_at.map(|started| {
|
||||
started
|
||||
.elapsed()
|
||||
@@ -2512,8 +2521,16 @@ where
|
||||
.await?;
|
||||
|
||||
let write_mode =
|
||||
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
|
||||
.await
|
||||
match write_client_payload(
|
||||
client_writer,
|
||||
proto_tag,
|
||||
flags,
|
||||
&data,
|
||||
rng,
|
||||
frame_buf,
|
||||
cancel,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(mode) => mode,
|
||||
Err(err) => {
|
||||
@@ -2556,7 +2573,7 @@ where
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
write_client_ack(client_writer, proto_tag, confirm).await?;
|
||||
write_client_ack(client_writer, proto_tag, confirm, cancel).await?;
|
||||
stats.increment_me_d2c_ack_frames_total();
|
||||
|
||||
Ok(MeWriterResponseOutcome::Continue {
|
||||
@@ -2608,6 +2625,7 @@ async fn write_client_payload<W>(
|
||||
data: &[u8],
|
||||
rng: &SecureRandom,
|
||||
frame_buf: &mut Vec<u8>,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<MeD2cWriteMode>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
@@ -2635,21 +2653,12 @@ where
|
||||
frame_buf.reserve(wire_len);
|
||||
frame_buf.push(first);
|
||||
frame_buf.extend_from_slice(data);
|
||||
client_writer
|
||||
.write_all(frame_buf.as_slice())
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
|
||||
MeD2cWriteMode::Coalesced
|
||||
} else {
|
||||
let header = [first];
|
||||
client_writer
|
||||
.write_all(&header)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
client_writer
|
||||
.write_all(data)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
write_all_client_or_cancel(client_writer, &header, cancel).await?;
|
||||
write_all_client_or_cancel(client_writer, data, cancel).await?;
|
||||
MeD2cWriteMode::Split
|
||||
}
|
||||
} else if len_words < (1 << 24) {
|
||||
@@ -2664,21 +2673,12 @@ where
|
||||
frame_buf.reserve(wire_len);
|
||||
frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]);
|
||||
frame_buf.extend_from_slice(data);
|
||||
client_writer
|
||||
.write_all(frame_buf.as_slice())
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
|
||||
MeD2cWriteMode::Coalesced
|
||||
} else {
|
||||
let header = [first, lw[0], lw[1], lw[2]];
|
||||
client_writer
|
||||
.write_all(&header)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
client_writer
|
||||
.write_all(data)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
write_all_client_or_cancel(client_writer, &header, cancel).await?;
|
||||
write_all_client_or_cancel(client_writer, data, cancel).await?;
|
||||
MeD2cWriteMode::Split
|
||||
}
|
||||
} else {
|
||||
@@ -2713,21 +2713,12 @@ where
|
||||
frame_buf.resize(start + padding_len, 0);
|
||||
rng.fill(&mut frame_buf[start..]);
|
||||
}
|
||||
client_writer
|
||||
.write_all(frame_buf.as_slice())
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
|
||||
MeD2cWriteMode::Coalesced
|
||||
} else {
|
||||
let header = len_val.to_le_bytes();
|
||||
client_writer
|
||||
.write_all(&header)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
client_writer
|
||||
.write_all(data)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
write_all_client_or_cancel(client_writer, &header, cancel).await?;
|
||||
write_all_client_or_cancel(client_writer, data, cancel).await?;
|
||||
if padding_len > 0 {
|
||||
frame_buf.clear();
|
||||
if frame_buf.capacity() < padding_len {
|
||||
@@ -2735,10 +2726,7 @@ where
|
||||
}
|
||||
frame_buf.resize(padding_len, 0);
|
||||
rng.fill(frame_buf.as_mut_slice());
|
||||
client_writer
|
||||
.write_all(frame_buf.as_slice())
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
|
||||
}
|
||||
MeD2cWriteMode::Split
|
||||
}
|
||||
@@ -2752,6 +2740,7 @@ async fn write_client_ack<W>(
|
||||
client_writer: &mut CryptoWriter<W>,
|
||||
proto_tag: ProtoTag,
|
||||
confirm: u32,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<()>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
@@ -2761,10 +2750,34 @@ where
|
||||
} else {
|
||||
confirm.to_le_bytes()
|
||||
};
|
||||
client_writer
|
||||
.write_all(&bytes)
|
||||
.await
|
||||
.map_err(ProxyError::Io)
|
||||
write_all_client_or_cancel(client_writer, &bytes, cancel).await
|
||||
}
|
||||
|
||||
async fn write_all_client_or_cancel<W>(
|
||||
client_writer: &mut CryptoWriter<W>,
|
||||
bytes: &[u8],
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<()>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
tokio::select! {
|
||||
result = client_writer.write_all(bytes) => result.map_err(ProxyError::Io),
|
||||
_ = cancel.cancelled() => Err(ProxyError::Proxy("ME client writer cancelled".into())),
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush_client_or_cancel<W>(
|
||||
client_writer: &mut CryptoWriter<W>,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<()>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
tokio::select! {
|
||||
result = client_writer.flush() => result.map_err(ProxyError::Io),
|
||||
_ = cancel.cancelled() => Err(ProxyError::Proxy("ME client writer cancelled".into())),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user