scuffle_http/backend/
hyper.rs

1//! Hyper backend.
2use std::fmt::Debug;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use scuffle_context::ContextFutExt;
7#[cfg(feature = "tracing")]
8use tracing::Instrument;
9
10use crate::error::HttpError;
11use crate::service::{HttpService, HttpServiceFactory};
12
13mod handler;
14mod stream;
15mod utils;
16
17/// A backend that handles incoming HTTP connections using a hyper backend.
18///
19/// This is used internally by the [`HttpServer`](crate::server::HttpServer) but can be used directly if preferred.
20///
21/// Call [`run`](HyperBackend::run) to start the server.
22#[derive(Debug, Clone, bon::Builder)]
23pub struct HyperBackend<F> {
24    /// The [`scuffle_context::Context`] this server will live by.
25    #[builder(default = scuffle_context::Context::global())]
26    ctx: scuffle_context::Context,
27    /// The number of worker tasks to spawn for each server backend.
28    #[builder(default = 1)]
29    worker_tasks: usize,
30    /// The service factory that will be used to create new services.
31    service_factory: F,
32    /// The address to bind to.
33    ///
34    /// Use `[::]` for a dual-stack listener.
35    /// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6.
36    bind: SocketAddr,
37    /// rustls config.
38    ///
39    /// Use this field to set the server into TLS mode.
40    /// It will only accept TLS connections when this is set.
41    #[cfg(feature = "tls-rustls")]
42    rustls_config: Option<tokio_rustls::rustls::ServerConfig>,
43    /// Enable HTTP/1.1.
44    #[cfg(feature = "http1")]
45    #[builder(default = true)]
46    http1_enabled: bool,
47    /// Enable HTTP/2.
48    #[cfg(feature = "http2")]
49    #[builder(default = true)]
50    http2_enabled: bool,
51}
52
53impl<F> HyperBackend<F>
54where
55    F: HttpServiceFactory + Clone + Send + 'static,
56    F::Error: std::error::Error + Send,
57    F::Service: Clone + Send + 'static,
58    <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
59    <F::Service as HttpService>::ResBody: Send,
60    <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
61    <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
62{
63    /// Run the HTTP server
64    ///
65    /// This function will bind to the address specified in `bind`, listen for incoming connections and handle requests.
66    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
67    #[allow(unused_mut)] // allow the unused `mut self`
68    pub async fn run(mut self) -> Result<(), HttpError<F>> {
69        #[cfg(feature = "tracing")]
70        tracing::debug!("starting server");
71
72        // reset to 0 because everything explodes if it's not
73        // https://github.com/hyperium/hyper/issues/3841
74        #[cfg(feature = "tls-rustls")]
75        if let Some(rustls_config) = self.rustls_config.as_mut() {
76            rustls_config.max_early_data_size = 0;
77        }
78
79        // We have to create an std listener first because the tokio listener isn't clonable
80        let listener = tokio::net::TcpListener::bind(self.bind).await?.into_std()?;
81
82        #[cfg(feature = "tls-rustls")]
83        let tls_acceptor = self
84            .rustls_config
85            .map(|c| tokio_rustls::TlsAcceptor::from(std::sync::Arc::new(c)));
86
87        // Create a child context for the workers so we can shut them down if one of them fails without shutting down the main context
88        let (worker_ctx, worker_handler) = self.ctx.new_child();
89
90        let workers = (0..self.worker_tasks)
91            .map(|_n| {
92                let service_factory = self.service_factory.clone();
93                let ctx = worker_ctx.clone();
94                let std_listener = listener.try_clone()?;
95                let listener = tokio::net::TcpListener::from_std(std_listener)?;
96                #[cfg(feature = "tls-rustls")]
97                let tls_acceptor = tls_acceptor.clone();
98
99                let worker_fut = async move {
100                    loop {
101                        #[cfg(feature = "tracing")]
102                        tracing::trace!("waiting for connections");
103
104                        let (mut stream, addr) = match listener.accept().with_context(ctx.clone()).await {
105                            Some(Ok((tcp_stream, addr))) => (stream::Stream::Tcp(tcp_stream), addr),
106                            Some(Err(e)) if utils::is_fatal_tcp_error(&e) => {
107                                #[cfg(feature = "tracing")]
108                                tracing::error!(err = %e, "failed to accept tcp connection");
109                                return Err(HttpError::<F>::from(e));
110                            }
111                            Some(Err(_)) => continue,
112                            None => {
113                                #[cfg(feature = "tracing")]
114                                tracing::trace!("context done, stopping listener");
115                                break;
116                            }
117                        };
118
119                        #[cfg(feature = "tracing")]
120                        tracing::trace!(addr = %addr, "accepted tcp connection");
121
122                        let ctx = ctx.clone();
123                        #[cfg(feature = "tls-rustls")]
124                        let tls_acceptor = tls_acceptor.clone();
125                        let mut service_factory = service_factory.clone();
126
127                        let connection_fut = async move {
128                            // Perform the TLS handshake if the acceptor is set
129                            #[cfg(feature = "tls-rustls")]
130                            if let Some(tls_acceptor) = tls_acceptor {
131                                #[cfg(feature = "tracing")]
132                                tracing::trace!("accepting tls connection");
133
134                                stream = match stream.try_accept_tls(&tls_acceptor).with_context(&ctx).await {
135                                    Some(Ok(stream)) => stream,
136                                    Some(Err(_err)) => {
137                                        #[cfg(feature = "tracing")]
138                                        tracing::warn!(err = %_err, "failed to accept tls connection");
139                                        return;
140                                    }
141                                    None => {
142                                        #[cfg(feature = "tracing")]
143                                        tracing::trace!("context done, stopping tls acceptor");
144                                        return;
145                                    }
146                                };
147
148                                #[cfg(feature = "tracing")]
149                                tracing::trace!("accepted tls connection");
150                            }
151
152                            let mut extra_extensions = http::Extensions::new();
153                            extra_extensions.insert(crate::extensions::ClientAddr(addr));
154
155                            #[cfg(feature = "tls-rustls")]
156                            if let Some(certs) = stream.get_client_certs() {
157                                extra_extensions.insert(crate::extensions::ClientIdentity(Arc::new(certs.to_vec())));
158                            }
159
160                            // make a new service
161                            let http_service = match service_factory.new_service(addr).await {
162                                Ok(service) => service,
163                                Err(_e) => {
164                                    #[cfg(feature = "tracing")]
165                                    tracing::warn!(err = %_e, "failed to create service");
166                                    return;
167                                }
168                            };
169
170                            #[cfg(feature = "tracing")]
171                            tracing::trace!("handling connection");
172
173                            #[cfg(feature = "http1")]
174                            let http1 = self.http1_enabled;
175                            #[cfg(not(feature = "http1"))]
176                            let http1 = false;
177
178                            #[cfg(feature = "http2")]
179                            let http2 = self.http2_enabled;
180                            #[cfg(not(feature = "http2"))]
181                            let http2 = false;
182
183                            let _res = handler::handle_connection::<F, _, _>(
184                                ctx,
185                                http_service,
186                                extra_extensions,
187                                stream,
188                                http1,
189                                http2,
190                            )
191                            .await;
192
193                            #[cfg(feature = "tracing")]
194                            if let Err(e) = _res {
195                                tracing::warn!(err = %e, "error handling connection");
196                            }
197
198                            #[cfg(feature = "tracing")]
199                            tracing::trace!("connection closed");
200                        };
201
202                        #[cfg(feature = "tracing")]
203                        let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
204
205                        tokio::spawn(connection_fut);
206                    }
207
208                    #[cfg(feature = "tracing")]
209                    tracing::trace!("listener closed");
210
211                    Ok(())
212                };
213
214                #[cfg(feature = "tracing")]
215                let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
216
217                Ok(tokio::spawn(worker_fut))
218            })
219            .collect::<std::io::Result<Vec<_>>>()?;
220
221        match futures::future::try_join_all(workers).await {
222            Ok(res) => {
223                for r in res {
224                    if let Err(e) = r {
225                        drop(worker_ctx);
226                        worker_handler.shutdown().await;
227                        return Err(e);
228                    }
229                }
230            }
231            Err(_e) => {
232                #[cfg(feature = "tracing")]
233                tracing::error!(err = %_e, "error running workers");
234            }
235        }
236
237        drop(worker_ctx);
238        worker_handler.shutdown().await;
239
240        #[cfg(feature = "tracing")]
241        tracing::debug!("all workers finished");
242
243        Ok(())
244    }
245}