Refactor endpoints and authorization

This commit is contained in:
chylex 2025-03-27 09:03:58 +01:00
parent b2276600c7
commit 3b569ad5d6
No known key found for this signature in database
11 changed files with 56 additions and 64 deletions

View File

@ -3,18 +3,17 @@ using System.Net;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Database;
using DHT.Utils.Http;
using DHT.Utils.Logging;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.StaticFiles;
using Microsoft.Extensions.Primitives;
namespace DHT.Server.Endpoints;
abstract class BaseEndpoint(IDatabaseFile db) {
abstract class BaseEndpoint {
private static readonly Log Log = Log.ForType<BaseEndpoint>();
protected IDatabaseFile Db { get; } = db;
private static readonly FileExtensionContentTypeProvider ContentTypeProvider = new ();
public async Task Handle(HttpContext ctx) {
HttpResponse response = ctx.Response;
@ -49,6 +48,16 @@ abstract class BaseEndpoint(IDatabaseFile db) {
}
}
protected static async Task WriteFileIfFound(HttpResponse response, string relativeFilePath, byte[]? bytes, CancellationToken cancellationToken) {
if (bytes == null) {
throw new HttpException(HttpStatusCode.NotFound, "File not found: " + relativeFilePath);
}
else {
string? contentType = ContentTypeProvider.TryGetContentType(relativeFilePath, out string? type) ? type : null;
await response.WriteFileAsync(contentType, bytes, cancellationToken);
}
}
protected static Guid GetSessionId(HttpRequest request) {
if (request.Query.TryGetValue("session", out StringValues sessionIdValue) && sessionIdValue.Count == 1 && Guid.TryParse(sessionIdValue[0], out Guid sessionId)) {
return sessionId;

View File

@ -10,12 +10,12 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class GetDownloadedFileEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
sealed class GetDownloadedFileEndpoint(IDatabaseFile db) : BaseEndpoint {
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
string url = WebUtility.UrlDecode((string) request.RouteValues["url"]!);
string normalizedUrl = DiscordCdn.NormalizeUrl(url);
if (!await Db.Downloads.GetSuccessfulDownloadWithData(normalizedUrl, WriteDataTo(response), cancellationToken)) {
if (!await db.Downloads.GetSuccessfulDownloadWithData(normalizedUrl, WriteDataTo(response), cancellationToken)) {
response.Redirect(url, permanent: false);
}
}

View File

@ -2,7 +2,6 @@ using System.Net.Mime;
using System.Threading;
using System.Threading.Tasks;
using System.Web;
using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Http;
using DHT.Utils.Resources;
@ -10,7 +9,7 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class GetTrackingScriptEndpoint(IDatabaseFile db, ServerParameters parameters, ResourceLoader resources) : BaseEndpoint(db) {
sealed class GetTrackingScriptEndpoint(ServerParameters parameters, ResourceLoader resources) : BaseEndpoint {
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
string bootstrap = await resources.ReadTextAsync("Tracker/bootstrap.js");
string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + parameters.Port + ";")

View File

@ -8,12 +8,12 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class GetViewerMessagesEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint(db) {
sealed class GetViewerMessagesEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint {
protected override Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
Guid sessionId = GetSessionId(request);
ViewerSession session = viewerSessions.Get(sessionId);
response.ContentType = "application/x-ndjson";
return ViewerJsonExport.GetMessages(response.Body, Db, session.MessageFilter, cancellationToken);
return ViewerJsonExport.GetMessages(response.Body, db, session.MessageFilter, cancellationToken);
}
}

View File

@ -9,12 +9,12 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class GetViewerMetadataEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint(db) {
sealed class GetViewerMetadataEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint {
protected override Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
Guid sessionId = GetSessionId(request);
ViewerSession session = viewerSessions.Get(sessionId);
response.ContentType = MediaTypeNames.Application.Json;
return ViewerJsonExport.GetMetadata(response.Body, Db, session.MessageFilter, cancellationToken);
return ViewerJsonExport.GetMetadata(response.Body, db, session.MessageFilter, cancellationToken);
}
}

View File

@ -9,14 +9,14 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackChannelEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
sealed class TrackChannelEndpoint(IDatabaseFile db) : BaseEndpoint {
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
JsonElement root = await ReadJson(request);
Data.Server server = ReadServer(root.RequireObject("server"), "server");
Channel channel = ReadChannel(root.RequireObject("channel"), "channel", server.Id);
await Db.Servers.Add([server]);
await Db.Channels.Add([channel]);
await db.Servers.Add([server]);
await db.Channels.Add([channel]);
}
private static Data.Server ReadServer(JsonElement json, string path) {

View File

@ -16,7 +16,7 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint {
private const string HasNewMessages = "1";
private const string NoNewMessages = "0";
@ -38,9 +38,9 @@ sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
}
var addedMessageFilter = new MessageFilter { MessageIds = addedMessageIds };
bool anyNewMessages = await Db.Messages.Count(addedMessageFilter, CancellationToken.None) < addedMessageIds.Count;
bool anyNewMessages = await db.Messages.Count(addedMessageFilter, CancellationToken.None) < addedMessageIds.Count;
await Db.Messages.Add(messages);
await db.Messages.Add(messages);
await response.WriteTextAsync(anyNewMessages ? HasNewMessages : NoNewMessages, cancellationToken);
}

View File

@ -9,7 +9,7 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint {
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
JsonElement root = await ReadJson(request);
@ -24,7 +24,7 @@ sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
users[i++] = ReadUser(user, "user");
}
await Db.Users.Add(users);
await db.Users.Add(users);
}
private static User ReadUser(JsonElement json, string path) {

View File

@ -1,18 +1,14 @@
using System.Collections.Generic;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Database;
using DHT.Utils.Http;
using DHT.Server.Service.Middlewares;
using DHT.Utils.Resources;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.StaticFiles;
namespace DHT.Server.Endpoints;
sealed class ViewerEndpoint(IDatabaseFile db, ResourceLoader resources) : BaseEndpoint(db) {
private static readonly FileExtensionContentTypeProvider ContentTypeProvider = new ();
[ServerAuthorizationMiddleware.NoAuthorization]
sealed class ViewerEndpoint(ResourceLoader resources) : BaseEndpoint {
private readonly Dictionary<string, byte[]?> cache = new ();
private readonly SemaphoreSlim cacheSemaphore = new (1);
@ -31,12 +27,6 @@ sealed class ViewerEndpoint(IDatabaseFile db, ResourceLoader resources) : BaseEn
cacheSemaphore.Release();
}
if (resourceBytes == null) {
throw new HttpException(HttpStatusCode.NotFound, "File not found: " + path);
}
else {
string? contentType = ContentTypeProvider.TryGetContentType(path, out string? type) ? type : null;
await response.WriteFileAsync(contentType, resourceBytes, cancellationToken);
}
await WriteFileIfFound(response, path, resourceBytes, cancellationToken);
}
}

View File

@ -1,4 +1,6 @@
using System;
using System.Net;
using System.Reflection;
using System.Threading.Tasks;
using DHT.Utils.Logging;
using Microsoft.AspNetCore.Http;
@ -6,25 +8,11 @@ using Microsoft.Extensions.Primitives;
namespace DHT.Server.Service.Middlewares;
sealed class ServerAuthorizationMiddleware {
sealed class ServerAuthorizationMiddleware(RequestDelegate next, ServerParameters serverParameters) {
private static readonly Log Log = Log.ForType<ServerAuthorizationMiddleware>();
private readonly RequestDelegate next;
private readonly ServerParameters serverParameters;
public ServerAuthorizationMiddleware(RequestDelegate next, ServerParameters serverParameters) {
this.next = next;
this.serverParameters = serverParameters;
}
public async Task InvokeAsync(HttpContext context) {
HttpRequest request = context.Request;
bool success = HttpMethods.IsGet(request.Method)
? CheckToken(request.Query["token"])
: CheckToken(request.Headers["X-DHT-Token"]);
if (success) {
if (SkipAuthorization(context) || CheckToken(context.Request)) {
await next(context);
}
else {
@ -32,6 +20,16 @@ sealed class ServerAuthorizationMiddleware {
}
}
private static bool SkipAuthorization(HttpContext context) {
return context.GetEndpoint()?.RequestDelegate?.Target?.GetType().GetCustomAttribute<NoAuthorization>() != null;
}
private bool CheckToken(HttpRequest request) {
return HttpMethods.IsGet(request.Method)
? CheckToken(request.Query["token"])
: CheckToken(request.Headers["X-DHT-Token"]);
}
private bool CheckToken(StringValues token) {
if (token.Count == 1 && token[0] == serverParameters.Token) {
return true;
@ -41,4 +39,7 @@ sealed class ServerAuthorizationMiddleware {
return false;
}
}
[AttributeUsage(AttributeTargets.Class)]
public sealed class NoAuthorization : Attribute;
}

View File

@ -38,25 +38,18 @@ sealed class Startup {
public void Configure(IApplicationBuilder app, IHostApplicationLifetime lifetime, IDatabaseFile db, ServerParameters parameters, ResourceLoader resources, ViewerSessions viewerSessions) {
app.UseMiddleware<ServerLoggingMiddleware>();
app.UseCors();
app.Map("/viewer", node => {
node.UseRouting();
node.UseEndpoints(endpoints => {
endpoints.MapGet("/{**path}", new ViewerEndpoint(db, resources).Handle);
});
});
app.UseMiddleware<ServerAuthorizationMiddleware>();
app.UseRouting();
app.UseMiddleware<ServerAuthorizationMiddleware>();
app.UseEndpoints(endpoints => {
endpoints.MapGet("/get-tracking-script", new GetTrackingScriptEndpoint(db, parameters, resources).Handle);
endpoints.MapGet("/get-viewer-metadata", new GetViewerMetadataEndpoint(db, viewerSessions).Handle);
endpoints.MapGet("/get-viewer-messages", new GetViewerMessagesEndpoint(db, viewerSessions).Handle);
endpoints.MapGet("/get-downloaded-file/{url}", new GetDownloadedFileEndpoint(db).Handle);
endpoints.MapGet("/get-tracking-script", new GetTrackingScriptEndpoint(parameters, resources).Handle);
endpoints.MapGet("/get-viewer-messages", new GetViewerMessagesEndpoint(db, viewerSessions).Handle);
endpoints.MapGet("/get-viewer-metadata", new GetViewerMetadataEndpoint(db, viewerSessions).Handle);
endpoints.MapGet("/viewer/{**path}", new ViewerEndpoint(resources).Handle);
endpoints.MapPost("/track-channel", new TrackChannelEndpoint(db).Handle);
endpoints.MapPost("/track-users", new TrackUsersEndpoint(db).Handle);
endpoints.MapPost("/track-messages", new TrackMessagesEndpoint(db).Handle);
endpoints.MapPost("/track-users", new TrackUsersEndpoint(db).Handle);
});
}
}