scuffle_http/backend/
h3.rs1use std::fmt::Debug;
3use std::io;
4use std::net::SocketAddr;
5use std::sync::Arc;
6
7use body::QuicIncomingBody;
8use scuffle_context::ContextFutExt;
9#[cfg(feature = "tracing")]
10use tracing::Instrument;
11use utils::copy_response_body;
12
13use crate::error::HttpError;
14use crate::service::{HttpService, HttpServiceFactory};
15
16pub mod body;
17mod utils;
18
19#[derive(bon::Builder, Debug, Clone)]
25pub struct Http3Backend<F> {
26 #[builder(default = scuffle_context::Context::global())]
28 ctx: scuffle_context::Context,
29 #[builder(default = 1)]
31 worker_tasks: usize,
32 service_factory: F,
34 bind: SocketAddr,
39 rustls_config: tokio_rustls::rustls::ServerConfig,
44}
45
46impl<F> Http3Backend<F>
47where
48 F: HttpServiceFactory + Clone + Send + 'static,
49 F::Error: std::error::Error + Send,
50 F::Service: Clone + Send + 'static,
51 <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
52 <F::Service as HttpService>::ResBody: Send,
53 <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
54 <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
55{
56 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
60 pub async fn run(mut self) -> Result<(), HttpError<F>> {
61 #[cfg(feature = "tracing")]
62 tracing::debug!("starting server");
63
64 self.rustls_config.max_early_data_size = u32::MAX;
66 let crypto = h3_quinn::quinn::crypto::rustls::QuicServerConfig::try_from(self.rustls_config)?;
67 let server_config = h3_quinn::quinn::ServerConfig::with_crypto(Arc::new(crypto));
68
69 let socket = std::net::UdpSocket::bind(self.bind)?;
71
72 let runtime = h3_quinn::quinn::default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
74
75 let (worker_ctx, worker_handler) = self.ctx.new_child();
77
78 let workers = (0..self.worker_tasks).map(|_n| {
79 let ctx = worker_ctx.clone();
80 let service_factory = self.service_factory.clone();
81 let server_config = server_config.clone();
82 let socket = socket.try_clone().expect("failed to clone socket");
83 let runtime = Arc::clone(&runtime);
84
85 let worker_fut = async move {
86 let endpoint = h3_quinn::quinn::Endpoint::new(
87 h3_quinn::quinn::EndpointConfig::default(),
88 Some(server_config),
89 socket,
90 runtime,
91 )?;
92
93 #[cfg(feature = "tracing")]
94 tracing::trace!("waiting for connections");
95
96 while let Some(Some(new_conn)) = endpoint.accept().with_context(&ctx).await {
97 let mut service_factory = service_factory.clone();
98 let ctx = ctx.clone();
99
100 tokio::spawn(async move {
101 let _res: Result<_, HttpError<F>> = async move {
102 let Some(conn) = new_conn.with_context(&ctx).await.transpose()? else {
103 #[cfg(feature = "tracing")]
104 tracing::trace!("context done while accepting connection");
105 return Ok(());
106 };
107 let addr = conn.remote_address();
108 let client_certs = conn
109 .peer_identity()
110 .and_then(|any| any.downcast::<Vec<tokio_rustls::rustls::pki_types::CertificateDer>>().ok());
111
112 #[cfg(feature = "tracing")]
113 tracing::debug!(addr = %addr, "accepted quic connection");
114
115 let connection_fut = async move {
116 let Some(mut h3_conn) = h3::server::Connection::new(h3_quinn::Connection::new(conn))
117 .with_context(&ctx)
118 .await
119 .transpose()?
120 else {
121 #[cfg(feature = "tracing")]
122 tracing::trace!("context done while establishing connection");
123 return Ok(());
124 };
125
126 let mut extra_extensions = http::Extensions::new();
127 extra_extensions.insert(crate::extensions::ClientAddr(addr));
128 if let Some(certs) = client_certs {
129 extra_extensions.insert(crate::extensions::ClientIdentity(Arc::new(*certs)));
130 }
131
132 let http_service = service_factory
134 .new_service(addr)
135 .await
136 .map_err(|e| HttpError::ServiceFactoryError(e))?;
137
138 loop {
139 match h3_conn.accept().with_context(&ctx).await {
140 Some(Ok(Some(resolver))) => {
141 let (req, stream) = match resolver.resolve_request().await {
142 Ok(r) => r,
143 Err(_err) => {
144 #[cfg(feature = "tracing")]
145 tracing::warn!("error on accept: {}", _err);
146 continue;
147 }
148 };
149
150 #[cfg(feature = "tracing")]
151 tracing::debug!(method = %req.method(), uri = %req.uri(), "received request");
152
153 let (mut send, recv) = stream.split();
154
155 let size_hint = req
156 .headers()
157 .get(http::header::CONTENT_LENGTH)
158 .and_then(|len| len.to_str().ok().and_then(|x| x.parse().ok()));
159 let body = QuicIncomingBody::new(recv, size_hint);
160 let mut req = req.map(|_| crate::body::IncomingBody::from(body));
161
162 req.extensions_mut().extend(extra_extensions.clone());
163
164 let ctx = ctx.clone();
165 let mut http_service = http_service.clone();
166 tokio::spawn(async move {
167 let _res: Result<_, HttpError<F>> = async move {
168 let resp = http_service
169 .call(req)
170 .await
171 .map_err(|e| HttpError::ServiceError(e))?;
172 let (parts, body) = resp.into_parts();
173
174 send.send_response(http::Response::from_parts(parts, ())).await?;
175 copy_response_body(send, body).await?;
176
177 Ok(())
178 }
179 .await;
180
181 #[cfg(feature = "tracing")]
182 if let Err(e) = _res {
183 tracing::warn!(err = %e, "error handling request");
184 }
185
186 drop(ctx);
188 });
189 }
190 Some(Ok(None)) => {
192 break;
193 }
194 Some(Err(err)) => return Err(err.into()),
195 None => {
197 #[cfg(feature = "tracing")]
198 tracing::trace!("context done, stopping connection loop");
199 break;
200 }
201 }
202 }
203
204 #[cfg(feature = "tracing")]
205 tracing::trace!("connection closed");
206
207 Ok(())
208 };
209
210 #[cfg(feature = "tracing")]
211 let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
212
213 connection_fut.await
214 }
215 .await;
216
217 #[cfg(feature = "tracing")]
218 if let Err(err) = _res {
219 tracing::warn!(err = %err, "error handling connection");
220 }
221 });
222 }
223
224 endpoint.wait_idle().await;
227
228 Ok::<_, crate::error::HttpError<F>>(())
229 };
230
231 #[cfg(feature = "tracing")]
232 let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
233
234 tokio::spawn(worker_fut)
235 });
236
237 if let Err(_e) = futures::future::try_join_all(workers).await {
238 #[cfg(feature = "tracing")]
239 tracing::error!(err = %_e, "error running workers");
240 }
241
242 drop(worker_ctx);
243 worker_handler.shutdown().await;
244
245 #[cfg(feature = "tracing")]
246 tracing::debug!("all workers finished");
247
248 Ok(())
249 }
250}