如何从 SignalR 获取连接客户端的域

问题描述 投票:0回答:1

我有一个在 ASP.NET Core 7 中内置的 API。我有一个 MediatR 控制器,可以读取消息并响应消息。客户端正在从 ReactJS 网站进行连接。我计划在不同的域上多次部署客户端站点。

我正在尝试找到最不易欺骗的方式来获取客户端连接的域。我能够知道允许连接哪些域,因此如果有帮助的话我可以使用 API 密钥。虽然安全性并不是一个大问题,但了解哪些域正在连接对我的指标很有帮助。

到目前为止,我已经考虑过使用 API 密钥、自定义标头(引用标头)或读取 CORS 标头。

c# asp.net-core signalr
1个回答
0
投票

我为你创建了一个非常简单的项目。请先检查测试结果。

我建议您可以生成一个临时安全密钥并在创建集线器连接时附加它。就像我的示例代码中的 112233 。

客户端创建连接

var connection = new signalR.HubConnectionBuilder().withUrl("/mainHub?uid=jason&sid=test&api-key=112233").configureLogging(signalR.LogLevel.Trace).build();

然后你可以在

OnConnectedAsync
中验证它,当然你可以将temp key存储在redis缓存或db中,当它被使用时你可以删除它。它可以更安全地保护您的连接。

    public override async Task OnConnectedAsync()
    {
        string? apiKey = Context?.GetHttpContext()?.Request.Query["api-key"].ToString();
        // retrieve apiKey from db/redis and verify it
        if (apiKey == "112233")
        {
            await base.OnConnectedAsync();
        }
        else
        {
            Context.Abort();
        }

        // Get HttpContext In asp.net core signalr
        //IHttpContextFeature? hcf = this.Context?.Features?[typeof(IHttpContextFeature)] as IHttpContextFeature;
        //HttpContext? hc = hcf?.HttpContext;
        // userid
        string? uid = Context?.GetHttpContext()?.Request.Query["uid"].ToString();
        //string? uid = hc?.Request?.Path.Value?.Split(new string[] { "=", "" }, StringSplitOptions.RemoveEmptyEntries)[1].ToString();
        // system id
        string? sid = Context.GetHttpContext()?.Request.Query["sid"].ToString();
        //string? sid = hc?.Request?.Path.Value?.Split(new string[] { "=", "" }, StringSplitOptions.RemoveEmptyEntries)[2].ToString();

        string? userid = uid;

        if (userid == null || userid.Equals(string.Empty))
        {
            Trace.TraceInformation("userid is required, can't connect signalr service");
            return;
        }
        Trace.TraceInformation(userid + "connected");
        // save connection
        List<string>? existUserConnectionIds;
        ConnectedUsers.TryGetValue(userid, out existUserConnectionIds);
        if (existUserConnectionIds == null)
        {
            existUserConnectionIds = new List<string>();
        }
        existUserConnectionIds.Add(Context!.ConnectionId);
        ConnectedUsers.TryAdd(userid, existUserConnectionIds);

        await base.OnConnectedAsync();
    }

您可以使用

var domain = context?.Request.Host.Host;
检索
HubLogFilter
内的域。这是给您的示例代码。

using Microsoft.AspNetCore.SignalR;
using System.Text.Json;
using System.Text;

namespace AspNetCore_SignalR
{
    public class HubLogFilter : IHubFilter
    {
        private readonly ILogger<HubLogFilter> _logger;
        private readonly HashSet<string> allowedDomains;

        public HubLogFilter(ILogger<HubLogFilter> logger)
        {
            _logger = logger;
            allowedDomains = new HashSet<string>
            {
                "a.com",
                "b.com",
                "localhost"
            };
        }

        public async ValueTask<object?> InvokeMethodAsync(
            HubInvocationContext invocationContext,
            Func<HubInvocationContext, ValueTask<object?>> next
        )
        {
            var startTime = DateTimeOffset.Now;
            var context = invocationContext.Context.GetHttpContext();
            var remoteIp = GetRemoteIpAddress(context);
            var userId = invocationContext.Context.UserIdentifier ?? "NonUser";
            // get the domain
            var domain = context?.Request.Host.Host;

            // check the domain
            if (!IsDomainAllowed(domain))
            {
                _logger.LogWarning("Blocked WebSocket connection from {Domain}", domain);
                context.Abort(); // if not allowed, abort the connection
                return null;
            }



            try
            {
                var result = await next(invocationContext);

                result = "failed";
                // this means we can find the message sent from client
                if (invocationContext.HubMethodArguments.Count > 0)
                {
                    result = "succeed";
                }

                var elapsed = DateTimeOffset.Now - startTime;
                var arguments = invocationContext.HubMethodArguments;
                var contentLength = System.Text.Json.JsonSerializer.Serialize(arguments).Length;

                //var contentLength = CalculateContentLength(invocationContext.HubMethodArguments);

                _logger.LogInformation(
                    "WebSocket {RequestPath}/{HubMethodName} {RemoteIpAddress} {UserId} executed in {Elapsed:0.0000} ms with content length {ContentLength}",
                    context?.Request.Path,
                    invocationContext.HubMethodName,
                    remoteIp,
                    userId,
                    elapsed.TotalMilliseconds,
                    contentLength
                );

                return result;
            }
            catch (Exception ex)
            {
                _logger.LogError($"Exception calling '{invocationContext.HubMethodName}': {ex}");
                throw;
            }
        }

        private static string GetRemoteIpAddress(HttpContext? context)
        {
            var remoteIp = context?.Request.Headers?["X-Forwarded-For"];
            if (string.IsNullOrEmpty(remoteIp))
            {
                remoteIp = context?.Connection.RemoteIpAddress?.ToString();
            }

            return remoteIp ?? "Unknown";
        }

        private static int CalculateContentLength(object?[] arguments)
        {
            var serializedArgs = JsonSerializer.Serialize(arguments);
            return Encoding.UTF8.GetByteCount(serializedArgs);
        }
        private bool IsDomainAllowed(string domain)
        {
            return allowedDomains.Contains(domain);
        }
    }
}

我的程序.cs

        builder.Services.AddSignalR(hubOptions => {
            hubOptions.AddFilter<HubLogFilter>();
        });
© www.soinside.com 2019 - 2024. All rights reserved.