使用 Rust 实现可跨域访问的 WebSocket 服务端

前言

在浏览器里使用 WebSocket 时,连接往往来自与接口不同源的页面,例如前端跑在 http://localhost:5173,而服务监听 http://127.0.0.1:3000
浏览器会把这次连接当作跨源请求,并在握手阶段的 HTTP 请求里带上 Origin 等头信息。
若服务端返回的握手响应缺少浏览器期望的 CORS 相关响应头,连接可能被拒绝或表现异常。
本文用 axum 提供 WebSocket 升级处理,用 tower-httpCorsLayer 统一为 HTTP 与升级响应附加跨源策略,便于本地或前后端分离场景联调。
文中还会在进程内维护「用户名到连接数」的表,连接建立时加一、连接结束时减一,并用只读接口查看当前有多少个不同用户在线。
读完你可以得到一个最小可运行的回声服务,并知道如何把「允许任意源」收紧为白名单。

依赖

新建二进制项目后,在 Cargo.toml 中加入下文列出的依赖。
版本号可按团队锁定策略调整,下列为撰写时的示例组合。

1
2
3
4
5
6
7
8
9
10
[package]
name = "ws-cors-demo"
version = "0.1.0"
edition = "2021"

[dependencies]
axum = { version = "0.8", features = ["ws"] }
tokio = { version = "1", features = ["full"] }
tower-http = { version = "0.6", features = ["cors"] }
serde = { version = "1", features = ["derive"] }

上述各依赖的大致分工如下。

  • axum:Web 框架,提供 Router、路由与提取器;开启 ws 特性后支持 WebSocketUpgrade,完成 HTTP 到 WebSocket 的升级与后续帧读写。
  • tokio:异步运行时,驱动 async/await、监听 TcpListener、运行 axum::serve,并提供 tokio::sync::RwLock 等可在异步任务间共享的状态原语。
  • tower-http:基于 tower 的 HTTP 中间件库;开启 cors 后使用 CorsLayer,在握手响应与普通 HTTP 响应上附加 Access-Control-* 等跨源相关头。
  • serde:序列化与反序列化基础设施;配合 derive 为结构体生成 Deserialize,供 Query 等提取器把查询串解析成 WsQuery 这类类型。

实现

在线用户与共享状态

在线用户统计放在 AppState 里,用 Arc<RwLock<HashMap<String, usize>>> 表示「每个 user 当前占几条 WebSocket 连接」。
同一账号打开多个标签页时,连接数会大于一,断开时做自减,减到零就从表里删掉,这样 HashMap::len 就是「当前至少有一条连接的不同用户数」。
登记写在进入消息循环之前,清理写在循环结束之后,这样不依赖在 Drop 里执行异步逻辑,路径简单清晰。
生产环境仍应对 user 做鉴权,本文用查询参数仅作演示,避免把示例写成长篇认证流程。

路由、跨源与完整示例

先组装带 AppStateRouter,为 /ws 注册 GET 升级处理,为 /online 提供只读统计,并挂上 CorsLayer
allow_origin(Any) 表示不校验来源、允许任意 Origin 连接,仅适合开发环境,生产环境应改为明确白名单。
with_state 写在 layer(cors) 之前,使跨源中间件包在外层,握手与普通 GET 都能带上 CORS 响应头。

下列代码为完整可替换的 src/main.rs,其中 /health 仍用于快速确认进程已监听。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Query, State,
},
response::IntoResponse,
routing::get,
Router,
};
use serde::Deserialize;
use std::{
collections::HashMap,
net::SocketAddr,
sync::Arc,
};
use tokio::sync::RwLock;
use tower_http::cors::{Any, CorsLayer};

#[derive(Clone)]
struct AppState {
online: Arc<RwLock<HashMap<String, usize>>>,
}

#[derive(Deserialize)]
struct WsQuery {
user: String,
}

#[tokio::main]
async fn main() {
let state = AppState {
online: Arc::new(RwLock::new(HashMap::new())),
};

let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::OPTIONS,
])
.allow_headers(Any);

let app = Router::new()
.route("/ws", get(ws_handler))
.route("/health", get(|| async { "ok" }))
.route("/online", get(online_users_handler))
.with_state(state)
.layer(cors);

let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}

async fn online_users_handler(State(state): State<AppState>) -> String {
let g = state.online.read().await;
let n = g.len();
format!("online_users: {n}")
}

async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Query(q): Query<WsQuery>,
) -> impl IntoResponse {
let user = q.user;
ws.on_upgrade(move |socket| handle_socket(socket, state, user))
}

async fn handle_socket(mut socket: WebSocket, state: AppState, user: String) {
{
let mut g = state.online.write().await;
*g.entry(user.clone()).or_insert(0) += 1;
}

while let Some(Ok(msg)) = socket.recv().await {
if let Message::Text(t) = msg {
let reply = format!("echo: {t}");
if socket.send(Message::Text(reply.into())).await.is_err() {
break;
}
}
}

let mut g = state.online.write().await;
if let Some(c) = g.get_mut(&user) {
*c = c.saturating_sub(1);
if *c == 0 {
g.remove(&user);
}
}
}

|socket| ... 是什么?

竖线里是这个匿名函数的参数。

只有 socket 是「调用方会传进来的」:升级完成后,Axum 把建好的 WebSocket 传给你,所以写在 |socket| 里。

写成:|socket| handle_socket(socket, state, user)

意思是:谁调用这个闭包,都要给我一个 socket;我拿到以后,就去调用 handle_socket(socket, state, user)

也就是说:handle_socket 需要三个参数,其中 socket 来自 Axum 稍后的调用,stateuser 来自当前这次请求里你已经有的值。

验证

  1. 在项目根目录执行 cargo run,确认终端无报错且监听 127.0.0.1:3000
  2. 浏览器或使用 curl 访问 http://127.0.0.1:3000/health,应返回纯文本 ok
  3. 再访问 http://127.0.0.1:3000/online,未连接任何 WebSocket 时应看到 online_users: 0
  4. 用另一端口或本地静态页发起跨源连接,地址需带 user 查询参数,打开开发者工具查看是否收到 echo: 前缀消息,并在连接保持期间刷新 /online 观察计数变化。

下面页面在连接时使用 user=demo,若你改了端口或路径请同步修改 url 变量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
<!DOCTYPE html>
<html>

<head>
<meta charset="UTF-8">
<title>WebSocket 聊天</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}

body {
font-family: Arial, sans-serif;
max-width: 600px;
margin: 20px auto;
padding: 0 20px;
}

h1 {
text-align: center;
color: #333;
margin-bottom: 10px;
}

.status {
text-align: center;
margin-bottom: 15px;
padding: 8px;
border-radius: 4px;
}

.status.connected {
background: #d4edda;
color: #155724;
}

.status.disconnected {
background: #f8d7da;
color: #721c24;
}

#url-display {
text-align: center;
font-size: 14px;
color: #666;
margin-bottom: 15px;
word-break: break-all;
}

#chat {
border: 1px solid #ddd;
height: 400px;
overflow-y: auto;
padding: 15px;
background: #f9f9f9;
}

.msg {
margin-bottom: 10px;
padding: 8px 12px;
border-radius: 8px;
max-width: 80%;
}

.msg.sent {
background: #007bff;
color: white;
margin-left: auto;
}

.msg.received {
background: #e9ecef;
color: #333;
}

.msg .time {
font-size: 10px;
opacity: 0.7;
margin-top: 4px;
}

.input-area {
display: flex;
gap: 10px;
margin-top: 15px;
}

#msg-input {
flex: 1;
padding: 12px;
border: 1px solid #ddd;
outline: none;
border-radius: 4px;
font-size: 16px;
}

#send-btn {
padding: 12px 24px;
background: #007bff;
color: white;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 16px;
}

#send-btn:hover {
background: #0056b3;
}

#send-btn:disabled {
background: #ccc;
cursor: not-allowed;
}
</style>
</head>

<body>
<h1>WebSocket 聊天</h1>
<div id="url-display">连接: <span id="ws-url"></span></div>
<div id="status" class="status disconnected">未连接</div>
<div id="chat"></div>
<div class="input-area">
<input type="text" id="msg-input" placeholder="输入消息..." disabled>
<button id="send-btn" disabled>发送</button>
</div>

<script>
const wsUrl = "ws://127.0.0.1:3000/ws?user=demo";
document.getElementById('ws-url').textContent = wsUrl;

const ws = new WebSocket(wsUrl);
const chat = document.getElementById('chat');
const status = document.getElementById('status');
const input = document.getElementById('msg-input');
const sendBtn = document.getElementById('send-btn');

ws.onopen = () => {
status.textContent = '已连接';
status.className = 'status connected';
input.disabled = false;
sendBtn.disabled = false;
input.focus();
};

ws.onclose = () => {
status.textContent = '连接已断开';
status.className = 'status disconnected';
input.disabled = true;
sendBtn.disabled = true;
};

ws.onerror = () => {
status.textContent = '连接错误';
status.className = 'status disconnected';
};

ws.onmessage = (ev) => addMsg(ev.data, 'received');

sendBtn.onclick = sendMsg;
input.onkeypress = (e) => { if (e.key === 'Enter') sendMsg(); };

function sendMsg () {
const text = input.value.trim();
if (!text) return;
ws.send(text);
addMsg(text, 'sent');
input.value = '';
input.focus();
}

function addMsg (text, type) {
const div = document.createElement('div');
div.className = 'msg ' + type;
const time = new Date().toLocaleTimeString();
div.innerHTML = `<div>${text}</div><div class="time">${time}</div>`;
chat.appendChild(div);
chat.scrollTop = chat.scrollHeight;
}
</script>
</body>

</html>

扩展

若只需允许固定前端地址,可将 allow_origin(Any) 换成 AllowOrigin::list 或自定义谓词,避免任意网站都能连上你的 WebSocket。
对外部署时务必配合 TLS,使用 wss://,并在反向代理上正确转发 UpgradeConnection 头。
业务层可在此之上增加鉴权,例如在查询串或首条消息里校验令牌,但不要依赖前端传来的 Origin 作为唯一安全依据。
多实例部署时,进程内 HashMap 彼此不可见,在线用户应落到 Redis 等外部存储,并用 TTL 或心跳处理僵死连接。

总结

步骤:

  1. 引入 axum(含 ws)、tokiotower-http(含 cors)、serde(含 derive)。
  2. 定义 AppStateArc<RwLock<HashMap<String, usize>>>,在连接建立与断开时分别增减计数。
  3. Routerwith_statelayer(CorsLayer),覆盖 GET 与 OPTIONS 等握手所需方法。
  4. WebSocketUpgradeQuery 解析 user,在 handle_socket 中读写 Message 并保证退出时减掉连接数。

注意:

  • allow_origin(Any) 仅适合联调,上线请改为白名单。
  • 文本消息发送时注意 Message::Text 的目标类型,必要时使用 into()
  • 跨源问题既涉及响应头,也涉及浏览器安全策略,异常时优先在开发者工具网络面板查看握手请求与响应头。
  • /online 统计的是「至少有一条连接的不同 user 数量」,若要推送昵称列表或踢人,需要在此基础上扩展结构与协议。