diff --git a/app/Desktop/Main/Pages/DatabasePage.axaml b/app/Desktop/Main/Pages/DatabasePage.axaml
index 6f320d9..7cc5bd0 100644
--- a/app/Desktop/Main/Pages/DatabasePage.axaml
+++ b/app/Desktop/Main/Pages/DatabasePage.axaml
@@ -24,6 +24,7 @@
+
diff --git a/app/Desktop/Main/Pages/DatabasePageModel.cs b/app/Desktop/Main/Pages/DatabasePageModel.cs
index 2fb4a58..d31e1f5 100644
--- a/app/Desktop/Main/Pages/DatabasePageModel.cs
+++ b/app/Desktop/Main/Pages/DatabasePageModel.cs
@@ -1,6 +1,8 @@
using System;
+using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
+using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -9,7 +11,10 @@ using Avalonia.Threading;
using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
+using DHT.Desktop.Dialogs.TextBox;
+using DHT.Server.Data;
using DHT.Server.Database;
+using DHT.Server.Database.Import;
using DHT.Utils.Logging;
using DHT.Utils.Models;
@@ -92,13 +97,13 @@ namespace DHT.Desktop.Main.Pages {
return DialogResult.YesNo.Yes == upgradeResult;
}
-
+
await PerformImport(target, paths, dialog, callback, "Database Merge", "Database Error", "database file", async path => {
SynchronizationContext? prevSyncContext = SynchronizationContext.Current;
SynchronizationContext.SetSynchronizationContext(new AvaloniaSynchronizationContext());
IDatabaseFile? db = await DatabaseGui.TryOpenOrCreateDatabaseFromPath(path, dialog, CheckCanUpgradeDatabase);
SynchronizationContext.SetSynchronizationContext(prevSyncContext);
-
+
if (db == null) {
return false;
}
@@ -112,6 +117,73 @@ namespace DHT.Desktop.Main.Pages {
});
}
+ public async void ImportLegacyArchive() {
+ var fileDialog = new OpenFileDialog {
+ Title = "Open Legacy DHT Archive",
+ Directory = Path.GetDirectoryName(Db.Path),
+ AllowMultiple = true
+ };
+
+ string[]? paths = await fileDialog.ShowAsync(window);
+ if (paths == null || paths.Length == 0) {
+ return;
+ }
+
+ ProgressDialog progressDialog = new ProgressDialog();
+ progressDialog.DataContext = new ProgressDialogModel(async callback => await ImportLegacyArchiveFromPaths(Db, paths, progressDialog, callback)) {
+ Title = "Legacy Archive Import"
+ };
+
+ await progressDialog.ShowDialog(window);
+ }
+
+ private static async Task ImportLegacyArchiveFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) {
+ var fakeSnowflake = new FakeSnowflake();
+
+ await PerformImport(target, paths, dialog, callback, "Legacy Archive Import", "Legacy Archive Error", "archive file", async path => {
+ await using var jsonStream = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read);
+
+ return await LegacyArchiveImport.Read(jsonStream, target, fakeSnowflake, async servers => {
+ SynchronizationContext? prevSyncContext = SynchronizationContext.Current;
+ SynchronizationContext.SetSynchronizationContext(new AvaloniaSynchronizationContext());
+ Dictionary? result = await Dispatcher.UIThread.InvokeAsync(() => AskForServerIds(dialog, servers));
+ SynchronizationContext.SetSynchronizationContext(prevSyncContext);
+ return result;
+ });
+ });
+ }
+
+ private static async Task?> AskForServerIds(Window window, DHT.Server.Data.Server[] servers) {
+ static bool IsValidSnowflake(string value) {
+ return string.IsNullOrEmpty(value) || ulong.TryParse(value, out _);
+ }
+
+ var items = new List>();
+
+ foreach (var server in servers.OrderBy(static server => server.Type).ThenBy(static server => server.Name)) {
+ items.Add(new TextBoxItem(server) {
+ Title = server.Name + " (" + ServerTypes.ToNiceString(server.Type) + ")",
+ ValidityCheck = IsValidSnowflake
+ });
+ }
+
+ var model = new TextBoxDialogModel(items) {
+ Title = "Imported Server IDs",
+ Description = "Please fill in the IDs of servers and direct messages. First enable Developer Mode in Discord, then right-click each server or direct message, click 'Copy ID', and paste it into the input field. If a server no longer exists, leave its input field empty to use a random ID."
+ };
+
+ var dialog = new TextBoxDialog { DataContext = model };
+ var result = await dialog.ShowDialog(window);
+
+ if (result != DialogResult.OkCancel.Ok) {
+ return null;
+ }
+
+ return model.ValidItems
+ .Where(static item => !string.IsNullOrEmpty(item.Value))
+ .ToDictionary(static item => item.Item, static item => ulong.Parse(item.Value));
+ }
+
private static async Task PerformImport(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback, string neutralDialogTitle, string errorDialogTitle, string itemName, Func> performImport) {
int total = paths.Length;
var oldStatistics = target.SnapshotStatistics();
diff --git a/app/Server/Data/ServerType.cs b/app/Server/Data/ServerType.cs
index e3cdc57..8e6e43a 100644
--- a/app/Server/Data/ServerType.cs
+++ b/app/Server/Data/ServerType.cs
@@ -24,6 +24,15 @@ namespace DHT.Server.Data {
};
}
+ public static string ToNiceString(ServerType? type) {
+ return type switch {
+ ServerType.Server => "Server",
+ ServerType.Group => "Group",
+ ServerType.DirectMessage => "DM",
+ _ => "Unknown"
+ };
+ }
+
internal static string ToJsonViewerString(ServerType? type) {
return type switch {
ServerType.Server => "server",
diff --git a/app/Server/Database/DatabaseExtensions.cs b/app/Server/Database/DatabaseExtensions.cs
index 55d69ca..38aa6e8 100644
--- a/app/Server/Database/DatabaseExtensions.cs
+++ b/app/Server/Database/DatabaseExtensions.cs
@@ -1,16 +1,11 @@
+using System.Collections.Generic;
using DHT.Server.Data;
namespace DHT.Server.Database {
public static class DatabaseExtensions {
public static void AddFrom(this IDatabaseFile target, IDatabaseFile source) {
- foreach (var server in source.GetAllServers()) {
- target.AddServer(server);
- }
-
- foreach (var channel in source.GetAllChannels()) {
- target.AddChannel(channel);
- }
-
+ target.AddServers(source.GetAllServers());
+ target.AddChannels(source.GetAllChannels());
target.AddUsers(source.GetAllUsers().ToArray());
target.AddMessages(source.GetMessages().ToArray());
@@ -18,5 +13,17 @@ namespace DHT.Server.Database {
target.AddDownload(download.Status == DownloadStatus.Success ? source.GetDownloadWithData(download) : download);
}
}
+
+ internal static void AddServers(this IDatabaseFile target, IEnumerable servers) {
+ foreach (var server in servers) {
+ target.AddServer(server);
+ }
+ }
+
+ internal static void AddChannels(this IDatabaseFile target, IEnumerable channels) {
+ foreach (var channel in channels) {
+ target.AddChannel(channel);
+ }
+ }
}
}
diff --git a/app/Server/Database/DummyDatabaseFile.cs b/app/Server/Database/DummyDatabaseFile.cs
index b771d73..ca5affd 100644
--- a/app/Server/Database/DummyDatabaseFile.cs
+++ b/app/Server/Database/DummyDatabaseFile.cs
@@ -45,6 +45,10 @@ namespace DHT.Server.Database {
return new();
}
+ public HashSet GetMessageIds(MessageFilter? filter = null) {
+ return new();
+ }
+
public void RemoveMessages(MessageFilter filter, FilterRemovalMode mode) {}
public int CountAttachments(AttachmentFilter? filter = null) {
diff --git a/app/Server/Database/IDatabaseFile.cs b/app/Server/Database/IDatabaseFile.cs
index e219097..1fb3ed8 100644
--- a/app/Server/Database/IDatabaseFile.cs
+++ b/app/Server/Database/IDatabaseFile.cs
@@ -23,6 +23,7 @@ namespace DHT.Server.Database {
void AddMessages(Message[] messages);
int CountMessages(MessageFilter? filter = null);
List GetMessages(MessageFilter? filter = null);
+ HashSet GetMessageIds(MessageFilter? filter = null);
void RemoveMessages(MessageFilter filter, FilterRemovalMode mode);
int CountAttachments(AttachmentFilter? filter = null);
diff --git a/app/Server/Database/Import/FakeSnowflake.cs b/app/Server/Database/Import/FakeSnowflake.cs
new file mode 100644
index 0000000..8478c93
--- /dev/null
+++ b/app/Server/Database/Import/FakeSnowflake.cs
@@ -0,0 +1,21 @@
+using System;
+
+namespace DHT.Server.Database.Import {
+ ///
+ /// https://discord.com/developers/docs/reference#snowflakes
+ ///
+ public sealed class FakeSnowflake {
+ private const ulong DiscordEpoch = 1420070400000UL;
+
+ private ulong id;
+
+ public FakeSnowflake() {
+ var unixMillis = (ulong) (DateTime.UtcNow.Subtract(DateTime.UnixEpoch).Ticks / TimeSpan.TicksPerMillisecond);
+ this.id = (unixMillis - DiscordEpoch) << 22;
+ }
+
+ internal ulong Next() {
+ return id++;
+ }
+ }
+}
diff --git a/app/Server/Database/Import/LegacyArchiveImport.cs b/app/Server/Database/Import/LegacyArchiveImport.cs
new file mode 100644
index 0000000..7e95744
--- /dev/null
+++ b/app/Server/Database/Import/LegacyArchiveImport.cs
@@ -0,0 +1,263 @@
+using System;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Diagnostics.CodeAnalysis;
+using System.IO;
+using System.Linq;
+using System.Text.Json;
+using System.Threading.Tasks;
+using DHT.Server.Data;
+using DHT.Utils.Collections;
+using DHT.Utils.Http;
+using DHT.Utils.Logging;
+using Microsoft.AspNetCore.StaticFiles;
+
+namespace DHT.Server.Database.Import {
+ public static class LegacyArchiveImport {
+ private static readonly Log Log = Log.ForType(typeof(LegacyArchiveImport));
+
+ private static readonly FileExtensionContentTypeProvider ContentTypeProvider = new ();
+
+ public static async Task Read(Stream stream, IDatabaseFile db, FakeSnowflake fakeSnowflake, Func?>> askForServerIds) {
+ var perf = Log.Start();
+ var root = await JsonSerializer.DeserializeAsync(stream);
+
+ try {
+ var meta = root.RequireObject("meta");
+ var data = root.RequireObject("data");
+
+ perf.Step("Deserialize JSON");
+
+ var users = ReadUserList(meta);
+ var servers = ReadServerList(meta, fakeSnowflake);
+
+ var newServersOnly = new HashSet(servers);
+ var oldServersById = db.GetAllServers().ToDictionary(static server => server.Id, static server => server);
+
+ var oldChannels = db.GetAllChannels();
+ var oldChannelsById = oldChannels.ToDictionary(static channel => channel.Id, static channel => channel);
+
+ foreach (var (channelId, serverIndex) in ReadChannelToServerIndexMapping(meta, servers)) {
+ if (oldChannelsById.TryGetValue(channelId, out var oldChannel) && oldServersById.TryGetValue(oldChannel.Server, out var oldServer) && newServersOnly.Remove(servers[serverIndex])) {
+ servers[serverIndex] = oldServer;
+ }
+ }
+
+ perf.Step("Read server and user list");
+
+ if (newServersOnly.Count > 0) {
+ var askedServerIds = await askForServerIds(newServersOnly.ToArray());
+ if (askedServerIds == null) {
+ return false;
+ }
+
+ perf.Step("Ask for server IDs");
+
+ for (var i = 0; i < servers.Length; i++) {
+ var server = servers[i];
+ if (askedServerIds.TryGetValue(server, out var serverId)) {
+ servers[i] = new Data.Server {
+ Id = serverId,
+ Name = server.Name,
+ Type = server.Type
+ };
+ }
+ }
+ }
+
+ var channels = ReadChannelList(meta, servers);
+
+ perf.Step("Read channel list");
+
+ var oldMessageIds = db.GetMessageIds();
+ var newMessages = channels.SelectMany(channel => ReadMessages(data, channel, users, fakeSnowflake))
+ .Where(message => !oldMessageIds.Contains(message.Id))
+ .ToArray();
+
+ perf.Step("Read messages");
+
+ db.AddUsers(users);
+ db.AddServers(servers);
+ db.AddChannels(channels);
+ db.AddMessages(newMessages);
+
+ perf.Step("Import into database");
+ } catch (HttpException e) {
+ throw new JsonException(e.Message);
+ }
+
+ perf.End();
+ return true;
+ }
+
+ private static User[] ReadUserList(JsonElement meta) {
+ const string UsersPath = "meta.users[]";
+
+ static ulong ParseUserIndex(JsonElement element, int index) {
+ return ulong.Parse(element.GetString() ?? throw new JsonException("Expected key 'meta.userindex[" + index + "]' to be a string."));
+ }
+
+ var userindex = meta.RequireArray("userindex", "meta")
+ .Select(static (item, index) => (ParseUserIndex(item, index), index))
+ .ToDictionary();
+
+ var users = new User[userindex.Count];
+
+ foreach (var item in meta.RequireObject("users", "meta").EnumerateObject()) {
+ var path = UsersPath + "." + item.Name;
+ var userId = ulong.Parse(item.Name);
+ var userObj = item.Value;
+
+ users[userindex[userId]] = new User {
+ Id = userId,
+ Name = userObj.RequireString("name", path),
+ AvatarUrl = userObj.HasKey("avatar") ? userObj.RequireString("avatar", path) : null,
+ Discriminator = userObj.HasKey("tag") ? userObj.RequireString("tag", path) : null
+ };
+ }
+
+ return users;
+ }
+
+ private static Data.Server[] ReadServerList(JsonElement meta, FakeSnowflake fakeSnowflake) {
+ const string ServersPath = "meta.servers[]";
+
+ return meta.RequireArray("servers", "meta").Select(serverObj => new Data.Server {
+ Id = fakeSnowflake.Next(),
+ Name = serverObj.RequireString("name", ServersPath),
+ Type = ServerTypes.FromString(serverObj.RequireString("type", ServersPath))
+ }).ToArray();
+ }
+
+ private const string ChannelsPath = "meta.channels";
+
+ private static Dictionary ReadChannelToServerIndexMapping(JsonElement meta, Data.Server[] servers) {
+ return meta.RequireObject("channels", "meta").EnumerateObject().Select(item => {
+ var path = ChannelsPath + "." + item.Name;
+ var channelId = ulong.Parse(item.Name);
+ var channelObj = item.Value;
+
+ return (channelId, channelObj.RequireInt("server", path, min: 0, max: servers.Length - 1));
+ }).ToDictionary();
+ }
+
+ private static Channel[] ReadChannelList(JsonElement meta, Data.Server[] servers) {
+ return meta.RequireObject("channels", "meta").EnumerateObject().Select(item => {
+ var path = ChannelsPath + "." + item.Name;
+ var channelId = ulong.Parse(item.Name);
+ var channelObj = item.Value;
+
+ return new Channel {
+ Id = channelId,
+ Server = servers[channelObj.RequireInt("server", path, min: 0, max: servers.Length - 1)].Id,
+ Name = channelObj.RequireString("name", path),
+ Position = channelObj.HasKey("position") ? channelObj.RequireInt("position", path, min: 0) : null,
+ Topic = channelObj.HasKey("topic") ? channelObj.RequireString("topic", path) : null,
+ Nsfw = channelObj.HasKey("nsfw") ? channelObj.RequireBool("nsfw", path) : null
+ };
+ }).ToArray();
+ }
+
+ private static Message[] ReadMessages(JsonElement data, Channel channel, User[] users, FakeSnowflake fakeSnowflake) {
+ const string DataPath = "data";
+
+ var channelId = channel.Id;
+ var channelIdStr = channelId.ToString();
+
+ var messagesObj = data.HasKey(channelIdStr) ? data.RequireObject(channelIdStr, DataPath) : (JsonElement?) null;
+ if (messagesObj == null) {
+ return Array.Empty();
+ }
+
+ return messagesObj.Value.EnumerateObject().Select(item => {
+ var path = DataPath + "." + item.Name;
+ var messageId = ulong.Parse(item.Name);
+ var messageObj = item.Value;
+
+ return new Message {
+ Id = messageId,
+ Sender = users[messageObj.RequireInt("u", path, min: 0, max: users.Length - 1)].Id,
+ Channel = channelId,
+ Text = messageObj.HasKey("m") ? messageObj.RequireString("m", path) : string.Empty,
+ Timestamp = messageObj.RequireLong("t", path),
+ EditTimestamp = messageObj.HasKey("te") ? messageObj.RequireLong("te", path) : null,
+ RepliedToId = messageObj.HasKey("r") ? messageObj.RequireSnowflake("r", path) : null,
+ Attachments = messageObj.HasKey("a") ? ReadMessageAttachments(messageObj.RequireArray("a", path), fakeSnowflake, path + ".a[]").ToImmutableArray() : ImmutableArray.Empty,
+ Embeds = messageObj.HasKey("e") ? ReadMessageEmbeds(messageObj.RequireArray("e", path), path + ".e[]").ToImmutableArray() : ImmutableArray