Implementing JWT Authentication in Rust using Axum


This is part 4 of this series, where I implement JWT authentication using the Axum framework in Rust.

The previous part is here.



Source Code

The repository is here



Quick Recap

In the previous chapter we added database and JWT configurations. We created an application context, which is initialised once and shared throughout the application.



Introduction

In this chapter we will see how to create middlewares in Axum using Tower. The middlewares here will be used for authentication purposes. We will use the argon2 crate added in the previous section to hash passwords; the jsonwebtoken crate will provide methods to sign and verify tokens.



Dependencies:

Add the following deps.

cargo add time -F serde
cargo add axum-extra -F "cookie, typed-header, typed-routing"
cargo add futures_util
cargo add tower -F "tracing, tokio"

Enter fullscreen mode

Exit fullscreen mode



Sharing State with Handlers.

Add the AppContext we created in the previous chapters to our router.

// src/app.rs
        let ctx = Arc::new(AppContext::try_from(&config)?);

        let router = Router::new()
            .route("/hello", get(|| async { "Hello from axum!" }))
// We will the share the state with handlers i.e.
// .route("/auth", auth::routers(&ctx))
            .layer(
                TraceLayer::new_for_http()
                    .make_span_with(middlewares::make_span_with)
                    .on_request(middlewares::on_request)
                    .on_response(middlewares::on_response)
                    .on_failure(middlewares::on_failure),
            );


Enter fullscreen mode

Exit fullscreen mode



Signing & Verifying JWTs.

Let’s create functions and data structures that will encode and decode tokens. Inside the model module, create a token.rs file and add the following contents.

use serde::{Deserialize, Serialize};
use uuid::Uuid;

/// The token string deserialises to this struct
/// The `sub` field will be the user's pid
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TokenClaims {
    pub sub: String,
    pub id: String,
    pub exp: i64,
    pub iat: i64,
    pub nbf: i64,
}

/// This struct will let us store our token in Redis
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TokenDetails {
    pub token: Option<String>,
    pub token_id: Uuid,
    pub user_pid: Uuid,
    pub expires_in: Option<i64>,
}

Enter fullscreen mode

Exit fullscreen mode



Understanding Token Details

The TokenDetails struct serves as a bridge between our JWT implementation and Redis storage. For refresh tokens, which need to persist across sessions, we can store the TokenDetails in Redis using the token_id as the key. This allows us to:

  • Track Active Sessions: Store long-lived refresh tokens to track user sessions.
  • Enable Token Revocation: Delete a token from Redis to immediately invalidate it, even before expiration.
  • Link Tokens to Users: Use the user_pid to find all active tokens for a specific user.
  • Automatic Cleanup: expires_in lets us implement Redis TTL (time-to-live).

In our context.rs file, add this implementation to the JwtContext.

impl JwtContext {
    pub fn generate_token(&self, sub: Uuid) -> Result<TokenDetails, Report> {
        let now = chrono::Utc::now();

        let mut token_details = TokenDetails {
            user_pid: sub,
            token_id: Uuid::new_v4(),
            expires_in: Some((now + chrono::Duration::seconds(self.exp)).timestamp()),
            token: None,
        };

        let claims = TokenClaims {
            sub: token_details.user_pid.to_string(),
            id: token_details.token_id.to_string(),
            exp: token_details.expires_in.ok_or(crate::Error::TokenError)?,
            iat: now.timestamp(),
            nbf: now.timestamp(),
        };

        let header = Header::new(Algorithm::RS256); // import from jsonwebtoken

        let token = jsonwebtoken::encode(&header, &claims, &self.encoding_key)?;

        token_details.token = Some(token);

        Ok(token_details)
    }

    pub fn verify_token(&self, token: &str) -> Result<TokenDetails, Report> {
        let validation = Validation::new(Algorithm::RS256);

        let token_data =
            jsonwebtoken::decode::<TokenClaims>(token, &self.decoding_key, &validation)?;

        let user_pid = Uuid::parse_str(&token_data.claims.sub)?; // import from uuid
        let token_id = Uuid::parse_str(&token_data.claims.id)?;

        Ok(TokenDetails {
            token: None,
            token_id,
            user_pid,
            expires_in: None,
        })
    }
}
Enter fullscreen mode

Exit fullscreen mode

Change the type of the exp field to i64.



Understanding Generating & Verifying Tokens.

The generate_token method creates a new JWT by initialising token metadata (TokenDetails) and building JWT claims (TokenClaims). You can read more about claims here. Then we create a token, which is a base64 string, using the encoding key (private key).

The verify_token method is pretty simple; we use the decoding key (public key) to decode the string and then deserialise it to the TokenClaims struct. But since the user_pid and token_id are of type Uuid we parse them to that type.



Middlewares

Middlewares are functions which usually run before a request is processed by a handler. Therefore they are perfect for checking if a user is authenticated and authorised, etc.

Inside the middlewares module create auth.rs and refresh.rs files and add the following contents:

auth.rs

/// This module contains middleware code to check if a user is authenticated.
/// It uses `tower::Service` and `tower::Layer` to create Request middleware.
use std::{
    convert::Infallible,
    sync::Arc,
    task::{Context, Poll},
};

use axum::{
    RequestPartsExt,
    body::Body,
    http::{Request, Response},
    response::IntoResponse,
};
use axum_extra::{
    TypedHeader,
    headers::{Authorization, Cookie, authorization::Bearer},
    typed_header::TypedHeaderRejectionReason,
};
use futures_util::future::BoxFuture;
use tower::{Layer, Service};

use crate::{context::AppContext, middlewares::AuthError};

#[derive(Clone)]
pub struct AuthLayer {
    ctx: Arc<AppContext>,
}

impl AuthLayer {
    pub fn new(ctx: &Arc<AppContext>) -> Self {
        Self { ctx: ctx.clone() }
    }
}

impl<S> Layer<S> for AuthLayer {
    type Service = AuthService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        Self::Service {
            inner,
            ctx: self.ctx.clone(),
        }
    }
}

#[derive(Clone)]
pub struct AuthService<S> {
    inner: S,
    ctx: Arc<AppContext>,
}

impl<S, B> Service<Request<B>> for AuthService<S>
where
    S: Service<Request<B>, Response = Response<Body>, Error = Infallible> + Clone + Send + 'static,
    S::Future: Send + 'static,
    B: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;

    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<B>) -> Self::Future {
        let ctx = self.ctx.clone();
        let clone = self.inner.clone();

        // Take the service that is ready
        let mut inner = std::mem::replace(&mut self.inner, clone);

        Box::pin(async move {
            let (mut parts, body) = req.into_parts();

            let access_token = match parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
                Ok(header) => Some(header.token().to_string()),
                Err(err) => {
                    // Access Token not in authorisation header; so check cookies
                    if matches!(err.reason(), TypedHeaderRejectionReason::Missing) {
                        parts.extract::<TypedHeader<Cookie>>().await.ok().and_then(
                            |TypedHeader(cookies)| {
                                cookies.get("access_token").map(ToString::to_string)
                            },
                        )
                    } else {
                        // The reason why we wrap the return value in Ok despite it being an error
                        // is beacause middlewares in Axum cannot return Errors i.e `Error =
                        // Infallible`
                        return Ok::<Response<Body>, Self::Error>(
                            AuthError::InvalidToken.into_response(),
                        );
                    }
                }
            };

            let Some(access_token) = access_token else {
                return Ok(AuthError::MissingCredentials.into_response());
            };

            // verify the access token
            let token_details = match ctx.auth.access.verify_token(&access_token) {
                Ok(details) => details,
                Err(err) => return Ok(err.into_response()),
            };

            // Reconstuct the Request and insert the token details into it.

            let mut req = Request::from_parts(parts, body);
            req.extensions_mut().insert(token_details);

            inner.call(req).await
        })
    }
}

Enter fullscreen mode

Exit fullscreen mode

refresh.rs

/// This module contains middleware code to refresh a user's expired access token.
/// It uses `tower::Service` and `tower::Layer` to create Request middleware.
use std::{
    convert::Infallible,
    sync::Arc,
    task::{Context, Poll},
};

use axum::{
    RequestPartsExt,
    body::Body,
    http::{
        HeaderValue, Request, Response,
        header::{AUTHORIZATION, SET_COOKIE},
    },
    response::IntoResponse,
};
use axum_extra::{
    TypedHeader,
    extract::cookie,
    headers::{Authorization, Cookie, authorization::Bearer},
    typed_header::TypedHeaderRejectionReason,
};
use futures_util::future::BoxFuture;
use redis::AsyncCommands as _;
use tower::{Layer, Service};

use crate::{context::AppContext, middlewares::AuthError, models::token::TokenDetails};

#[derive(Clone)]
pub struct RefreshLayer {
    ctx: Arc<AppContext>,
}

impl RefreshLayer {
    pub fn new(ctx: &Arc<AppContext>) -> Self {
        Self { ctx: ctx.clone() }
    }
}

impl<S> Layer<S> for RefreshLayer {
    type Service = RefreshService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        Self::Service {
            inner,
            ctx: self.ctx.clone(),
        }
    }
}

#[derive(Clone)]
pub struct RefreshService<S> {
    inner: S,
    ctx: Arc<AppContext>,
}

impl<S, B> Service<Request<B>> for RefreshService<S>
where
    S: Service<Request<B>, Response = Response<Body>, Error = Infallible> + Clone + Send + 'static,
    S::Future: Send + 'static,
    B: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;

    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<B>) -> Self::Future {
        let ctx = self.ctx.clone();
        let clone = self.inner.clone();

        let mut inner = std::mem::replace(&mut self.inner, clone);

        Box::pin(async move {
            let (mut parts, body) = req.into_parts();

            let refresh_token = match parts.extract::<TypedHeader<Cookie>>().await {
                Ok(cookies) => cookies.get("refresh_token").map(ToString::to_string),
                Err(_) => return Ok(AuthError::MissingCredentials.into_response()),
            };

            let Some(refresh_token) = refresh_token else {
                return Ok(AuthError::MissingCredentials.into_response());
            };

            let access_token = match parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
                Ok(header) => Some(header.token().to_string()),
                Err(err) => {
                    // extract from the cookie object
                    if matches!(err.reason(), TypedHeaderRejectionReason::Missing) {
                        parts.extract::<TypedHeader<Cookie>>().await.ok().and_then(
                            |TypedHeader(cookies)| {
                                cookies.get("access_token").map(ToString::to_string)
                            },
                        )
                    } else {
                        // The reason why we wrap the return value in Ok despite it being an error
                        // is beacause middlewares in Axum cannot return Errors i.e `Error =
                        // Infallible`
                        return Ok::<Response<Body>, Self::Error>(
                            AuthError::InvalidToken.into_response(),
                        );
                    }
                }
            };

            // Verify the refresh token and get the user's pid
            let refresh_token_details = match ctx.auth.refresh.verify_token(&refresh_token) {
                Ok(details) => details,
                Err(err) => return Ok(err.into_response()),
            };

            // Check if the Refresh Token is in cache
            let mut redis_conn = ctx.redis.clone();
            let redis_key = format!("refresh_token:{}", refresh_token_details.token_id);

            let stored_token: Result<String, crate::Error> = redis_conn
                .get(&redis_key)
                .await
                .map_err(crate::Error::Redis);

            let stored_details = match stored_token {
                Ok(token) => match serde_json::from_str::<TokenDetails>(&token)
                    .map_err(crate::Error::SerdeJson)
                {
                    Ok(details) => details,
                    Err(err) => return Ok(err.response()),
                },
                Err(err) => return Ok(err.response()),
            };

            // if access token is missing we issue a new one.
            let new_access_token: String;
            if let Some(token) = access_token {
                new_access_token = token;
            } else {
                match ctx.auth.access.generate_token(stored_details.user_pid) {
                    Ok(details) => {
                        new_access_token = details.token.unwrap();
                    }
                    Err(e) => return Ok(e.into_response()),
                }
            }

            let access_cookie = cookie::Cookie::build(("access_token", &new_access_token))
                .path("https://dev.to/")
                .max_age(time::Duration::seconds(ctx.auth.access.exp))
                .same_site(cookie::SameSite::Lax)
                .http_only(true)
                .to_string();

            let mut req = Request::from_parts(parts, body);
            req.headers_mut().append(
                AUTHORIZATION,
                HeaderValue::from_str(format!("Bearer {}", &new_access_token).as_str()).unwrap(),
            );
            req.headers_mut().append(
                SET_COOKIE,
                HeaderValue::from_str(access_cookie.as_str()).unwrap(),
            );

            inner.call(req).await
        })
    }
}

Enter fullscreen mode

Exit fullscreen mode



Understanding The Middleware

This is usually how we implement middleware functions in Axum. You may notice its implemetation is quite familiar with the Future trait from the standard library that is because the Service job is to produce a Future.

Layer is a factory for middleware; It takes an inner service and produces a wrapped service. This is how we compose middleware at the router.
Service> for T is the middleware itself. axum treat HTTP handling as calling a Service that returns a Response. Implementing Service is the canonical, low-level way to make middleware.



Authentication Errors

Create an error.rs file inside the middleware module & add the following contents.

use axum::{
    Json,
    http::StatusCode,
    response::{IntoResponse, Response},
};
use serde_json::json;

#[derive(Debug, thiserror::Error)]
pub enum AuthError {
    #[error("Invalid token")]
    InvalidToken,
    #[error("Credentials missing from request")]
    MissingCredentials,
    #[error("Token creation failed")]
    TokenCreation,
    #[error("Wrong credentials")]
    WrongCredentials,
}

pub type AuthResult<T, E = AuthError> = Result<T, E>;

impl IntoResponse for AuthError {
    fn into_response(self) -> Response {
        self.response()
    }
}

impl AuthError {
    pub fn response(&self) -> Response {
        let (status, message) = match self {
            Self::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
            Self::MissingCredentials => {
                (StatusCode::BAD_REQUEST, "Credentials missing from request")
            }
            Self::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
            Self::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
        };

        let body = Json(json!({
            "error": message
        }));

        (status, body).into_response()
    }
}
Enter fullscreen mode

Exit fullscreen mode

Then add the AuthError to the main Error enum in error/mod.rs

use std::{
    env::VarError,
    fmt::{self, Display},
};

use argon2::password_hash::Error as PasswordHashError;
use axum::{
    Json,
    http::StatusCode,
    response::{IntoResponse, Response},
};
use serde_json::json;
use tracing_subscriber::filter::FromEnvError;

use crate::{middlewares::AuthError, models::ModelError};

#[derive(Debug)]
pub struct Report(pub color_eyre::Report);

impl IntoResponse for Report {
    fn into_response(self) -> Response {
        let err = self.0;
        let err_string = format!("{:?}", &err);

        tracing::error!("An error occured {}", err_string);

        if let Some(error) = err.downcast_ref::<Error>() {
            return error.response();
        }

        // backup response
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(json!({"error": "Something went wrong on our end."})),
        )
            .into_response()
    }
}

impl<E> From<E> for Report
where
    E: Into<color_eyre::Report>,
{
    fn from(err: E) -> Self {
        Self(err.into())
    }
}

impl Display for Report {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.0.fmt(f)
    }
}

pub type Result<T, E = Report> = std::result::Result<T, E>;

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error(transparent)]
    Axum(#[from] axum::Error),
    #[error(transparent)]
    Config(#[from] config::ConfigError),
    #[error(transparent)]
    DirectiveParseError(#[from] tracing_subscriber::filter::ParseError),
    #[error(transparent)]
    EnvFilter(#[from] VarError),
    #[error(transparent)]
    FromEnv(#[from] FromEnvError),
    #[error(transparent)]
    IOError(#[from] std::io::Error),
    #[error(transparent)]
    TryInit(#[from] tracing_subscriber::util::TryInitError),
    #[error(transparent)]
    Migrate(#[from] sqlx::migrate::MigrateError),
    #[error(transparent)]
    Redis(#[from] redis::RedisError),
    #[error(transparent)]
    JsonWebToken(#[from] jsonwebtoken::errors::Error),
    #[error("{0}")]
    Argon2(argon2::Error),
    #[error("{0}")]
    PasswordHash(argon2::password_hash::Error),
    #[error("Invalid email or password")]
    InvalidCredentials,
    #[error("Error occured when signing or verifying token")]
    TokenError,
    #[error(transparent)]
    SerdeJson(#[from] serde_json::error::Error),
    #[error(transparent)]
    Auth(#[from] AuthError),
    #[error(transparent)]
    Model(#[from] ModelError),
}

impl From<argon2::Error> for Error {
    fn from(err: argon2::Error) -> Self {
        Self::Argon2(err)
    }
}

impl From<PasswordHashError> for Error {
    fn from(err: PasswordHashError) -> Self {
        match err {
            PasswordHashError::Password => Self::InvalidCredentials,
            _ => Self::PasswordHash(err),
        }
    }
}

impl Error {
    pub fn response(&self) -> Response {
        let (status, message) = match self {
            Self::InvalidCredentials => (StatusCode::UNAUTHORIZED, "Invalid email or password"),
            Self::Auth(err) => return err.response(),
            Self::Model(err) => return err.response(),
            _ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error"),
        };

        let body = Json(json!({
            "error": message
        }));

        (status, body).into_response()
    }
}


Enter fullscreen mode

Exit fullscreen mode

That concludes this part, where we have seen how to generate and verify JWTs, create middlewares in Axum using tower crate.

In the following section I will show you how to use the above code to create, sign-in & sign-out users.



This Series

Part 1: Project Setup & Configuration
Part 2: Implementing Logging
Part 3: Database Setup with SQLx and PostgreSQL.
Part 4: JWTs & Middlewares (You are here)
Next Part 5: Creating & Authenticating Users (Comming Soon)



Source link

Leave a Reply

Your email address will not be published. Required fields are marked *