4 min read
On this page

Middleware & State

Every non-trivial web application needs shared state and cross-cutting concerns like logging, authentication, and CORS. Axum handles both through Tower middleware and a built-in state extraction system. Understanding these two mechanisms is essential for building production services.

Application State with State<T>

Axum lets you attach shared state to the router and extract it in handlers:

use axum::{Router, routing::get, extract::State};
use std::sync::Arc;

struct AppState {
    db_pool: sqlx::PgPool,
    config: AppConfig,
}

struct AppConfig {
    jwt_secret: String,
    max_upload_size: usize,
}

#[tokio::main]
async fn main() {
    let pool = sqlx::PgPool::connect("postgres://localhost/mydb")
        .await
        .unwrap();

    let state = Arc::new(AppState {
        db_pool: pool,
        config: AppConfig {
            jwt_secret: "secret".to_string(),
            max_upload_size: 10_485_760,
        },
    });

    let app = Router::new()
        .route("/users", get(list_users))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
        .await
        .unwrap();

    axum::serve(listener, app).await.unwrap();
}

async fn list_users(State(state): State<Arc<AppState>>) -> String {
    // Access state.db_pool, state.config, etc.
    format!("Max upload: {} bytes", state.config.max_upload_size)
}

State is an extractor like Path or Json. It clones the state for each request, which is why you wrap it in Arc — cloning an Arc is cheap. The state type must implement Clone.

Why Arc for State

If your state is just a database pool, you might not need Arc because PgPool is already internally reference-counted. But most applications have multiple fields:

use std::sync::Arc;
use tokio::sync::RwLock;

#[derive(Clone)]
struct AppState {
    db: sqlx::PgPool,
    cache: Arc<RwLock<HashMap<String, String>>>,
}

When you derive Clone on a struct containing Arc fields, the clone is shallow. Each request handler gets its own reference to the same underlying data. This is the pattern for sharing mutable state safely across async handlers.

Tower Middleware

Axum is built on Tower, which means middleware is composable and reusable. Tower middleware wraps your service, intercepting requests and responses.

Logging with TraceLayer

The most common middleware. Add tracing to every request:

use axum::Router;
use tower_http::trace::TraceLayer;
use tracing_subscriber;

#[tokio::main]
async fn main() {
    tracing_subscriber::fmt::init();

    let app = Router::new()
        .route("/", get(root))
        .layer(TraceLayer::new_for_http());

    // Every request now logs method, path, status, and duration
}
2024-01-15T10:30:00.000Z  INFO request{method=GET uri=/users}: tower_http::trace: response, status=200, latency=2ms

CORS

Cross-Origin Resource Sharing headers:

use tower_http::cors::{CorsLayer, Any};
use http::Method;

let cors = CorsLayer::new()
    .allow_origin(Any)
    .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
    .allow_headers(Any);

let app = Router::new()
    .route("/api/data", get(get_data))
    .layer(cors);

In production, replace Any for allow_origin with your actual frontend domain.

Rate Limiting with Tower

Tower provides a rate limiting layer:

use tower::limit::RateLimitLayer;
use std::time::Duration;

let app = Router::new()
    .route("/api/expensive", get(expensive_operation))
    .layer(RateLimitLayer::new(100, Duration::from_secs(60)));

This limits the entire service to 100 requests per 60 seconds. For per-user rate limiting, you need a custom middleware or a crate like tower-governor.

Layer Composition

Middleware is applied with .layer() and executes in reverse order — the last layer added is the outermost:

use tower_http::trace::TraceLayer;
use tower_http::cors::CorsLayer;
use tower_http::compression::CompressionLayer;

let app = Router::new()
    .route("/", get(root))
    .layer(CompressionLayer::new())  // innermost: runs last
    .layer(CorsLayer::new())          // middle
    .layer(TraceLayer::new_for_http()); // outermost: runs first

Request flow: TraceLayer -> CorsLayer -> CompressionLayer -> Handler. Response flow reverses. The trace layer wraps everything, so it captures the full request lifecycle including time spent in other middleware.

You can also use ServiceBuilder for a more readable stack:

use tower::ServiceBuilder;

let middleware_stack = ServiceBuilder::new()
    .layer(TraceLayer::new_for_http())
    .layer(CorsLayer::new())
    .layer(CompressionLayer::new());

let app = Router::new()
    .route("/", get(root))
    .layer(middleware_stack);

With ServiceBuilder, layers execute in the order listed — top to bottom for requests.

Writing Custom Middleware

For simple cases, use axum::middleware::from_fn:

use axum::{
    middleware::{self, Next},
    http::Request,
    response::Response,
};

async fn auth_middleware(
    request: Request<axum::body::Body>,
    next: Next,
) -> Result<Response, StatusCode> {
    let auth_header = request
        .headers()
        .get("Authorization")
        .and_then(|v| v.to_str().ok());

    match auth_header {
        Some(token) if token.starts_with("Bearer ") => {
            Ok(next.run(request).await)
        }
        _ => Err(StatusCode::UNAUTHORIZED),
    }
}

let app = Router::new()
    .route("/protected", get(protected_handler))
    .layer(middleware::from_fn(auth_middleware));

The middleware receives the request, can inspect or modify it, and calls next.run(request) to pass control to the next layer. Returning an error short-circuits the chain.

Middleware with State

Custom middleware often needs access to application state:

use axum::extract::State;
use std::sync::Arc;

async fn auth_middleware(
    State(state): State<Arc<AppState>>,
    request: Request<axum::body::Body>,
    next: Next,
) -> Result<Response, StatusCode> {
    let auth_header = request
        .headers()
        .get("Authorization")
        .and_then(|v| v.to_str().ok());

    match auth_header {
        Some(token) if validate_token(token, &state.config.jwt_secret) => {
            Ok(next.run(request).await)
        }
        _ => Err(StatusCode::UNAUTHORIZED),
    }
}

let app = Router::new()
    .route("/protected", get(protected_handler))
    .layer(middleware::from_fn_with_state(
        state.clone(),
        auth_middleware,
    ))
    .with_state(state);

Use from_fn_with_state to inject state into the middleware function. The state extractor must be the first parameter.

Request Extensions for Scoped Data

Middleware can attach data to the request for downstream handlers using extensions:

use axum::extract::Extension;

#[derive(Clone)]
struct CurrentUser {
    id: u64,
    username: String,
}

async fn auth_middleware(
    State(state): State<Arc<AppState>>,
    mut request: Request<axum::body::Body>,
    next: Next,
) -> Result<Response, StatusCode> {
    let token = request
        .headers()
        .get("Authorization")
        .and_then(|v| v.to_str().ok())
        .ok_or(StatusCode::UNAUTHORIZED)?;

    let user = decode_token(token, &state.config.jwt_secret)
        .map_err(|_| StatusCode::UNAUTHORIZED)?;

    // Attach the user to the request
    request.extensions_mut().insert(CurrentUser {
        id: user.id,
        username: user.username,
    });

    Ok(next.run(request).await)
}

// Handler extracts the user set by middleware
async fn profile(Extension(user): Extension<CurrentUser>) -> String {
    format!("Hello, {}", user.username)
}

This is the idiomatic way to pass data from middleware to handlers. The middleware authenticates and decodes the user; the handler just extracts it.

Applying Middleware to Specific Routes

Not all middleware should apply globally. Use route_layer for route-specific middleware:

let public_routes = Router::new()
    .route("/login", post(login))
    .route("/health", get(health));

let protected_routes = Router::new()
    .route("/profile", get(profile))
    .route("/settings", get(settings).put(update_settings))
    .layer(middleware::from_fn_with_state(
        state.clone(),
        auth_middleware,
    ));

let app = Router::new()
    .merge(public_routes)
    .merge(protected_routes)
    .layer(TraceLayer::new_for_http()) // applies to all routes
    .with_state(state);

Tracing applies to every request. Authentication only applies to the protected routes.

Common Pitfalls

  • Forgetting Arc on state. State<T> clones on every request. Without Arc, you clone the entire struct. With large state, this is expensive and often wrong.
  • Layer ordering confusion. With .layer(), the last one added is the outermost. With ServiceBuilder, the first one listed is the outermost. Pick one style and be consistent.
  • Blocking in middleware. Middleware runs on the async runtime. Do not perform blocking I/O or heavy computation without spawn_blocking.
  • Not propagating state to nested routers. Each nested router needs .with_state() or needs to share the parent state. A missing state causes compile errors that point to trait bounds, not the actual problem.
  • Using Extension when State would suffice. State is for data that lives for the application lifetime. Extension is for request-scoped data set by middleware. Mixing them up works but confuses intent.

Key Takeaways

  • Use State<Arc<T>> to share application-wide data like database pools and configuration.
  • Tower middleware composes via .layer() — tracing, CORS, compression, and auth are all just layers.
  • Write custom middleware with middleware::from_fn for simple cases and from_fn_with_state when you need application state.
  • Request extensions let middleware pass data to handlers — ideal for authentication.
  • Apply middleware selectively by merging routers with different layer stacks.
  • ServiceBuilder offers a more readable way to compose multiple middleware layers.