scuffle_http/backend/hyper/
stream.rs1use tokio::io::{AsyncRead, AsyncWrite};
2
3pub(crate) enum Stream {
7 Tcp(tokio::net::TcpStream),
8 #[cfg(feature = "tls-rustls")]
9 Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
10}
11
12impl Stream {
13 #[cfg(feature = "tls-rustls")]
17 pub(crate) async fn try_accept_tls(self, tls_acceptor: &tokio_rustls::TlsAcceptor) -> std::io::Result<Self> {
18 match self {
19 Stream::Tcp(stream) => {
20 let stream = tls_acceptor.accept(stream).await?;
21 Ok(Self::Tls(Box::new(stream)))
22 }
23 Stream::Tls(_) => Ok(self),
24 }
25 }
26}
27
28impl AsyncRead for Stream {
29 fn poll_read(
30 self: std::pin::Pin<&mut Self>,
31 cx: &mut std::task::Context<'_>,
32 buf: &mut tokio::io::ReadBuf<'_>,
33 ) -> std::task::Poll<std::io::Result<()>> {
34 match self.get_mut() {
35 Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
36 #[cfg(feature = "tls-rustls")]
37 Stream::Tls(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
38 }
39 }
40}
41
42impl AsyncWrite for Stream {
43 fn poll_write(
44 self: std::pin::Pin<&mut Self>,
45 cx: &mut std::task::Context<'_>,
46 buf: &[u8],
47 ) -> std::task::Poll<Result<usize, std::io::Error>> {
48 match self.get_mut() {
49 Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
50 #[cfg(feature = "tls-rustls")]
51 Stream::Tls(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
52 }
53 }
54
55 fn poll_flush(
56 self: std::pin::Pin<&mut Self>,
57 cx: &mut std::task::Context<'_>,
58 ) -> std::task::Poll<Result<(), std::io::Error>> {
59 match self.get_mut() {
60 Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx),
61 #[cfg(feature = "tls-rustls")]
62 Stream::Tls(stream) => std::pin::Pin::new(stream).poll_flush(cx),
63 }
64 }
65
66 fn poll_shutdown(
67 self: std::pin::Pin<&mut Self>,
68 cx: &mut std::task::Context<'_>,
69 ) -> std::task::Poll<Result<(), std::io::Error>> {
70 match self.get_mut() {
71 Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
72 #[cfg(feature = "tls-rustls")]
73 Stream::Tls(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
74 }
75 }
76
77 fn poll_write_vectored(
78 self: std::pin::Pin<&mut Self>,
79 cx: &mut std::task::Context<'_>,
80 bufs: &[std::io::IoSlice<'_>],
81 ) -> std::task::Poll<Result<usize, std::io::Error>> {
82 match self.get_mut() {
83 Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_write_vectored(cx, bufs),
84 #[cfg(feature = "tls-rustls")]
85 Stream::Tls(stream) => std::pin::Pin::new(stream).poll_write_vectored(cx, bufs),
86 }
87 }
88
89 fn is_write_vectored(&self) -> bool {
90 match self {
91 Stream::Tcp(stream) => stream.is_write_vectored(),
92 #[cfg(feature = "tls-rustls")]
93 Stream::Tls(stream) => stream.is_write_vectored(),
94 }
95 }
96}