如何使用tower::Service和tower::Layer创建Axum中间件?

问题描述 投票:0回答:1

我正在尝试使用 Axum 中的自定义中间件来实现 JWT 身份验证。但是,当我尝试返回验证失败的错误时,我无法编译它。下面是我编写的 JWT 中间件验证代码。您能告诉我如何修改它以实现所需的功能吗?

//custom_middleware.rs
use axum::http::StatusCode;
use axum::{extract::Request, response::Response};
use futures_util::future::BoxFuture;
use jsonwebtoken::{
    decode, errors::Error as JwtError, Algorithm, DecodingKey, TokenData, Validation,
};
use serde::{Deserialize, Serialize};
use std::task::{Context, Poll};
use tower::{Layer, Service};

#[derive(Serialize, Deserialize)]
pub struct Claims {
    pub id: usize,
    pub exp: usize,
}

#[derive(Clone)]
pub struct MyLayer;

impl<S> Layer<S> for MyLayer {
    type Service = MyMiddleware<S>;

    fn layer(&self, inner: S) -> Self::Service {
        MyMiddleware { inner }
    }
}

#[derive(Clone)]
pub struct MyMiddleware<S> {
    inner: S,
}

impl<S> Service<Request> for MyMiddleware<S>
where
    S: Service<Request, Response = Response> + Send + 'static,
    S::Future: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request) -> Self::Future {
        match has_permission(&req) {
            Ok(_) => {
                let future = self.inner.call(req);
                Box::pin(async move {
                    let response: Response = future.await?;
                    Ok(response)
                })
            }
            Err(_) => Err((StatusCode::BAD_REQUEST, "bad request")),
        }
    }
}

fn has_permission(req: &Request) -> Result<TokenData<Claims>, (StatusCode, &'static str)> {
    let secret = "baby195lxl";
    let authorization_header_option = req.headers().get("authorization");
    if authorization_header_option.is_none() {
        return Err((StatusCode::BAD_REQUEST, "authorization header is none"));
    }
    let authentication_token: String = authorization_header_option
        .unwrap()
        .to_str()
        .unwrap_or("")
        .to_string();

    if authentication_token.is_empty() {
        return Err((StatusCode::BAD_REQUEST, "authorization header is empty"));
    }
    let token_result: Result<TokenData<Claims>, JwtError> = decode::<Claims>(
        &authentication_token,
        &DecodingKey::from_secret(secret.as_bytes()),
        &Validation::new(Algorithm::HS256),
    );
    match token_result {
        Ok(_token) => Ok(_token),
        Err(_e) => Err((StatusCode::UNAUTHORIZED, "Token Error")),
    }
}
//main.rs
use axum::{
    body::Bytes,
    extract::{Json, Request, State},
    routing::{get, post},
    Router,
};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::Deserialize;

use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tower_http::trace::TraceLayer;
use tracing::Span;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

mod custom_middleware;
use custom_middleware::Claims;
use custom_middleware::MyLayer;

mod state;
use state::AppState;

#[derive(Deserialize, Debug, PartialEq)]
struct User {
    account: usize,
    password: String,
}

async fn register(State(state): State<AppState>, Json(user): Json<User>) -> String {
    let store_user = User {
        account: 195,
        password: "world".to_string(),
    };
    if user == store_user {
        let expiration = SystemTime::now() + Duration::from_secs(30 * 60);
        let exp_timestamp = expiration.duration_since(UNIX_EPOCH).unwrap().as_secs();
        let claims = Claims {
            id: user.account,
            exp: exp_timestamp as usize,
        };
        let token = encode(
            &Header::default(),
            &claims,
            &EncodingKey::from_secret(state.secret.as_bytes()),
        )
        .unwrap();
        token
    } else {
        "hello, world!".to_string()
    }
}

async fn login(State(state): State<AppState>, req: Request) -> Json<Claims> {
    let token = req
        .headers()
        .get("Authorization")
        .unwrap()
        .to_str()
        .unwrap();
    let payload = decode::<Claims>(
        token,
        &DecodingKey::from_secret(state.secret.as_bytes()),
        &Validation::new(Algorithm::HS256),
    )
    .unwrap();
    Json(payload.claims)
}

async fn protected(_req: Request) -> String {
    "World!".to_string()
}
#[tokio::main]
async fn main() {
    let state = AppState {
        secret: "baby195lxl".to_string(),
    };
    tracing_subscriber::registry()
        .with(tracing_subscriber::EnvFilter::new("debug"))
        .with(tracing_subscriber::fmt::layer())
        .init();

    let app = Router::new()
        .route("/protected", get(protected))
        .layer(MyLayer)
        .route("/register", post(register))
        .route("/login", post(login))
        .with_state(state)
        .layer(TraceLayer::new_for_http().on_body_chunk(
            |chunk: &Bytes, latency: Duration, _span: &Span| {
                tracing::debug!("streaming {} bytes in {:?}", chunk.len(), latency);
            },
        ));

    let listener = tokio::net::TcpListener::bind("127.0.0.1:5000")
        .await
        .unwrap();
    tracing::debug!("listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}
Err(_) => Err((StatusCode::BAD_REQUEST, "bad request")),
   |                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `Pin<Box<dyn Future<Output = ...> + Send>>`, found `Result<_, (StatusCode, &str)>`

我还没有找到不产生错误的解决方案。 Cargo.toml的配置如下,我使用的Rust编译器版本是rustc 1.76.0 (07dca489a 2024-02-04)。如果有人能帮助我消除这种困惑,我将不胜感激,欢迎任何回复。谢谢。

[dependencies]
axum = "^0.7"
tokio = { version = "^1.36", features = ["full"] }
tower-http = { version = "^0.5", features = ["trace"] }
tracing = "^0.1"
tracing-subscriber = { version = "^0.3", features = ["env-filter"] }
serde = { version = "1.0", features = ["derive"] }
jsonwebtoken = "9.2.0"
tower = "0.4.13"
futures-util = "0.3.30"
rust jwt rust-axum
1个回答
0
投票

解决当前问题的方法是将您的

Err(...)
包裹在
Box::pin(async { ... })
中,因为您已经说过您的
Self::Future
BoxFuture

但是,这不是解决方案,因为您会发现您正在使用

S::Error
而不是
(StatusCode, &'static str)
作为错误类型,这意味着您必须遵守嵌套服务的错误类型。你可能不想要。您可以创建一个包装错误枚举,它可以表达您自己的错误类型或来自嵌套服务的错误类型,但这也不是解决方案,因为......

Axum 预计其

Service
不会失败。您可以看到
Router::layer
受到约束,因此图层服务的
Error
类型为
Into<Infallible>
,这意味着它无法返回错误。这是因为,虽然 tower 的设计能够很好地描述错误类型,但 axum 期望错误立即转换为
Response
s。

所以你可以重构你的

Layer
Service
实现来遵循这一点,但更好的解决方案是使用axum的
middleware::from_fn
辅助函数,这使得它更好用:

use axum::middleware::Next;

async fn auth_middleware(req: Request, next: Next) -> Result<Response, (StatusCode, &'static str)> {
    match has_permission(&req) {
        Ok(_) => {
            let response = next.run(req).await;
            Ok(response)
        }
        Err(_) => Err((StatusCode::BAD_REQUEST, "bad request")),
    }
}
let app = Router::new()
    ...
    .layer(axum::middleware::from_fn(auth_middleware))
    ...
© www.soinside.com 2019 - 2024. All rights reserved.