scuffle_http/backend/
hyper.rs1use std::fmt::Debug;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use scuffle_context::ContextFutExt;
7#[cfg(feature = "tracing")]
8use tracing::Instrument;
9
10use crate::error::HttpError;
11use crate::service::{HttpService, HttpServiceFactory};
12
13mod handler;
14mod stream;
15mod utils;
16
17#[derive(Debug, Clone, bon::Builder)]
23pub struct HyperBackend<F> {
24 #[builder(default = scuffle_context::Context::global())]
26 ctx: scuffle_context::Context,
27 #[builder(default = 1)]
29 worker_tasks: usize,
30 service_factory: F,
32 bind: SocketAddr,
37 #[cfg(feature = "tls-rustls")]
42 rustls_config: Option<tokio_rustls::rustls::ServerConfig>,
43 #[cfg(feature = "http1")]
45 #[builder(default = true)]
46 http1_enabled: bool,
47 #[cfg(feature = "http2")]
49 #[builder(default = true)]
50 http2_enabled: bool,
51}
52
53impl<F> HyperBackend<F>
54where
55 F: HttpServiceFactory + Clone + Send + 'static,
56 F::Error: std::error::Error + Send,
57 F::Service: Clone + Send + 'static,
58 <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
59 <F::Service as HttpService>::ResBody: Send,
60 <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
61 <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
62{
63 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
67 #[allow(unused_mut)] pub async fn run(mut self) -> Result<(), HttpError<F>> {
69 #[cfg(feature = "tracing")]
70 tracing::debug!("starting server");
71
72 #[cfg(feature = "tls-rustls")]
75 if let Some(rustls_config) = self.rustls_config.as_mut() {
76 rustls_config.max_early_data_size = 0;
77 }
78
79 let listener = tokio::net::TcpListener::bind(self.bind).await?.into_std()?;
81
82 #[cfg(feature = "tls-rustls")]
83 let tls_acceptor = self
84 .rustls_config
85 .map(|c| tokio_rustls::TlsAcceptor::from(std::sync::Arc::new(c)));
86
87 let (worker_ctx, worker_handler) = self.ctx.new_child();
89
90 let workers = (0..self.worker_tasks)
91 .map(|_n| {
92 let service_factory = self.service_factory.clone();
93 let ctx = worker_ctx.clone();
94 let std_listener = listener.try_clone()?;
95 let listener = tokio::net::TcpListener::from_std(std_listener)?;
96 #[cfg(feature = "tls-rustls")]
97 let tls_acceptor = tls_acceptor.clone();
98
99 let worker_fut = async move {
100 loop {
101 #[cfg(feature = "tracing")]
102 tracing::trace!("waiting for connections");
103
104 let (mut stream, addr) = match listener.accept().with_context(ctx.clone()).await {
105 Some(Ok((tcp_stream, addr))) => (stream::Stream::Tcp(tcp_stream), addr),
106 Some(Err(e)) if utils::is_fatal_tcp_error(&e) => {
107 #[cfg(feature = "tracing")]
108 tracing::error!(err = %e, "failed to accept tcp connection");
109 return Err(HttpError::<F>::from(e));
110 }
111 Some(Err(_)) => continue,
112 None => {
113 #[cfg(feature = "tracing")]
114 tracing::trace!("context done, stopping listener");
115 break;
116 }
117 };
118
119 #[cfg(feature = "tracing")]
120 tracing::trace!(addr = %addr, "accepted tcp connection");
121
122 let ctx = ctx.clone();
123 #[cfg(feature = "tls-rustls")]
124 let tls_acceptor = tls_acceptor.clone();
125 let mut service_factory = service_factory.clone();
126
127 let connection_fut = async move {
128 #[cfg(feature = "tls-rustls")]
130 if let Some(tls_acceptor) = tls_acceptor {
131 #[cfg(feature = "tracing")]
132 tracing::trace!("accepting tls connection");
133
134 stream = match stream.try_accept_tls(&tls_acceptor).with_context(&ctx).await {
135 Some(Ok(stream)) => stream,
136 Some(Err(_err)) => {
137 #[cfg(feature = "tracing")]
138 tracing::warn!(err = %_err, "failed to accept tls connection");
139 return;
140 }
141 None => {
142 #[cfg(feature = "tracing")]
143 tracing::trace!("context done, stopping tls acceptor");
144 return;
145 }
146 };
147
148 #[cfg(feature = "tracing")]
149 tracing::trace!("accepted tls connection");
150 }
151
152 let mut extra_extensions = http::Extensions::new();
153 extra_extensions.insert(crate::extensions::ClientAddr(addr));
154
155 #[cfg(feature = "tls-rustls")]
156 if let Some(certs) = stream.get_client_certs() {
157 extra_extensions.insert(crate::extensions::ClientIdentity(Arc::new(certs.to_vec())));
158 }
159
160 let http_service = match service_factory.new_service(addr).await {
162 Ok(service) => service,
163 Err(_e) => {
164 #[cfg(feature = "tracing")]
165 tracing::warn!(err = %_e, "failed to create service");
166 return;
167 }
168 };
169
170 #[cfg(feature = "tracing")]
171 tracing::trace!("handling connection");
172
173 #[cfg(feature = "http1")]
174 let http1 = self.http1_enabled;
175 #[cfg(not(feature = "http1"))]
176 let http1 = false;
177
178 #[cfg(feature = "http2")]
179 let http2 = self.http2_enabled;
180 #[cfg(not(feature = "http2"))]
181 let http2 = false;
182
183 let _res = handler::handle_connection::<F, _, _>(
184 ctx,
185 http_service,
186 extra_extensions,
187 stream,
188 http1,
189 http2,
190 )
191 .await;
192
193 #[cfg(feature = "tracing")]
194 if let Err(e) = _res {
195 tracing::warn!(err = %e, "error handling connection");
196 }
197
198 #[cfg(feature = "tracing")]
199 tracing::trace!("connection closed");
200 };
201
202 #[cfg(feature = "tracing")]
203 let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
204
205 tokio::spawn(connection_fut);
206 }
207
208 #[cfg(feature = "tracing")]
209 tracing::trace!("listener closed");
210
211 Ok(())
212 };
213
214 #[cfg(feature = "tracing")]
215 let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
216
217 Ok(tokio::spawn(worker_fut))
218 })
219 .collect::<std::io::Result<Vec<_>>>()?;
220
221 match futures::future::try_join_all(workers).await {
222 Ok(res) => {
223 for r in res {
224 if let Err(e) = r {
225 drop(worker_ctx);
226 worker_handler.shutdown().await;
227 return Err(e);
228 }
229 }
230 }
231 Err(_e) => {
232 #[cfg(feature = "tracing")]
233 tracing::error!(err = %_e, "error running workers");
234 }
235 }
236
237 drop(worker_ctx);
238 worker_handler.shutdown().await;
239
240 #[cfg(feature = "tracing")]
241 tracing::debug!("all workers finished");
242
243 Ok(())
244 }
245}