如何在自定义提取器中获取状态?

问题描述 投票:0回答:1

我正在尝试使用 Axum 的自定义提取器来实现 JWT 身份验证。我可以在提取器内打印 State 的内容,但无法在提取器内访问它。下面是我编写的 JWT 验证代码。你能告诉我如何修改它以实现我想要的功能吗?

//main.rs
use axum::body::Bytes;
use axum::extract::{Json, Request, State};
use axum::{
    routing::{get, post},
    Router,
};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tower_http::trace::TraceLayer;
use tracing::Span;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

mod authentication_token;
use authentication_token::ExtractAuthorization;

#[derive(Deserialize, Debug, PartialEq)]
struct User {
    account: usize,
    password: String,
}

#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct Claims {
    pub id: usize,
    pub exp: usize,
}

async fn register(State(secret): State<&str>, Json(user): Json<User>) -> String {
    let store_user = User {
        account: 195,
        password: "world".to_string(),
    };
    if user == store_user {
        let expiration = SystemTime::now() + Duration::from_secs(30 * 60);
        let exp_timestamp = expiration.duration_since(UNIX_EPOCH).unwrap().as_secs();
        let claims = Claims {
            id: user.account,
            exp: exp_timestamp as usize,
        };
        let token = encode(
            &Header::default(),
            &claims,
            &EncodingKey::from_secret(secret.as_bytes()),
        )
        .unwrap();
        token
    } else {
        "hello, world!".to_string()
    }
}

async fn login(State(secret): State<&str>, req: Request) -> Json<Claims> {
    let token = req
        .headers()
        .get("Authorization")
        .unwrap()
        .to_str()
        .unwrap();
    let payload = decode::<Claims>(
        token,
        &DecodingKey::from_secret(secret.as_bytes()),
        &Validation::new(Algorithm::HS256),
    )
    .unwrap();
    Json(payload.claims)
}

async fn protected(_auth_token: ExtractAuthorization, req: Request) -> String {
    println!("{:?}", req);
    "World!".to_string()
}

#[tokio::main]
async fn main() {
    tracing_subscriber::registry()
        .with(tracing_subscriber::EnvFilter::new("debug"))
        .with(tracing_subscriber::fmt::layer())
        .init();

    let app = Router::new()
        .route("/register", post(register))
        .route("/login", post(login))
        .route("/protected", get(protected))
        .with_state("baby195lxl")
        .layer(TraceLayer::new_for_http().on_body_chunk(
            |chunk: &Bytes, latency: Duration, _span: &Span| {
                tracing::debug!("streaming {} bytes in {:?}", chunk.len(), latency);
            },
        ));

    let listener = tokio::net::TcpListener::bind("127.0.0.1:5000")
        .await
        .unwrap();
    tracing::debug!("listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}
//authentication_token.rs
use axum::{
    async_trait,
    extract::FromRequestParts,
    http::{header::AUTHORIZATION, request::Parts, StatusCode},
};
use jsonwebtoken::{
    decode, errors::Error as JwtError, Algorithm, DecodingKey, TokenData, Validation,
};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;

pub struct ExtractAuthorization {
    id: usize,
}

#[derive(Serialize, Deserialize)]
pub struct Claims {
    pub id: usize,
    pub exp: usize,
}

#[async_trait]
impl<S> FromRequestParts<S> for ExtractAuthorization
where
    S: Send + Sync + Debug,
{
    type Rejection = (StatusCode, &'static str);
    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        println!("{:?}", state);   // <-- I can print the content
        let auth_header = parts.headers.get(AUTHORIZATION);
        if auth_header.is_none() {
            return Err((StatusCode::BAD_REQUEST, "Authorization is Missing"));
        }
        let auth_token: String = auth_header.unwrap().to_str().unwrap_or("").to_string();

        if auth_token.is_empty() {
            return Err((
                StatusCode::BAD_REQUEST,
                "Authentication token has foreign chars!",
            ));
        }
        let token_result: Result<TokenData<Claims>, JwtError> = decode::<Claims>(
            &auth_token,
            &DecodingKey::from_secret(state.as_bytes()),  // <-- but i cant use it as &str
            &Validation::new(Algorithm::HS256),
        );
        match token_result {
            Ok(token) => Ok(ExtractAuthorization {
                id: token.claims.id,
            }),
            Err(_) => Err((StatusCode::BAD_REQUEST, "Token Error")),
        }
    }
}

我还没有找到不会产生错误的解决方案。

Cargo.toml
的配置如下,我使用的Rust编译器版本是rustc 1.76.0 (07dca489a 2024-02-04)。如果有人能帮助我消除这种困惑,我将不胜感激,欢迎任何答复。谢谢。

[dependencies]
axum = "^0.7"
tokio = { version = "^1.36", features = ["full"] }
tower-http = { version = "^0.5", features = ["trace"] }
tracing = "^0.1"
tracing-subscriber = { version = "^0.3", features = ["env-filter"] }
serde = { version = "1.0", features = ["derive"] }
jsonwebtoken = "9.2.0"

rust jwt rust-axum
1个回答
0
投票

不要为所有状态通用地实现提取器,而是为您拥有的特定状态实现它:

#[async_trait]
impl<'a> FromRequestParts<&'a str> for ExtractAuthorization {
    type Rejection = (StatusCode, &'static str);
    async fn from_request_parts(parts: &mut Parts, state: &&'a str) -> Result<Self, Self::Rejection> {
        // ...
    }
}
© www.soinside.com 2019 - 2024. All rights reserved.