使用tokio-tungstenite时如何获得标题?

问题描述 投票:0回答:2

我正在尝试使用 tokio-tungstenite 箱来根据 URL 创建聊天室。例如,我有一个客户端连接到

ws://localhost:8080/abcd
。我的理解是,我必须使用
tokio_tungstenite::accept_hdr_async
函数来访问标头才能获取
/abcd
路径,但我在使用它时遇到问题。我对
copy_headers_callback
的第二个论点应该是什么?

我的代码基于这个示例

use std::{
    collections::HashMap,
    env,
    io::Error as IoError,
    net::SocketAddr,
    sync::{Arc, Mutex},
    marker::Unpin,
};

use futures_channel::mpsc::{unbounded, UnboundedSender};
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt};

use tokio::net::{TcpListener, TcpStream};
use tungstenite::{
    protocol::Message,
    handshake::server::{Request},
};

type Sender = UnboundedSender<Message>;
type PeerMap = Arc<Mutex<HashMap<SocketAddr, Sender>>>;

use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize)]
struct BroadcastJsonStruct {
    message: String,
    sender_addr: SocketAddr,
}

async fn handle_connection(peer_map: PeerMap, raw_stream: TcpStream, client_addr: SocketAddr) {
    println!("Incoming TCP connection from: {}, raw stream: {}", client_addr, raw_stream.local_addr().unwrap());

    let copy_headers_callback = |request: &Request| {
            for (name, value) in request.headers().iter() {
                println!("Name: {}, value: ", name.to_string())
                // headers.insert(name.to_string(), value.clone());
            }
            Ok(None)
        };

    //accept a new asynchronous WebSocket connection
    let ws_stream = tokio_tungstenite::accept_hdr_async(
        raw_stream,
        copy_headers_callback,
    )
        .await
        .expect("Error during the websocket handshake occurred");
    println!("WebSocket connection established: {}", client_addr);

    // Insert the write part of this peer to the peer map.
    let (sender, receiver) = unbounded();
    peer_map.lock().unwrap().insert(client_addr, sender);

    //set up the incoming and outgoing
    let (outgoing, incoming) = ws_stream.split();

    let broadcast_incoming = incoming.try_for_each(|msg| {
        println!("Received a message from {}: {}", client_addr, msg.to_text().unwrap());
        let peers = peer_map.lock().unwrap();

        //make a new struct to be serialized
        let broadcast_data = BroadcastJsonStruct {
            message: msg.to_text().unwrap().to_owned(),
            sender_addr: client_addr.to_owned(),
        };
        let new_msg = Message::Text(
            serde_json::to_string(&broadcast_data).expect("problem serializing broadcast_data")
        );
        println!("New message {}", new_msg.to_text().unwrap());

        // We want to broadcast the message to everyone except ourselves.
        //filter returns addresses that aren't our current address
        let broadcast_recipients =
            peers.iter().filter(|(peer_addr, _)| peer_addr != &&client_addr).map(|(_, ws_sink)| ws_sink);

        //send the message to all the recipients
        for recp in broadcast_recipients {
            recp.unbounded_send(new_msg.clone()).unwrap();
        }

        future::ok(())
    });

    let receive_from_others = receiver.map(Ok).forward(outgoing);

    pin_mut!(broadcast_incoming, receive_from_others);
    future::select(broadcast_incoming, receive_from_others).await;

    println!("{} disconnected", &client_addr);
    peer_map.lock().unwrap().remove(&client_addr);
}

#[tokio::main]
async fn main() -> Result<(), IoError> {
    //see if there is a server address specified in the command line argument, else use default address
    let server_addr = env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string());

    let state = PeerMap::new(Mutex::new(HashMap::new()));

    // Create the event loop and TCP listener we'll accept connections on.
    let try_socket = TcpListener::bind(&server_addr).await; //create the server on the address
    let listener = try_socket.expect("Failed to bind");
    println!("Listening on: {}", server_addr);

    // Let's spawn the handling of each connection in a separate task.
    while let Ok((stream, client_addr)) = listener.accept().await {
        tokio::spawn(handle_connection(state.clone(), stream, client_addr));
    }

    Ok(())
}

我的错误:

error[E0593]: closure is expected to take 2 arguments, but it takes 1 argument
   --> src/main.rs:45:9
    |
34  |     let copy_headers_callback = |request: &Request| {
    |                                 ------------------- takes 1 argument
...
45  |         copy_headers_callback,
    |         ^^^^^^^^^^^^^^^^^^^^^ expected closure that takes 2 arguments
    | 
   ::: /Users/harryli/.cargo/registry/src/github.com-1ecc6299db9ec823/tokio-tungstenite-0.12.0/src/lib.rs:151:8
    |
151 |     C: Callback + Unpin,
    |        -------- required by this bound in `accept_hdr_async`
    |
    = note: required because of the requirements on the impl of `Callback` for `[closure@src/main.rs:34:33: 40:10]`

error: aborting due to previous error; 1 warning emitted

不确定我的方法是否合理。我是 Rust 新手,所以我将不胜感激任何见解!

rust websocket header rust-tokio
2个回答
6
投票

我最终让回调工作以打印标题名称和值,如下所示!

我在标题中没有找到路径

/abcd
。我的解决方法是将
window.location.pathname
作为协议传递,然后访问服务器中的协议。

下面的代码显示了服务器的工作示例,该服务器仅向使用相同协议的客户端广播消息。

JS

const socket = new WebSocket("ws://localhost:8080", window.location.pathname.replace(/\//ig, "-")

Rust 回调:

let mut protocol = HeaderValue::from_static("");

    let copy_headers_callback = |request: &Request, mut response: Response| -> Result<Response, ErrorResponse> {
        for (name, value) in request.headers().iter() {
            println!("Name: {}, value: {}", name.to_string(), value.to_str().expect("expected a value"));
        }

        //access the protocol in the request, then set it in the response
        protocol = request.headers().get(SEC_WEBSOCKET_PROTOCOL).expect("the client should specify a protocol").to_owned(); //save the protocol to use outside the closure
        let response_protocol = request.headers().get(SEC_WEBSOCKET_PROTOCOL).expect("the client should specify a protocol").to_owned();
        response.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, response_protocol);
        Ok(response)
    };

    //accept a new asynchronous WebSocket connection
    let ws_stream = tokio_tungstenite::accept_hdr_async(
        raw_stream,
        copy_headers_callback,
    )
        .await
        .expect("Error during the websocket handshake occurred");

完整 Rust 代码:

use std::{
    collections::HashMap,
    env,
    io::Error as IoError,
    net::SocketAddr,
    sync::{Arc, Mutex},
};

use futures_channel::mpsc::{unbounded, UnboundedSender};
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt};

use tokio::net::{TcpListener, TcpStream};
use tungstenite::{
    protocol::Message,
    handshake::server::{Request, Response, ErrorResponse},
};
use http::header::{
    HeaderValue,
    SEC_WEBSOCKET_PROTOCOL,
};

type Sender = UnboundedSender<Message>;
struct PeerStruct {
    protocol: HeaderValue,
    sender: Sender,
}

type PeerMap = Arc<Mutex<HashMap<SocketAddr, PeerStruct>>>;

use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize)]
struct BroadcastJsonStruct {
    message: String,
    sender_addr: SocketAddr,
}


async fn handle_connection(peer_map: PeerMap, raw_stream: TcpStream, client_addr: SocketAddr) {
    println!("Incoming TCP connection from: {}, raw stream: {}", client_addr, raw_stream.local_addr().unwrap());
    let mut protocol = HeaderValue::from_static("");

    let copy_headers_callback = |request: &Request, mut response: Response| -> Result<Response, ErrorResponse> {
        for (name, value) in request.headers().iter() {
            println!("Name: {}, value: {}", name.to_string(), value.to_str().expect("expected a value"));
        }

        //access the protocol in the request, then set it in the response
        protocol = request.headers().get(SEC_WEBSOCKET_PROTOCOL).expect("the client should specify a protocol").to_owned(); //save the protocol to use outside the closure
        let response_protocol = request.headers().get(SEC_WEBSOCKET_PROTOCOL).expect("the client should specify a protocol").to_owned();
        response.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, response_protocol);
        Ok(response)
    };

    //accept a new asynchronous WebSocket connection
    let ws_stream = tokio_tungstenite::accept_hdr_async(
        raw_stream,
        copy_headers_callback,
    )
    .await
    .expect("Error during the websocket handshake occurred");
    println!("WebSocket connection established: {}", client_addr);

    // Insert the write part of this peer to the peer map.
    let (sender, receiver) = unbounded();
    peer_map.lock().unwrap().insert(client_addr, PeerStruct {
        protocol: protocol.to_owned(),
        sender: sender,
    });

    //set up the incoming and outgoing
    let (outgoing, incoming) = ws_stream.split();

    //this function broadcasts messages to all other connected clients using the same protocol
    let broadcast_incoming = incoming.try_for_each(|msg| {
        println!("Received a message from {}: {}", client_addr, msg.to_text().unwrap());
        let peers = peer_map.lock().unwrap();

        //make a new struct to be serialized
        let broadcast_data = BroadcastJsonStruct {
            message: msg.to_text().unwrap().to_owned(),
            sender_addr: client_addr.to_owned(),
        };
        let new_msg = Message::Text(
            serde_json::to_string(&broadcast_data).expect("problem serializing broadcast_data")
        );
        println!("New message {}", new_msg.to_text().unwrap());

        //filter addresses that aren't the message sender's address AND are using the same protocol
        let broadcast_recipients = peers.iter().filter(
            |(peer_addr, _)|
            peer_addr != &&client_addr
            && peers.get(peer_addr).expect("peer_addr should be a key in the HashMap").protocol.to_str().expect("expected a string")==protocol.to_str().expect("expected a string")
        ).map(|(_, ws_sink)| ws_sink);

        //send the message to all the recipients
        for recp in broadcast_recipients {
            recp.sender.unbounded_send(new_msg.clone()).unwrap();
        }

        future::ok(())
    });

    let receive_from_others = receiver.map(Ok).forward(outgoing);

    pin_mut!(broadcast_incoming, receive_from_others);
    future::select(broadcast_incoming, receive_from_others).await;

    println!("{} disconnected", &client_addr);
    peer_map.lock().unwrap().remove(&client_addr);
}

#[tokio::main]
async fn main() -> Result<(), IoError> {
    //see if there is a server address specified in the command line argument, else use default address
    let server_addr = env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string());

    let state = PeerMap::new(Mutex::new(HashMap::new()));

    // Create the event loop and TCP listener we'll accept connections on.
    let try_socket = TcpListener::bind(&server_addr).await; //create the server on the address
    let listener = try_socket.expect("Failed to bind");
    println!("Listening on: {}", server_addr);

    // Let's spawn the handling of each connection in a separate task.
    while let Ok((stream, client_addr)) = listener.accept().await {
        tokio::spawn(handle_connection(state.clone(), stream, client_addr));
    }

    Ok(())
}

货物.toml

[package]
name = "server"
version = "0.1.0"
authors = ["harryli0088 <[email protected]>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tokio-tungstenite = "*"
tokio = { version = "0.3", features = ["full"] }
futures-channel = "0.3"
futures-util = "0.3.8"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
http = "0.2.2"

[dependencies.tungstenite]
version = "0.11.1"
default-features = false

0
投票

这个问题比较旧,但为了完整性:你自己的答案非常接近。

我在标题中没有找到路径 /abcd。

它不在标题中,而是

tungstenite::handshake::server::Request::uri()

© www.soinside.com 2019 - 2024. All rights reserved.