mirror of
https://github.com/chylex/Discord-History-Tracker.git
synced 2025-04-12 06:50:01 +03:00
Refactor endpoints and authorization
This commit is contained in:
parent
b2276600c7
commit
3b569ad5d6
@ -3,18 +3,17 @@ using System.Net;
|
||||
using System.Text.Json;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using DHT.Server.Database;
|
||||
using DHT.Utils.Http;
|
||||
using DHT.Utils.Logging;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.StaticFiles;
|
||||
using Microsoft.Extensions.Primitives;
|
||||
|
||||
namespace DHT.Server.Endpoints;
|
||||
|
||||
abstract class BaseEndpoint(IDatabaseFile db) {
|
||||
abstract class BaseEndpoint {
|
||||
private static readonly Log Log = Log.ForType<BaseEndpoint>();
|
||||
|
||||
protected IDatabaseFile Db { get; } = db;
|
||||
private static readonly FileExtensionContentTypeProvider ContentTypeProvider = new ();
|
||||
|
||||
public async Task Handle(HttpContext ctx) {
|
||||
HttpResponse response = ctx.Response;
|
||||
@ -49,6 +48,16 @@ abstract class BaseEndpoint(IDatabaseFile db) {
|
||||
}
|
||||
}
|
||||
|
||||
protected static async Task WriteFileIfFound(HttpResponse response, string relativeFilePath, byte[]? bytes, CancellationToken cancellationToken) {
|
||||
if (bytes == null) {
|
||||
throw new HttpException(HttpStatusCode.NotFound, "File not found: " + relativeFilePath);
|
||||
}
|
||||
else {
|
||||
string? contentType = ContentTypeProvider.TryGetContentType(relativeFilePath, out string? type) ? type : null;
|
||||
await response.WriteFileAsync(contentType, bytes, cancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
protected static Guid GetSessionId(HttpRequest request) {
|
||||
if (request.Query.TryGetValue("session", out StringValues sessionIdValue) && sessionIdValue.Count == 1 && Guid.TryParse(sessionIdValue[0], out Guid sessionId)) {
|
||||
return sessionId;
|
||||
|
@ -10,12 +10,12 @@ using Microsoft.AspNetCore.Http;
|
||||
|
||||
namespace DHT.Server.Endpoints;
|
||||
|
||||
sealed class GetDownloadedFileEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
|
||||
sealed class GetDownloadedFileEndpoint(IDatabaseFile db) : BaseEndpoint {
|
||||
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
|
||||
string url = WebUtility.UrlDecode((string) request.RouteValues["url"]!);
|
||||
string normalizedUrl = DiscordCdn.NormalizeUrl(url);
|
||||
|
||||
if (!await Db.Downloads.GetSuccessfulDownloadWithData(normalizedUrl, WriteDataTo(response), cancellationToken)) {
|
||||
if (!await db.Downloads.GetSuccessfulDownloadWithData(normalizedUrl, WriteDataTo(response), cancellationToken)) {
|
||||
response.Redirect(url, permanent: false);
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ using System.Net.Mime;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Web;
|
||||
using DHT.Server.Database;
|
||||
using DHT.Server.Service;
|
||||
using DHT.Utils.Http;
|
||||
using DHT.Utils.Resources;
|
||||
@ -10,7 +9,7 @@ using Microsoft.AspNetCore.Http;
|
||||
|
||||
namespace DHT.Server.Endpoints;
|
||||
|
||||
sealed class GetTrackingScriptEndpoint(IDatabaseFile db, ServerParameters parameters, ResourceLoader resources) : BaseEndpoint(db) {
|
||||
sealed class GetTrackingScriptEndpoint(ServerParameters parameters, ResourceLoader resources) : BaseEndpoint {
|
||||
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
|
||||
string bootstrap = await resources.ReadTextAsync("Tracker/bootstrap.js");
|
||||
string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + parameters.Port + ";")
|
||||
|
@ -8,12 +8,12 @@ using Microsoft.AspNetCore.Http;
|
||||
|
||||
namespace DHT.Server.Endpoints;
|
||||
|
||||
sealed class GetViewerMessagesEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint(db) {
|
||||
sealed class GetViewerMessagesEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint {
|
||||
protected override Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
|
||||
Guid sessionId = GetSessionId(request);
|
||||
ViewerSession session = viewerSessions.Get(sessionId);
|
||||
|
||||
response.ContentType = "application/x-ndjson";
|
||||
return ViewerJsonExport.GetMessages(response.Body, Db, session.MessageFilter, cancellationToken);
|
||||
return ViewerJsonExport.GetMessages(response.Body, db, session.MessageFilter, cancellationToken);
|
||||
}
|
||||
}
|
||||
|
@ -9,12 +9,12 @@ using Microsoft.AspNetCore.Http;
|
||||
|
||||
namespace DHT.Server.Endpoints;
|
||||
|
||||
sealed class GetViewerMetadataEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint(db) {
|
||||
sealed class GetViewerMetadataEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint {
|
||||
protected override Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
|
||||
Guid sessionId = GetSessionId(request);
|
||||
ViewerSession session = viewerSessions.Get(sessionId);
|
||||
|
||||
response.ContentType = MediaTypeNames.Application.Json;
|
||||
return ViewerJsonExport.GetMetadata(response.Body, Db, session.MessageFilter, cancellationToken);
|
||||
return ViewerJsonExport.GetMetadata(response.Body, db, session.MessageFilter, cancellationToken);
|
||||
}
|
||||
}
|
||||
|
@ -9,14 +9,14 @@ using Microsoft.AspNetCore.Http;
|
||||
|
||||
namespace DHT.Server.Endpoints;
|
||||
|
||||
sealed class TrackChannelEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
|
||||
sealed class TrackChannelEndpoint(IDatabaseFile db) : BaseEndpoint {
|
||||
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
|
||||
JsonElement root = await ReadJson(request);
|
||||
Data.Server server = ReadServer(root.RequireObject("server"), "server");
|
||||
Channel channel = ReadChannel(root.RequireObject("channel"), "channel", server.Id);
|
||||
|
||||
await Db.Servers.Add([server]);
|
||||
await Db.Channels.Add([channel]);
|
||||
await db.Servers.Add([server]);
|
||||
await db.Channels.Add([channel]);
|
||||
}
|
||||
|
||||
private static Data.Server ReadServer(JsonElement json, string path) {
|
||||
|
@ -16,7 +16,7 @@ using Microsoft.AspNetCore.Http;
|
||||
|
||||
namespace DHT.Server.Endpoints;
|
||||
|
||||
sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
|
||||
sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint {
|
||||
private const string HasNewMessages = "1";
|
||||
private const string NoNewMessages = "0";
|
||||
|
||||
@ -38,9 +38,9 @@ sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
|
||||
}
|
||||
|
||||
var addedMessageFilter = new MessageFilter { MessageIds = addedMessageIds };
|
||||
bool anyNewMessages = await Db.Messages.Count(addedMessageFilter, CancellationToken.None) < addedMessageIds.Count;
|
||||
bool anyNewMessages = await db.Messages.Count(addedMessageFilter, CancellationToken.None) < addedMessageIds.Count;
|
||||
|
||||
await Db.Messages.Add(messages);
|
||||
await db.Messages.Add(messages);
|
||||
|
||||
await response.WriteTextAsync(anyNewMessages ? HasNewMessages : NoNewMessages, cancellationToken);
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ using Microsoft.AspNetCore.Http;
|
||||
|
||||
namespace DHT.Server.Endpoints;
|
||||
|
||||
sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
|
||||
sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint {
|
||||
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
|
||||
JsonElement root = await ReadJson(request);
|
||||
|
||||
@ -24,7 +24,7 @@ sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
|
||||
users[i++] = ReadUser(user, "user");
|
||||
}
|
||||
|
||||
await Db.Users.Add(users);
|
||||
await db.Users.Add(users);
|
||||
}
|
||||
|
||||
private static User ReadUser(JsonElement json, string path) {
|
||||
|
@ -1,18 +1,14 @@
|
||||
using System.Collections.Generic;
|
||||
using System.Net;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using DHT.Server.Database;
|
||||
using DHT.Utils.Http;
|
||||
using DHT.Server.Service.Middlewares;
|
||||
using DHT.Utils.Resources;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.StaticFiles;
|
||||
|
||||
namespace DHT.Server.Endpoints;
|
||||
|
||||
sealed class ViewerEndpoint(IDatabaseFile db, ResourceLoader resources) : BaseEndpoint(db) {
|
||||
private static readonly FileExtensionContentTypeProvider ContentTypeProvider = new ();
|
||||
|
||||
[ServerAuthorizationMiddleware.NoAuthorization]
|
||||
sealed class ViewerEndpoint(ResourceLoader resources) : BaseEndpoint {
|
||||
private readonly Dictionary<string, byte[]?> cache = new ();
|
||||
private readonly SemaphoreSlim cacheSemaphore = new (1);
|
||||
|
||||
@ -31,12 +27,6 @@ sealed class ViewerEndpoint(IDatabaseFile db, ResourceLoader resources) : BaseEn
|
||||
cacheSemaphore.Release();
|
||||
}
|
||||
|
||||
if (resourceBytes == null) {
|
||||
throw new HttpException(HttpStatusCode.NotFound, "File not found: " + path);
|
||||
}
|
||||
else {
|
||||
string? contentType = ContentTypeProvider.TryGetContentType(path, out string? type) ? type : null;
|
||||
await response.WriteFileAsync(contentType, resourceBytes, cancellationToken);
|
||||
}
|
||||
await WriteFileIfFound(response, path, resourceBytes, cancellationToken);
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,6 @@
|
||||
using System;
|
||||
using System.Net;
|
||||
using System.Reflection;
|
||||
using System.Threading.Tasks;
|
||||
using DHT.Utils.Logging;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
@ -6,25 +8,11 @@ using Microsoft.Extensions.Primitives;
|
||||
|
||||
namespace DHT.Server.Service.Middlewares;
|
||||
|
||||
sealed class ServerAuthorizationMiddleware {
|
||||
sealed class ServerAuthorizationMiddleware(RequestDelegate next, ServerParameters serverParameters) {
|
||||
private static readonly Log Log = Log.ForType<ServerAuthorizationMiddleware>();
|
||||
|
||||
private readonly RequestDelegate next;
|
||||
private readonly ServerParameters serverParameters;
|
||||
|
||||
public ServerAuthorizationMiddleware(RequestDelegate next, ServerParameters serverParameters) {
|
||||
this.next = next;
|
||||
this.serverParameters = serverParameters;
|
||||
}
|
||||
|
||||
public async Task InvokeAsync(HttpContext context) {
|
||||
HttpRequest request = context.Request;
|
||||
|
||||
bool success = HttpMethods.IsGet(request.Method)
|
||||
? CheckToken(request.Query["token"])
|
||||
: CheckToken(request.Headers["X-DHT-Token"]);
|
||||
|
||||
if (success) {
|
||||
if (SkipAuthorization(context) || CheckToken(context.Request)) {
|
||||
await next(context);
|
||||
}
|
||||
else {
|
||||
@ -32,6 +20,16 @@ sealed class ServerAuthorizationMiddleware {
|
||||
}
|
||||
}
|
||||
|
||||
private static bool SkipAuthorization(HttpContext context) {
|
||||
return context.GetEndpoint()?.RequestDelegate?.Target?.GetType().GetCustomAttribute<NoAuthorization>() != null;
|
||||
}
|
||||
|
||||
private bool CheckToken(HttpRequest request) {
|
||||
return HttpMethods.IsGet(request.Method)
|
||||
? CheckToken(request.Query["token"])
|
||||
: CheckToken(request.Headers["X-DHT-Token"]);
|
||||
}
|
||||
|
||||
private bool CheckToken(StringValues token) {
|
||||
if (token.Count == 1 && token[0] == serverParameters.Token) {
|
||||
return true;
|
||||
@ -41,4 +39,7 @@ sealed class ServerAuthorizationMiddleware {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
[AttributeUsage(AttributeTargets.Class)]
|
||||
public sealed class NoAuthorization : Attribute;
|
||||
}
|
||||
|
@ -38,25 +38,18 @@ sealed class Startup {
|
||||
public void Configure(IApplicationBuilder app, IHostApplicationLifetime lifetime, IDatabaseFile db, ServerParameters parameters, ResourceLoader resources, ViewerSessions viewerSessions) {
|
||||
app.UseMiddleware<ServerLoggingMiddleware>();
|
||||
app.UseCors();
|
||||
|
||||
app.Map("/viewer", node => {
|
||||
node.UseRouting();
|
||||
node.UseEndpoints(endpoints => {
|
||||
endpoints.MapGet("/{**path}", new ViewerEndpoint(db, resources).Handle);
|
||||
});
|
||||
});
|
||||
|
||||
app.UseMiddleware<ServerAuthorizationMiddleware>();
|
||||
|
||||
app.UseRouting();
|
||||
app.UseMiddleware<ServerAuthorizationMiddleware>();
|
||||
app.UseEndpoints(endpoints => {
|
||||
endpoints.MapGet("/get-tracking-script", new GetTrackingScriptEndpoint(db, parameters, resources).Handle);
|
||||
endpoints.MapGet("/get-viewer-metadata", new GetViewerMetadataEndpoint(db, viewerSessions).Handle);
|
||||
endpoints.MapGet("/get-viewer-messages", new GetViewerMessagesEndpoint(db, viewerSessions).Handle);
|
||||
endpoints.MapGet("/get-downloaded-file/{url}", new GetDownloadedFileEndpoint(db).Handle);
|
||||
endpoints.MapGet("/get-tracking-script", new GetTrackingScriptEndpoint(parameters, resources).Handle);
|
||||
endpoints.MapGet("/get-viewer-messages", new GetViewerMessagesEndpoint(db, viewerSessions).Handle);
|
||||
endpoints.MapGet("/get-viewer-metadata", new GetViewerMetadataEndpoint(db, viewerSessions).Handle);
|
||||
endpoints.MapGet("/viewer/{**path}", new ViewerEndpoint(resources).Handle);
|
||||
|
||||
endpoints.MapPost("/track-channel", new TrackChannelEndpoint(db).Handle);
|
||||
endpoints.MapPost("/track-users", new TrackUsersEndpoint(db).Handle);
|
||||
endpoints.MapPost("/track-messages", new TrackMessagesEndpoint(db).Handle);
|
||||
endpoints.MapPost("/track-users", new TrackUsersEndpoint(db).Handle);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user