Extract bearer token verification into router middleware
Move auth from per-handler to a `verify_bearer` middleware applied to all
routes nested under `/runs/{run_id}`. The middleware extracts the run_id
via axum path params and verifies the bearer token against the DB before
any handler runs, so new endpoints under that prefix are authenticated
automatically.
change
commit bbbf7bc956dfcca0abc4d8921e102be21a6205c7
author Claude <noreply@anthropic.com>
date
parent fe711120
diff --git a/quire-server/src/quire/web/api.rs b/quire-server/src/quire/web/api.rs
index babdb9a..2f43367 100644
--- a/quire-server/src/quire/web/api.rs
+++ b/quire-server/src/quire/web/api.rs
@@ -4,25 +4,30 @@
 //! header auth used by the web UI). Each token is minted when the run
 //! is created and scoped to that run's ID.
 
-use axum::extract::{Path as AxumPath, State};
+use std::collections::HashMap;
+
+use axum::extract::{FromRequestParts, Path as AxumPath, State};
 use axum::http::StatusCode;
+use axum::middleware::Next;
 use axum::response::{IntoResponse, Response, Result};
-use axum_extra::TypedHeader;
-use axum_extra::headers::Authorization;
-use axum_extra::headers::authorization::Bearer;
 
 use crate::Quire;
 
-/// Build the API router. Routes are not wrapped in web-UI auth
-/// middleware; each handler verifies its own bearer token.
+/// Build the API router. Routes under `/runs/{run_id}` are wrapped in
+/// [`verify_bearer`] middleware which authenticates the bearer token against
+/// the run's stored token before any handler runs.
 ///
 /// Intended to be mounted at `/api` via `Router::nest`.
 pub fn router(quire: Quire) -> axum::Router {
+    let run_routes = axum::Router::new()
+        .route("/secrets/{name}", axum::routing::get(get_secret))
+        .layer(axum::middleware::from_fn_with_state(
+            quire.clone(),
+            verify_bearer,
+        ));
+
     axum::Router::new()
-        .route(
-            "/runs/{run_id}/secrets/{name}",
-            axum::routing::get(get_secret),
-        )
+        .nest("/runs/{run_id}", run_routes)
         .with_state(quire)
 }
 
@@ -62,44 +67,79 @@ impl IntoResponse for ApiError {
     }
 }
 
-/// Verify the bearer token against the stored `auth_token` for `run_id`.
-/// Returns `Err(NotFound)` if the run doesn't exist, `Err(Unauthorized)` if
-/// the token doesn't match (including a null token for filesystem-mode runs).
-fn verify_token(db: &rusqlite::Connection, run_id: &str, token: &str) -> Result<(), ApiError> {
-    let stored: Option<String> = db
-        .query_row(
-            "SELECT auth_token FROM runs WHERE id = ?1",
-            rusqlite::params![run_id],
-            |row| row.get(0),
+/// Middleware that authenticates requests under `/runs/{run_id}` by verifying
+/// the `Authorization: Bearer <token>` header against `runs.auth_token` in the
+/// DB. Returns 401 if the header is absent or the token doesn't match, 404 if
+/// the run doesn't exist.
+async fn verify_bearer(
+    State(quire): State<Quire>,
+    req: axum::extract::Request,
+    next: Next,
+) -> Response {
+    let token = req
+        .headers()
+        .get(axum::http::header::AUTHORIZATION)
+        .and_then(|v| v.to_str().ok())
+        .and_then(|s| s.strip_prefix("Bearer "))
+        .map(|s| s.to_string());
+
+    let Some(token) = token else {
+        return StatusCode::UNAUTHORIZED.into_response();
+    };
+
+    // Extract the run_id path param set by the enclosing nest("/runs/{run_id}", ...).
+    let (mut parts, body) = req.into_parts();
+    let run_id =
+        <AxumPath<HashMap<String, String>> as FromRequestParts<()>>::from_request_parts(
+            &mut parts,
+            &(),
         )
-        .map_err(ApiError::from)?;
-    match stored {
-        Some(ref t) if t == token => Ok(()),
-        _ => Err(ApiError::Unauthorized),
+        .await
+        .ok()
+        .and_then(|mut p| p.0.remove("run_id"));
+    let req = axum::extract::Request::from_parts(parts, body);
+
+    let Some(run_id) = run_id else {
+        return StatusCode::NOT_FOUND.into_response();
+    };
+
+    let result = tokio::task::spawn_blocking(move || -> Result<(), ApiError> {
+        let db = quire
+            .db_pool()
+            .lock()
+            .map_err(|_| crate::Error::Io(std::io::Error::other("db mutex poisoned")))?;
+        let stored: Option<String> = db
+            .query_row(
+                "SELECT auth_token FROM runs WHERE id = ?1",
+                rusqlite::params![run_id],
+                |row| row.get(0),
+            )
+            .map_err(ApiError::from)?;
+        match stored {
+            Some(ref t) if t == &token => Ok(()),
+            _ => Err(ApiError::Unauthorized),
+        }
+    })
+    .await
+    .expect("blocking task panicked");
+
+    match result {
+        Ok(()) => next.run(req).await,
+        Err(e) => e.into_response(),
     }
 }
 
 /// `GET /api/runs/:run_id/secrets/:name`
 ///
 /// Returns the plain-text value of a named secret from the global config.
-/// Auth: `Authorization: Bearer <token>` matching `runs.auth_token`.
-/// Returns 404 if the run is unknown or the secret is not declared in config.
+/// Auth is handled by [`verify_bearer`] middleware before this handler runs.
+/// Returns 404 if the secret is not declared in config.
 async fn get_secret(
     State(quire): State<Quire>,
-    AxumPath((run_id, name)): AxumPath<(String, String)>,
-    bearer: Option<TypedHeader<Authorization<Bearer>>>,
+    AxumPath(params): AxumPath<HashMap<String, String>>,
 ) -> Result<axum::Json<serde_json::Value>, ApiError> {
-    let Some(TypedHeader(Authorization(bearer))) = bearer else {
-        return Err(ApiError::Unauthorized);
-    };
-    let token = bearer.token().to_string();
-
+    let name = params.get("name").cloned().unwrap_or_default();
     let value = tokio::task::spawn_blocking(move || -> std::result::Result<String, ApiError> {
-        let db = quire
-            .db_pool()
-            .lock()
-            .map_err(|_| crate::Error::Io(std::io::Error::other("db mutex poisoned")))?;
-        verify_token(&db, &run_id, &token)?;
         let config = quire.global_config()?;
         match config.secrets.get(&name) {
             Some(s) => Ok(s.reveal()?.to_string()),