scuffle_http/backend/hyper/
stream.rs

1use tokio::io::{AsyncRead, AsyncWrite};
2
3/// A stream that can be either a TCP stream or a TLS stream.
4///
5/// Implements [`AsyncRead`] and [`AsyncWrite`] by delegating to the inner stream.
6pub(crate) enum Stream {
7    Tcp(tokio::net::TcpStream),
8    #[cfg(feature = "tls-rustls")]
9    Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
10}
11
12impl Stream {
13    /// Try to upgrade the stream to a TLS stream by using a TLS acceptor.
14    ///
15    /// If the stream is already a TLS stream, this function will return the stream unchanged.
16    #[cfg(feature = "tls-rustls")]
17    pub(crate) async fn try_accept_tls(self, tls_acceptor: &tokio_rustls::TlsAcceptor) -> std::io::Result<Self> {
18        match self {
19            Stream::Tcp(stream) => {
20                let stream = tls_acceptor.accept(stream).await?;
21                Ok(Self::Tls(Box::new(stream)))
22            }
23            Stream::Tls(_) => Ok(self),
24        }
25    }
26
27    /// Get the client certificates from the TLS stream.
28    ///
29    /// Returns `None` if the stream is not a TLS stream or if no client certificates are present.
30    #[cfg(feature = "tls-rustls")]
31    pub(crate) fn get_client_certs(&self) -> Option<&[tokio_rustls::rustls::pki_types::CertificateDer<'static>]> {
32        match self {
33            Stream::Tcp(_) => None,
34            Stream::Tls(stream) => stream.get_ref().1.peer_certificates(),
35        }
36    }
37}
38
39impl AsyncRead for Stream {
40    fn poll_read(
41        self: std::pin::Pin<&mut Self>,
42        cx: &mut std::task::Context<'_>,
43        buf: &mut tokio::io::ReadBuf<'_>,
44    ) -> std::task::Poll<std::io::Result<()>> {
45        match self.get_mut() {
46            Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
47            #[cfg(feature = "tls-rustls")]
48            Stream::Tls(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
49        }
50    }
51}
52
53impl AsyncWrite for Stream {
54    fn poll_write(
55        self: std::pin::Pin<&mut Self>,
56        cx: &mut std::task::Context<'_>,
57        buf: &[u8],
58    ) -> std::task::Poll<Result<usize, std::io::Error>> {
59        match self.get_mut() {
60            Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
61            #[cfg(feature = "tls-rustls")]
62            Stream::Tls(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
63        }
64    }
65
66    fn poll_flush(
67        self: std::pin::Pin<&mut Self>,
68        cx: &mut std::task::Context<'_>,
69    ) -> std::task::Poll<Result<(), std::io::Error>> {
70        match self.get_mut() {
71            Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx),
72            #[cfg(feature = "tls-rustls")]
73            Stream::Tls(stream) => std::pin::Pin::new(stream).poll_flush(cx),
74        }
75    }
76
77    fn poll_shutdown(
78        self: std::pin::Pin<&mut Self>,
79        cx: &mut std::task::Context<'_>,
80    ) -> std::task::Poll<Result<(), std::io::Error>> {
81        match self.get_mut() {
82            Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
83            #[cfg(feature = "tls-rustls")]
84            Stream::Tls(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
85        }
86    }
87
88    fn poll_write_vectored(
89        self: std::pin::Pin<&mut Self>,
90        cx: &mut std::task::Context<'_>,
91        bufs: &[std::io::IoSlice<'_>],
92    ) -> std::task::Poll<Result<usize, std::io::Error>> {
93        match self.get_mut() {
94            Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_write_vectored(cx, bufs),
95            #[cfg(feature = "tls-rustls")]
96            Stream::Tls(stream) => std::pin::Pin::new(stream).poll_write_vectored(cx, bufs),
97        }
98    }
99
100    fn is_write_vectored(&self) -> bool {
101        match self {
102            Stream::Tcp(stream) => stream.is_write_vectored(),
103            #[cfg(feature = "tls-rustls")]
104            Stream::Tls(stream) => stream.is_write_vectored(),
105        }
106    }
107}