scuffle_http/backend/
h3.rs

1//! HTTP3 backend.
2use std::fmt::Debug;
3use std::io;
4use std::net::SocketAddr;
5use std::sync::Arc;
6
7use body::QuicIncomingBody;
8use scuffle_context::ContextFutExt;
9#[cfg(feature = "tracing")]
10use tracing::Instrument;
11use utils::copy_response_body;
12
13use crate::error::HttpError;
14use crate::service::{HttpService, HttpServiceFactory};
15
16pub mod body;
17mod utils;
18
19/// A backend that handles incoming HTTP3 connections.
20///
21/// This is used internally by the [`HttpServer`](crate::server::HttpServer) but can be used directly if preferred.
22///
23/// Call [`run`](Http3Backend::run) to start the server.
24#[derive(bon::Builder, Debug, Clone)]
25pub struct Http3Backend<F> {
26    /// The [`scuffle_context::Context`] this server will live by.
27    #[builder(default = scuffle_context::Context::global())]
28    ctx: scuffle_context::Context,
29    /// The number of worker tasks to spawn for each server backend.
30    #[builder(default = 1)]
31    worker_tasks: usize,
32    /// The service factory that will be used to create new services.
33    service_factory: F,
34    /// The address to bind to.
35    ///
36    /// Use `[::]` for a dual-stack listener.
37    /// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6.
38    bind: SocketAddr,
39    /// rustls config.
40    ///
41    /// Use this field to set the server into TLS mode.
42    /// It will only accept TLS connections when this is set.
43    rustls_config: tokio_rustls::rustls::ServerConfig,
44}
45
46impl<F> Http3Backend<F>
47where
48    F: HttpServiceFactory + Clone + Send + 'static,
49    F::Error: std::error::Error + Send,
50    F::Service: Clone + Send + 'static,
51    <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
52    <F::Service as HttpService>::ResBody: Send,
53    <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
54    <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
55{
56    /// Run the HTTP3 server
57    ///
58    /// This function will bind to the address specified in `bind`, listen for incoming connections and handle requests.
59    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
60    pub async fn run(mut self) -> Result<(), HttpError<F>> {
61        #[cfg(feature = "tracing")]
62        tracing::debug!("starting server");
63
64        // not quite sure why this is necessary but it is
65        self.rustls_config.max_early_data_size = u32::MAX;
66        let crypto = h3_quinn::quinn::crypto::rustls::QuicServerConfig::try_from(self.rustls_config)?;
67        let server_config = h3_quinn::quinn::ServerConfig::with_crypto(Arc::new(crypto));
68
69        // Bind the UDP socket
70        let socket = std::net::UdpSocket::bind(self.bind)?;
71
72        // Runtime for the quinn endpoint
73        let runtime = h3_quinn::quinn::default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
74
75        // Create a child context for the workers so we can shut them down if one of them fails without shutting down the main context
76        let (worker_ctx, worker_handler) = self.ctx.new_child();
77
78        let workers = (0..self.worker_tasks).map(|_n| {
79            let ctx = worker_ctx.clone();
80            let service_factory = self.service_factory.clone();
81            let server_config = server_config.clone();
82            let socket = socket.try_clone().expect("failed to clone socket");
83            let runtime = Arc::clone(&runtime);
84
85            let worker_fut = async move {
86                let endpoint = h3_quinn::quinn::Endpoint::new(
87                    h3_quinn::quinn::EndpointConfig::default(),
88                    Some(server_config),
89                    socket,
90                    runtime,
91                )?;
92
93                #[cfg(feature = "tracing")]
94                tracing::trace!("waiting for connections");
95
96                while let Some(Some(new_conn)) = endpoint.accept().with_context(&ctx).await {
97                    let mut service_factory = service_factory.clone();
98                    let ctx = ctx.clone();
99
100                    tokio::spawn(async move {
101                        let _res: Result<_, HttpError<F>> = async move {
102                            let Some(conn) = new_conn.with_context(&ctx).await.transpose()? else {
103                                #[cfg(feature = "tracing")]
104                                tracing::trace!("context done while accepting connection");
105                                return Ok(());
106                            };
107                            let addr = conn.remote_address();
108                            let client_certs = conn
109                                .peer_identity()
110                                .and_then(|any| any.downcast::<Vec<tokio_rustls::rustls::pki_types::CertificateDer>>().ok());
111
112                            #[cfg(feature = "tracing")]
113                            tracing::debug!(addr = %addr, "accepted quic connection");
114
115                            let connection_fut = async move {
116                                let Some(mut h3_conn) = h3::server::Connection::new(h3_quinn::Connection::new(conn))
117                                    .with_context(&ctx)
118                                    .await
119                                    .transpose()?
120                                else {
121                                    #[cfg(feature = "tracing")]
122                                    tracing::trace!("context done while establishing connection");
123                                    return Ok(());
124                                };
125
126                                let mut extra_extensions = http::Extensions::new();
127                                extra_extensions.insert(crate::extensions::ClientAddr(addr));
128                                if let Some(certs) = client_certs {
129                                    extra_extensions.insert(crate::extensions::ClientIdentity(Arc::new(*certs)));
130                                }
131
132                                // make a new service for this connection
133                                let http_service = service_factory
134                                    .new_service(addr)
135                                    .await
136                                    .map_err(|e| HttpError::ServiceFactoryError(e))?;
137
138                                loop {
139                                    match h3_conn.accept().with_context(&ctx).await {
140                                        Some(Ok(Some(resolver))) => {
141                                            let (req, stream) = match resolver.resolve_request().await {
142                                                Ok(r) => r,
143                                                Err(_err) => {
144                                                    #[cfg(feature = "tracing")]
145                                                    tracing::warn!("error on accept: {}", _err);
146                                                    continue;
147                                                }
148                                            };
149
150                                            #[cfg(feature = "tracing")]
151                                            tracing::debug!(method = %req.method(), uri = %req.uri(), "received request");
152
153                                            let (mut send, recv) = stream.split();
154
155                                            let size_hint = req
156                                                .headers()
157                                                .get(http::header::CONTENT_LENGTH)
158                                                .and_then(|len| len.to_str().ok().and_then(|x| x.parse().ok()));
159                                            let body = QuicIncomingBody::new(recv, size_hint);
160                                            let mut req = req.map(|_| crate::body::IncomingBody::from(body));
161
162                                            req.extensions_mut().extend(extra_extensions.clone());
163
164                                            let ctx = ctx.clone();
165                                            let mut http_service = http_service.clone();
166                                            tokio::spawn(async move {
167                                                let _res: Result<_, HttpError<F>> = async move {
168                                                    let resp = http_service
169                                                        .call(req)
170                                                        .await
171                                                        .map_err(|e| HttpError::ServiceError(e))?;
172                                                    let (parts, body) = resp.into_parts();
173
174                                                    send.send_response(http::Response::from_parts(parts, ())).await?;
175                                                    copy_response_body(send, body).await?;
176
177                                                    Ok(())
178                                                }
179                                                .await;
180
181                                                #[cfg(feature = "tracing")]
182                                                if let Err(e) = _res {
183                                                    tracing::warn!(err = %e, "error handling request");
184                                                }
185
186                                                // This moves the context into the async block because it is dropped here
187                                                drop(ctx);
188                                            });
189                                        }
190                                        // indicating no more streams to be received
191                                        Some(Ok(None)) => {
192                                            break;
193                                        }
194                                        Some(Err(err)) => return Err(err.into()),
195                                        // context is done
196                                        None => {
197                                            #[cfg(feature = "tracing")]
198                                            tracing::trace!("context done, stopping connection loop");
199                                            break;
200                                        }
201                                    }
202                                }
203
204                                #[cfg(feature = "tracing")]
205                                tracing::trace!("connection closed");
206
207                                Ok(())
208                            };
209
210                            #[cfg(feature = "tracing")]
211                            let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
212
213                            connection_fut.await
214                        }
215                        .await;
216
217                        #[cfg(feature = "tracing")]
218                        if let Err(err) = _res {
219                            tracing::warn!(err = %err, "error handling connection");
220                        }
221                    });
222                }
223
224                // shut down gracefully
225                // wait for connections to be closed before exiting
226                endpoint.wait_idle().await;
227
228                Ok::<_, crate::error::HttpError<F>>(())
229            };
230
231            #[cfg(feature = "tracing")]
232            let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
233
234            tokio::spawn(worker_fut)
235        });
236
237        if let Err(_e) = futures::future::try_join_all(workers).await {
238            #[cfg(feature = "tracing")]
239            tracing::error!(err = %_e, "error running workers");
240        }
241
242        drop(worker_ctx);
243        worker_handler.shutdown().await;
244
245        #[cfg(feature = "tracing")]
246        tracing::debug!("all workers finished");
247
248        Ok(())
249    }
250}