scuffle_http/backend/hyper/
stream.rs

1use tokio::io::{AsyncRead, AsyncWrite};
2
3/// A stream that can be either a TCP stream or a TLS stream.
4///
5/// Implements [`AsyncRead`] and [`AsyncWrite`] by delegating to the inner stream.
6pub(crate) enum Stream {
7    Tcp(tokio::net::TcpStream),
8    #[cfg(feature = "tls-rustls")]
9    Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
10}
11
12impl Stream {
13    /// Try to upgrade the stream to a TLS stream by using a TLS acceptor.
14    ///
15    /// If the stream is already a TLS stream, this function will return the stream unchanged.
16    #[cfg(feature = "tls-rustls")]
17    pub(crate) async fn try_accept_tls(self, tls_acceptor: &tokio_rustls::TlsAcceptor) -> std::io::Result<Self> {
18        match self {
19            Stream::Tcp(stream) => {
20                let stream = tls_acceptor.accept(stream).await?;
21                Ok(Self::Tls(Box::new(stream)))
22            }
23            Stream::Tls(_) => Ok(self),
24        }
25    }
26}
27
28impl AsyncRead for Stream {
29    fn poll_read(
30        self: std::pin::Pin<&mut Self>,
31        cx: &mut std::task::Context<'_>,
32        buf: &mut tokio::io::ReadBuf<'_>,
33    ) -> std::task::Poll<std::io::Result<()>> {
34        match self.get_mut() {
35            Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
36            #[cfg(feature = "tls-rustls")]
37            Stream::Tls(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
38        }
39    }
40}
41
42impl AsyncWrite for Stream {
43    fn poll_write(
44        self: std::pin::Pin<&mut Self>,
45        cx: &mut std::task::Context<'_>,
46        buf: &[u8],
47    ) -> std::task::Poll<Result<usize, std::io::Error>> {
48        match self.get_mut() {
49            Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
50            #[cfg(feature = "tls-rustls")]
51            Stream::Tls(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
52        }
53    }
54
55    fn poll_flush(
56        self: std::pin::Pin<&mut Self>,
57        cx: &mut std::task::Context<'_>,
58    ) -> std::task::Poll<Result<(), std::io::Error>> {
59        match self.get_mut() {
60            Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx),
61            #[cfg(feature = "tls-rustls")]
62            Stream::Tls(stream) => std::pin::Pin::new(stream).poll_flush(cx),
63        }
64    }
65
66    fn poll_shutdown(
67        self: std::pin::Pin<&mut Self>,
68        cx: &mut std::task::Context<'_>,
69    ) -> std::task::Poll<Result<(), std::io::Error>> {
70        match self.get_mut() {
71            Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
72            #[cfg(feature = "tls-rustls")]
73            Stream::Tls(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
74        }
75    }
76
77    fn poll_write_vectored(
78        self: std::pin::Pin<&mut Self>,
79        cx: &mut std::task::Context<'_>,
80        bufs: &[std::io::IoSlice<'_>],
81    ) -> std::task::Poll<Result<usize, std::io::Error>> {
82        match self.get_mut() {
83            Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_write_vectored(cx, bufs),
84            #[cfg(feature = "tls-rustls")]
85            Stream::Tls(stream) => std::pin::Pin::new(stream).poll_write_vectored(cx, bufs),
86        }
87    }
88
89    fn is_write_vectored(&self) -> bool {
90        match self {
91            Stream::Tcp(stream) => stream.is_write_vectored(),
92            #[cfg(feature = "tls-rustls")]
93            Stream::Tls(stream) => stream.is_write_vectored(),
94        }
95    }
96}