91 lines
3.1 KiB
Rust
91 lines
3.1 KiB
Rust
//! Wrapper that ensures TLS close_notify is sent when the stream is dropped.
|
|
//!
|
|
//! When hyper drops an HTTP connection (backend error, timeout, normal H2 close),
|
|
//! the underlying TLS stream is dropped WITHOUT `shutdown()`. tokio-rustls cannot
|
|
//! send `close_notify` in Drop (requires async). This wrapper tracks whether
|
|
//! `poll_shutdown` was called and, if not, spawns a background task to send it.
|
|
|
|
use std::io;
|
|
use std::pin::Pin;
|
|
use std::task::{Context, Poll};
|
|
|
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
|
|
|
/// Wraps an AsyncRead+AsyncWrite stream and ensures `shutdown()` is called when
|
|
/// dropped, even if the caller (e.g. hyper) doesn't explicitly shut down.
|
|
///
|
|
/// This guarantees TLS `close_notify` is sent for TLS-wrapped streams, preventing
|
|
/// "GnuTLS recv error (-110): The TLS connection was non-properly terminated" errors.
|
|
pub struct ShutdownOnDrop<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> {
|
|
inner: Option<S>,
|
|
shutdown_called: bool,
|
|
}
|
|
|
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> ShutdownOnDrop<S> {
|
|
/// Create a new wrapper around the given stream.
|
|
pub fn new(stream: S) -> Self {
|
|
Self {
|
|
inner: Some(stream),
|
|
shutdown_called: false,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncRead for ShutdownOnDrop<S> {
|
|
fn poll_read(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut ReadBuf<'_>,
|
|
) -> Poll<io::Result<()>> {
|
|
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_read(cx, buf)
|
|
}
|
|
}
|
|
|
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for ShutdownOnDrop<S> {
|
|
fn poll_write(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_write(cx, buf)
|
|
}
|
|
|
|
fn poll_flush(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
) -> Poll<io::Result<()>> {
|
|
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx)
|
|
}
|
|
|
|
fn poll_shutdown(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
) -> Poll<io::Result<()>> {
|
|
let this = self.get_mut();
|
|
let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx);
|
|
if result.is_ready() {
|
|
this.shutdown_called = true;
|
|
}
|
|
result
|
|
}
|
|
}
|
|
|
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Drop for ShutdownOnDrop<S> {
|
|
fn drop(&mut self) {
|
|
// If shutdown was already called (hyper closed properly), nothing to do.
|
|
// If not (hyper dropped without shutdown — e.g. H2 close, error, timeout),
|
|
// spawn a background task to send close_notify / TCP FIN.
|
|
if !self.shutdown_called {
|
|
if let Some(mut stream) = self.inner.take() {
|
|
tokio::spawn(async move {
|
|
let _ = tokio::time::timeout(
|
|
std::time::Duration::from_secs(2),
|
|
tokio::io::AsyncWriteExt::shutdown(&mut stream),
|
|
).await;
|
|
// stream is dropped here — all resources freed
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|