scuffle_http/backend/hyper/
stream.rs1use tokio::io::{AsyncRead, AsyncWrite};
2
3pub(crate) enum Stream {
7 Tcp(tokio::net::TcpStream),
8 #[cfg(feature = "tls-rustls")]
9 Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
10}
11
12impl Stream {
13 #[cfg(feature = "tls-rustls")]
17 pub(crate) async fn try_accept_tls(self, tls_acceptor: &tokio_rustls::TlsAcceptor) -> std::io::Result<Self> {
18 match self {
19 Stream::Tcp(stream) => {
20 let stream = tls_acceptor.accept(stream).await?;
21 Ok(Self::Tls(Box::new(stream)))
22 }
23 Stream::Tls(_) => Ok(self),
24 }
25 }
26
27 #[cfg(feature = "tls-rustls")]
31 pub(crate) fn get_client_certs(&self) -> Option<&[tokio_rustls::rustls::pki_types::CertificateDer<'static>]> {
32 match self {
33 Stream::Tcp(_) => None,
34 Stream::Tls(stream) => stream.get_ref().1.peer_certificates(),
35 }
36 }
37}
38
39impl AsyncRead for Stream {
40 fn poll_read(
41 self: std::pin::Pin<&mut Self>,
42 cx: &mut std::task::Context<'_>,
43 buf: &mut tokio::io::ReadBuf<'_>,
44 ) -> std::task::Poll<std::io::Result<()>> {
45 match self.get_mut() {
46 Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
47 #[cfg(feature = "tls-rustls")]
48 Stream::Tls(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
49 }
50 }
51}
52
53impl AsyncWrite for Stream {
54 fn poll_write(
55 self: std::pin::Pin<&mut Self>,
56 cx: &mut std::task::Context<'_>,
57 buf: &[u8],
58 ) -> std::task::Poll<Result<usize, std::io::Error>> {
59 match self.get_mut() {
60 Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
61 #[cfg(feature = "tls-rustls")]
62 Stream::Tls(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
63 }
64 }
65
66 fn poll_flush(
67 self: std::pin::Pin<&mut Self>,
68 cx: &mut std::task::Context<'_>,
69 ) -> std::task::Poll<Result<(), std::io::Error>> {
70 match self.get_mut() {
71 Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx),
72 #[cfg(feature = "tls-rustls")]
73 Stream::Tls(stream) => std::pin::Pin::new(stream).poll_flush(cx),
74 }
75 }
76
77 fn poll_shutdown(
78 self: std::pin::Pin<&mut Self>,
79 cx: &mut std::task::Context<'_>,
80 ) -> std::task::Poll<Result<(), std::io::Error>> {
81 match self.get_mut() {
82 Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
83 #[cfg(feature = "tls-rustls")]
84 Stream::Tls(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
85 }
86 }
87
88 fn poll_write_vectored(
89 self: std::pin::Pin<&mut Self>,
90 cx: &mut std::task::Context<'_>,
91 bufs: &[std::io::IoSlice<'_>],
92 ) -> std::task::Poll<Result<usize, std::io::Error>> {
93 match self.get_mut() {
94 Stream::Tcp(stream) => std::pin::Pin::new(stream).poll_write_vectored(cx, bufs),
95 #[cfg(feature = "tls-rustls")]
96 Stream::Tls(stream) => std::pin::Pin::new(stream).poll_write_vectored(cx, bufs),
97 }
98 }
99
100 fn is_write_vectored(&self) -> bool {
101 match self {
102 Stream::Tcp(stream) => stream.is_write_vectored(),
103 #[cfg(feature = "tls-rustls")]
104 Stream::Tls(stream) => stream.is_write_vectored(),
105 }
106 }
107}