scufflecloud_core/middleware/
auth.rs1use std::str::FromStr;
2use std::sync::Arc;
3
4use axum::extract::Request;
5use axum::http::{HeaderMap, HeaderName, StatusCode};
6use axum::middleware::Next;
7use axum::response::Response;
8use base64::Engine;
9use core_db_types::models::{UserSession, UserSessionTokenId};
10use core_db_types::schema::user_sessions;
11use diesel::{BoolExpressionMethods, ExpressionMethods, SelectableHelper};
12use diesel_async::RunQueryDsl;
13use ext_traits::RequestExt;
14use fred::prelude::KeysInterface;
15use geo_ip::GeoIpRequestExt;
16use geo_ip::middleware::IpAddressInfo;
17use hmac::Mac;
18
19const TOKEN_ID_HEADER: HeaderName = HeaderName::from_static("scuf-token-id");
20const TIMESTAMP_HEADER: HeaderName = HeaderName::from_static("scuf-timestamp");
21const NONCE_HEADER: HeaderName = HeaderName::from_static("scuf-nonce");
22
23const AUTHENTICATION_METHOD_HEADER: HeaderName = HeaderName::from_static("scuf-auth-method");
24const AUTHENTICATION_HMAC_HEADER: HeaderName = HeaderName::from_static("scuf-auth-hmac");
25
26pub(crate) const fn auth_headers() -> [HeaderName; 5] {
27 [
28 TOKEN_ID_HEADER,
29 TIMESTAMP_HEADER,
30 NONCE_HEADER,
31 AUTHENTICATION_METHOD_HEADER,
32 AUTHENTICATION_HMAC_HEADER,
33 ]
34}
35
36#[derive(Clone, Debug)]
37pub(crate) struct ExpiredSession(pub UserSession);
38
39pub(crate) async fn auth<G: core_traits::Global>(mut req: Request, next: Next) -> Result<Response, StatusCode> {
40 let global = req
41 .extensions()
42 .global::<G>()
43 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
44 let ip_info = req
45 .extensions()
46 .ip_address_info()
47 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
48
49 let (session, expired_session) = get_and_update_active_session(&global, &ip_info, req.headers()).await?;
50 if let Some(session) = session {
51 req.extensions_mut().insert(session);
52 }
53 if let Some(expired_session) = expired_session {
54 req.extensions_mut().insert(expired_session);
55 }
56
57 Ok(next.run(req).await)
58}
59
60fn get_auth_header<'a, T>(headers: &'a HeaderMap, header_name: &HeaderName) -> Result<Option<T>, StatusCode>
61where
62 T: FromStr + 'a,
63 T::Err: std::fmt::Display,
64{
65 match headers.get(header_name) {
66 Some(h) => {
67 let s = h.to_str().map_err(|e| {
68 tracing::debug!(header = %header_name, error = %e, "invalid header value");
69 StatusCode::BAD_REQUEST
70 })?;
71 Ok(Some(s.parse().map_err(|e| {
72 tracing::debug!(header = %header_name, error = %e, "failed to parse header value");
73 StatusCode::BAD_REQUEST
74 })?))
75 }
76 None => Ok(None),
77 }
78}
79
80#[derive(Debug, thiserror::Error)]
81enum AuthenticationMethodParseError {
82 #[error("unknown authentication algorithm")]
83 UnknownAlgorithm,
84 #[error("invalid header format")]
85 InvalidHeaderFormat,
86}
87
88#[derive(Debug)]
89enum AuthenticationAlgorithm {
90 HmacSha256,
91}
92
93impl FromStr for AuthenticationAlgorithm {
94 type Err = AuthenticationMethodParseError;
95
96 fn from_str(s: &str) -> Result<Self, Self::Err> {
97 match s {
98 "HMAC-SHA256" => Ok(AuthenticationAlgorithm::HmacSha256),
99 _ => Err(AuthenticationMethodParseError::UnknownAlgorithm),
100 }
101 }
102}
103
104#[derive(Debug)]
105struct AuthenticationMethod {
106 pub algorithm: AuthenticationAlgorithm,
107 pub headers: Vec<HeaderName>,
108}
109
110impl FromStr for AuthenticationMethod {
111 type Err = AuthenticationMethodParseError;
112
113 fn from_str(s: &str) -> Result<Self, Self::Err> {
114 let parts: Vec<&str> = s.splitn(2, ';').collect();
115 if parts.len() != 2 {
116 return Err(AuthenticationMethodParseError::InvalidHeaderFormat);
117 }
118
119 let algorithm: AuthenticationAlgorithm = parts[0].parse()?;
120 let headers: Vec<HeaderName> = parts[1]
121 .split(',')
122 .map(|h| HeaderName::from_str(h.trim()).map_err(|_| AuthenticationMethodParseError::InvalidHeaderFormat))
123 .collect::<Result<_, _>>()?;
124
125 Ok(AuthenticationMethod { algorithm, headers })
126 }
127}
128
129#[derive(thiserror::Error, Debug)]
130enum NonceParseError {
131 #[error("failed to decode: {0}")]
132 Base64(#[from] base64::DecodeError),
133 #[error("invalid nonce length {0}, must be 32 bytes")]
134 InvalidLength(usize),
135}
136
137#[derive(Debug)]
138struct Nonce(Vec<u8>);
139
140impl FromStr for Nonce {
141 type Err = NonceParseError;
142
143 fn from_str(s: &str) -> Result<Self, Self::Err> {
144 let bytes = base64::prelude::BASE64_STANDARD.decode(s)?;
145 if bytes.len() != 32 {
146 return Err(NonceParseError::InvalidLength(bytes.len()));
147 }
148 Ok(Nonce(bytes))
149 }
150}
151
152#[derive(Debug)]
153struct AuthenticationHmac(Vec<u8>);
154
155impl FromStr for AuthenticationHmac {
156 type Err = base64::DecodeError;
157
158 fn from_str(s: &str) -> Result<Self, Self::Err> {
159 let bytes = base64::prelude::BASE64_STANDARD.decode(s)?;
160 Ok(AuthenticationHmac(bytes))
161 }
162}
163
164async fn get_and_update_active_session<G: core_traits::Global>(
165 global: &Arc<G>,
166 ip_info: &IpAddressInfo,
167 headers: &HeaderMap,
168) -> Result<(Option<UserSession>, Option<ExpiredSession>), StatusCode> {
169 let Some(session_token_id) = get_auth_header::<UserSessionTokenId>(headers, &TOKEN_ID_HEADER)? else {
170 return Ok((None, None));
171 };
172 let Some(timestamp) =
173 get_auth_header::<u64>(headers, &TIMESTAMP_HEADER)?.and_then(|t| chrono::DateTime::from_timestamp_millis(t as i64))
174 else {
175 return Ok((None, None));
176 };
177 let Some(nonce) = get_auth_header::<Nonce>(headers, &NONCE_HEADER)? else {
178 return Ok((None, None));
179 };
180
181 let Some(auth_method) = get_auth_header::<AuthenticationMethod>(headers, &AUTHENTICATION_METHOD_HEADER)? else {
182 return Ok((None, None));
183 };
184 let Some(auth_hmac) = get_auth_header::<AuthenticationHmac>(headers, &AUTHENTICATION_HMAC_HEADER)? else {
185 return Ok((None, None));
186 };
187
188 if (chrono::Utc::now() - timestamp).abs()
189 > chrono::TimeDelta::from_std(global.timeout_config().max_request_diff).expect("invalid config")
190 {
191 tracing::debug!(timestamp = %timestamp, "invalid request timestamp");
192 return Err(StatusCode::UNAUTHORIZED);
193 }
194
195 if !auth_method.headers.contains(&TOKEN_ID_HEADER)
196 || !auth_method.headers.contains(&TIMESTAMP_HEADER)
197 || !auth_method.headers.contains(&NONCE_HEADER)
198 {
199 tracing::debug!("missing required headers in authentication method");
200 return Err(StatusCode::BAD_REQUEST);
201 }
202
203 let mut db = global.db().await.map_err(|e| {
204 tracing::error!(error = %e, "failed to connect to database");
205 StatusCode::INTERNAL_SERVER_ERROR
206 })?;
207
208 let Some(session) = diesel::update(user_sessions::dsl::user_sessions)
209 .set((
210 user_sessions::dsl::last_ip.eq(ip_info.to_network()),
211 user_sessions::dsl::last_used_at.eq(chrono::Utc::now()),
212 ))
213 .filter(
214 user_sessions::dsl::token_id
215 .eq(session_token_id)
216 .and(user_sessions::dsl::token.is_not_null())
217 .and(user_sessions::dsl::expires_at.gt(chrono::Utc::now())),
218 )
219 .returning(UserSession::as_select())
220 .get_results::<UserSession>(&mut db)
221 .await
222 .map_err(|e| {
223 tracing::error!(error = %e, "failed to update user session");
224 StatusCode::INTERNAL_SERVER_ERROR
225 })?
226 .into_iter()
227 .next()
228 else {
229 tracing::debug!(token_id = %session_token_id, "no active session found");
230 return Err(StatusCode::UNAUTHORIZED);
231 };
232
233 let token = session.token.as_ref().expect("known to be not null due to filter");
234
235 match auth_method.algorithm {
237 AuthenticationAlgorithm::HmacSha256 => {
238 let mut mac = hmac::Hmac::<sha2::Sha256>::new_from_slice(token).map_err(|e| {
239 tracing::error!(error = %e, "failed to create HMAC instance");
240 StatusCode::INTERNAL_SERVER_ERROR
241 })?;
242
243 for header_name in &auth_method.headers {
244 if let Some(value) = headers.get(header_name) {
245 mac.update(value.as_bytes());
246 } else {
247 tracing::debug!(header = %header_name, "missing header");
248 return Err(StatusCode::BAD_REQUEST);
249 }
250 }
251
252 mac.verify_slice(&auth_hmac.0).map_err(|e| {
253 tracing::debug!(error = %e, "HMAC verification failed");
254 StatusCode::UNAUTHORIZED
255 })?;
256 }
257 }
258
259 let mut key = "nonces:".as_bytes().to_vec();
260 key.extend_from_slice(&nonce.0);
261 let value: Option<bool> = global
262 .redis()
263 .set(
264 key.as_slice(),
265 true,
266 Some(fred::types::Expiration::PX(
267 global.timeout_config().max_request_diff.as_millis() as i64,
268 )),
269 Some(fred::types::SetOptions::NX),
270 true,
271 )
272 .await
273 .map_err(|e| {
274 tracing::error!(error = %e, "failed to set nonce in redis");
275 StatusCode::INTERNAL_SERVER_ERROR
276 })?;
277
278 if value.is_some() {
279 tracing::debug!("replayed nonce detected");
280 return Err(StatusCode::UNAUTHORIZED);
281 }
282
283 if session.token_expires_at.is_some_and(|t| t <= chrono::Utc::now()) {
284 return Ok((None, Some(ExpiredSession(session))));
285 }
286
287 Ok((Some(session), None))
288}