我有一个在 ASP.NET Core 7 中内置的 API。我有一个 MediatR 控制器,可以读取消息并响应消息。客户端正在从 ReactJS 网站进行连接。我计划在不同的域上多次部署客户端站点。
我正在尝试找到最不易欺骗的方式来获取客户端连接的域。我能够知道允许连接哪些域,因此如果有帮助的话我可以使用 API 密钥。虽然安全性并不是一个大问题,但了解哪些域正在连接对我的指标很有帮助。
到目前为止,我已经考虑过使用 API 密钥、自定义标头(引用标头)或读取 CORS 标头。
我为你创建了一个非常简单的项目。请先检查测试结果。
我建议您可以生成一个临时安全密钥并在创建集线器连接时附加它。就像我的示例代码中的 112233 。
客户端创建连接
var connection = new signalR.HubConnectionBuilder().withUrl("/mainHub?uid=jason&sid=test&api-key=112233").configureLogging(signalR.LogLevel.Trace).build();
然后你可以在
OnConnectedAsync
中验证它,当然你可以将temp key存储在redis缓存或db中,当它被使用时你可以删除它。它可以更安全地保护您的连接。
public override async Task OnConnectedAsync()
{
string? apiKey = Context?.GetHttpContext()?.Request.Query["api-key"].ToString();
// retrieve apiKey from db/redis and verify it
if (apiKey == "112233")
{
await base.OnConnectedAsync();
}
else
{
Context.Abort();
}
// Get HttpContext In asp.net core signalr
//IHttpContextFeature? hcf = this.Context?.Features?[typeof(IHttpContextFeature)] as IHttpContextFeature;
//HttpContext? hc = hcf?.HttpContext;
// userid
string? uid = Context?.GetHttpContext()?.Request.Query["uid"].ToString();
//string? uid = hc?.Request?.Path.Value?.Split(new string[] { "=", "" }, StringSplitOptions.RemoveEmptyEntries)[1].ToString();
// system id
string? sid = Context.GetHttpContext()?.Request.Query["sid"].ToString();
//string? sid = hc?.Request?.Path.Value?.Split(new string[] { "=", "" }, StringSplitOptions.RemoveEmptyEntries)[2].ToString();
string? userid = uid;
if (userid == null || userid.Equals(string.Empty))
{
Trace.TraceInformation("userid is required, can't connect signalr service");
return;
}
Trace.TraceInformation(userid + "connected");
// save connection
List<string>? existUserConnectionIds;
ConnectedUsers.TryGetValue(userid, out existUserConnectionIds);
if (existUserConnectionIds == null)
{
existUserConnectionIds = new List<string>();
}
existUserConnectionIds.Add(Context!.ConnectionId);
ConnectedUsers.TryAdd(userid, existUserConnectionIds);
await base.OnConnectedAsync();
}
您可以使用
var domain = context?.Request.Host.Host;
检索 HubLogFilter
内的域。这是给您的示例代码。
using Microsoft.AspNetCore.SignalR;
using System.Text.Json;
using System.Text;
namespace AspNetCore_SignalR
{
public class HubLogFilter : IHubFilter
{
private readonly ILogger<HubLogFilter> _logger;
private readonly HashSet<string> allowedDomains;
public HubLogFilter(ILogger<HubLogFilter> logger)
{
_logger = logger;
allowedDomains = new HashSet<string>
{
"a.com",
"b.com",
"localhost"
};
}
public async ValueTask<object?> InvokeMethodAsync(
HubInvocationContext invocationContext,
Func<HubInvocationContext, ValueTask<object?>> next
)
{
var startTime = DateTimeOffset.Now;
var context = invocationContext.Context.GetHttpContext();
var remoteIp = GetRemoteIpAddress(context);
var userId = invocationContext.Context.UserIdentifier ?? "NonUser";
// get the domain
var domain = context?.Request.Host.Host;
// check the domain
if (!IsDomainAllowed(domain))
{
_logger.LogWarning("Blocked WebSocket connection from {Domain}", domain);
context.Abort(); // if not allowed, abort the connection
return null;
}
try
{
var result = await next(invocationContext);
result = "failed";
// this means we can find the message sent from client
if (invocationContext.HubMethodArguments.Count > 0)
{
result = "succeed";
}
var elapsed = DateTimeOffset.Now - startTime;
var arguments = invocationContext.HubMethodArguments;
var contentLength = System.Text.Json.JsonSerializer.Serialize(arguments).Length;
//var contentLength = CalculateContentLength(invocationContext.HubMethodArguments);
_logger.LogInformation(
"WebSocket {RequestPath}/{HubMethodName} {RemoteIpAddress} {UserId} executed in {Elapsed:0.0000} ms with content length {ContentLength}",
context?.Request.Path,
invocationContext.HubMethodName,
remoteIp,
userId,
elapsed.TotalMilliseconds,
contentLength
);
return result;
}
catch (Exception ex)
{
_logger.LogError($"Exception calling '{invocationContext.HubMethodName}': {ex}");
throw;
}
}
private static string GetRemoteIpAddress(HttpContext? context)
{
var remoteIp = context?.Request.Headers?["X-Forwarded-For"];
if (string.IsNullOrEmpty(remoteIp))
{
remoteIp = context?.Connection.RemoteIpAddress?.ToString();
}
return remoteIp ?? "Unknown";
}
private static int CalculateContentLength(object?[] arguments)
{
var serializedArgs = JsonSerializer.Serialize(arguments);
return Encoding.UTF8.GetByteCount(serializedArgs);
}
private bool IsDomainAllowed(string domain)
{
return allowedDomains.Contains(domain);
}
}
}
我的程序.cs
builder.Services.AddSignalR(hubOptions => {
hubOptions.AddFilter<HubLogFilter>();
});