我正在尝试使用 Axum 中的自定义中间件来实现 JWT 身份验证。但是,当我尝试返回验证失败的错误时,我无法编译它。下面是我编写的 JWT 中间件验证代码。您能告诉我如何修改它以实现所需的功能吗?
//custom_middleware.rs
use axum::http::StatusCode;
use axum::{extract::Request, response::Response};
use futures_util::future::BoxFuture;
use jsonwebtoken::{
decode, errors::Error as JwtError, Algorithm, DecodingKey, TokenData, Validation,
};
use serde::{Deserialize, Serialize};
use std::task::{Context, Poll};
use tower::{Layer, Service};
#[derive(Serialize, Deserialize)]
pub struct Claims {
pub id: usize,
pub exp: usize,
}
#[derive(Clone)]
pub struct MyLayer;
impl<S> Layer<S> for MyLayer {
type Service = MyMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
MyMiddleware { inner }
}
}
#[derive(Clone)]
pub struct MyMiddleware<S> {
inner: S,
}
impl<S> Service<Request> for MyMiddleware<S>
where
S: Service<Request, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
match has_permission(&req) {
Ok(_) => {
let future = self.inner.call(req);
Box::pin(async move {
let response: Response = future.await?;
Ok(response)
})
}
Err(_) => Err((StatusCode::BAD_REQUEST, "bad request")),
}
}
}
fn has_permission(req: &Request) -> Result<TokenData<Claims>, (StatusCode, &'static str)> {
let secret = "baby195lxl";
let authorization_header_option = req.headers().get("authorization");
if authorization_header_option.is_none() {
return Err((StatusCode::BAD_REQUEST, "authorization header is none"));
}
let authentication_token: String = authorization_header_option
.unwrap()
.to_str()
.unwrap_or("")
.to_string();
if authentication_token.is_empty() {
return Err((StatusCode::BAD_REQUEST, "authorization header is empty"));
}
let token_result: Result<TokenData<Claims>, JwtError> = decode::<Claims>(
&authentication_token,
&DecodingKey::from_secret(secret.as_bytes()),
&Validation::new(Algorithm::HS256),
);
match token_result {
Ok(_token) => Ok(_token),
Err(_e) => Err((StatusCode::UNAUTHORIZED, "Token Error")),
}
}
//main.rs
use axum::{
body::Bytes,
extract::{Json, Request, State},
routing::{get, post},
Router,
};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::Deserialize;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tower_http::trace::TraceLayer;
use tracing::Span;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
mod custom_middleware;
use custom_middleware::Claims;
use custom_middleware::MyLayer;
mod state;
use state::AppState;
#[derive(Deserialize, Debug, PartialEq)]
struct User {
account: usize,
password: String,
}
async fn register(State(state): State<AppState>, Json(user): Json<User>) -> String {
let store_user = User {
account: 195,
password: "world".to_string(),
};
if user == store_user {
let expiration = SystemTime::now() + Duration::from_secs(30 * 60);
let exp_timestamp = expiration.duration_since(UNIX_EPOCH).unwrap().as_secs();
let claims = Claims {
id: user.account,
exp: exp_timestamp as usize,
};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(state.secret.as_bytes()),
)
.unwrap();
token
} else {
"hello, world!".to_string()
}
}
async fn login(State(state): State<AppState>, req: Request) -> Json<Claims> {
let token = req
.headers()
.get("Authorization")
.unwrap()
.to_str()
.unwrap();
let payload = decode::<Claims>(
token,
&DecodingKey::from_secret(state.secret.as_bytes()),
&Validation::new(Algorithm::HS256),
)
.unwrap();
Json(payload.claims)
}
async fn protected(_req: Request) -> String {
"World!".to_string()
}
#[tokio::main]
async fn main() {
let state = AppState {
secret: "baby195lxl".to_string(),
};
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new("debug"))
.with(tracing_subscriber::fmt::layer())
.init();
let app = Router::new()
.route("/protected", get(protected))
.layer(MyLayer)
.route("/register", post(register))
.route("/login", post(login))
.with_state(state)
.layer(TraceLayer::new_for_http().on_body_chunk(
|chunk: &Bytes, latency: Duration, _span: &Span| {
tracing::debug!("streaming {} bytes in {:?}", chunk.len(), latency);
},
));
let listener = tokio::net::TcpListener::bind("127.0.0.1:5000")
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
Err(_) => Err((StatusCode::BAD_REQUEST, "bad request")),
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `Pin<Box<dyn Future<Output = ...> + Send>>`, found `Result<_, (StatusCode, &str)>`
我还没有找到不产生错误的解决方案。 Cargo.toml的配置如下,我使用的Rust编译器版本是rustc 1.76.0 (07dca489a 2024-02-04)。如果有人能帮助我消除这种困惑,我将不胜感激,欢迎任何回复。谢谢。
[dependencies]
axum = "^0.7"
tokio = { version = "^1.36", features = ["full"] }
tower-http = { version = "^0.5", features = ["trace"] }
tracing = "^0.1"
tracing-subscriber = { version = "^0.3", features = ["env-filter"] }
serde = { version = "1.0", features = ["derive"] }
jsonwebtoken = "9.2.0"
tower = "0.4.13"
futures-util = "0.3.30"
解决当前问题的方法是将您的
Err(...)
包裹在 Box::pin(async { ... })
中,因为您已经说过您的 Self::Future
是 BoxFuture
。
但是,这不是解决方案,因为您会发现您正在使用
S::Error
而不是 (StatusCode, &'static str)
作为错误类型,这意味着您必须遵守嵌套服务的错误类型。你可能不想要。您可以创建一个包装错误枚举,它可以表达您自己的错误类型或来自嵌套服务的错误类型,但这也不是解决方案,因为......
Axum 预计其
Service
不会失败。您可以看到 Router::layer
受到约束,因此图层服务的 Error
类型为 Into<Infallible>
,这意味着它无法返回错误。这是因为,虽然 tower 的设计能够很好地描述错误类型,但 axum 期望错误立即转换为 Response
s。
所以你可以重构你的
Layer
和Service
实现来遵循这一点,但更好的解决方案是使用axum的middleware::from_fn
辅助函数,这使得它更好用:
use axum::middleware::Next;
async fn auth_middleware(req: Request, next: Next) -> Result<Response, (StatusCode, &'static str)> {
match has_permission(&req) {
Ok(_) => {
let response = next.run(req).await;
Ok(response)
}
Err(_) => Err((StatusCode::BAD_REQUEST, "bad request")),
}
}
let app = Router::new()
...
.layer(axum::middleware::from_fn(auth_middleware))
...