diff --git a/app/Desktop/Main/Pages/DatabasePage.axaml b/app/Desktop/Main/Pages/DatabasePage.axaml index 6f320d9..7cc5bd0 100644 --- a/app/Desktop/Main/Pages/DatabasePage.axaml +++ b/app/Desktop/Main/Pages/DatabasePage.axaml @@ -24,6 +24,7 @@ + diff --git a/app/Desktop/Main/Pages/DatabasePageModel.cs b/app/Desktop/Main/Pages/DatabasePageModel.cs index 2fb4a58..d31e1f5 100644 --- a/app/Desktop/Main/Pages/DatabasePageModel.cs +++ b/app/Desktop/Main/Pages/DatabasePageModel.cs @@ -1,6 +1,8 @@ using System; +using System.Collections.Generic; using System.Diagnostics; using System.IO; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -9,7 +11,10 @@ using Avalonia.Threading; using DHT.Desktop.Common; using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Progress; +using DHT.Desktop.Dialogs.TextBox; +using DHT.Server.Data; using DHT.Server.Database; +using DHT.Server.Database.Import; using DHT.Utils.Logging; using DHT.Utils.Models; @@ -92,13 +97,13 @@ namespace DHT.Desktop.Main.Pages { return DialogResult.YesNo.Yes == upgradeResult; } - + await PerformImport(target, paths, dialog, callback, "Database Merge", "Database Error", "database file", async path => { SynchronizationContext? prevSyncContext = SynchronizationContext.Current; SynchronizationContext.SetSynchronizationContext(new AvaloniaSynchronizationContext()); IDatabaseFile? db = await DatabaseGui.TryOpenOrCreateDatabaseFromPath(path, dialog, CheckCanUpgradeDatabase); SynchronizationContext.SetSynchronizationContext(prevSyncContext); - + if (db == null) { return false; } @@ -112,6 +117,73 @@ namespace DHT.Desktop.Main.Pages { }); } + public async void ImportLegacyArchive() { + var fileDialog = new OpenFileDialog { + Title = "Open Legacy DHT Archive", + Directory = Path.GetDirectoryName(Db.Path), + AllowMultiple = true + }; + + string[]? paths = await fileDialog.ShowAsync(window); + if (paths == null || paths.Length == 0) { + return; + } + + ProgressDialog progressDialog = new ProgressDialog(); + progressDialog.DataContext = new ProgressDialogModel(async callback => await ImportLegacyArchiveFromPaths(Db, paths, progressDialog, callback)) { + Title = "Legacy Archive Import" + }; + + await progressDialog.ShowDialog(window); + } + + private static async Task ImportLegacyArchiveFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) { + var fakeSnowflake = new FakeSnowflake(); + + await PerformImport(target, paths, dialog, callback, "Legacy Archive Import", "Legacy Archive Error", "archive file", async path => { + await using var jsonStream = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read); + + return await LegacyArchiveImport.Read(jsonStream, target, fakeSnowflake, async servers => { + SynchronizationContext? prevSyncContext = SynchronizationContext.Current; + SynchronizationContext.SetSynchronizationContext(new AvaloniaSynchronizationContext()); + Dictionary? result = await Dispatcher.UIThread.InvokeAsync(() => AskForServerIds(dialog, servers)); + SynchronizationContext.SetSynchronizationContext(prevSyncContext); + return result; + }); + }); + } + + private static async Task?> AskForServerIds(Window window, DHT.Server.Data.Server[] servers) { + static bool IsValidSnowflake(string value) { + return string.IsNullOrEmpty(value) || ulong.TryParse(value, out _); + } + + var items = new List>(); + + foreach (var server in servers.OrderBy(static server => server.Type).ThenBy(static server => server.Name)) { + items.Add(new TextBoxItem(server) { + Title = server.Name + " (" + ServerTypes.ToNiceString(server.Type) + ")", + ValidityCheck = IsValidSnowflake + }); + } + + var model = new TextBoxDialogModel(items) { + Title = "Imported Server IDs", + Description = "Please fill in the IDs of servers and direct messages. First enable Developer Mode in Discord, then right-click each server or direct message, click 'Copy ID', and paste it into the input field. If a server no longer exists, leave its input field empty to use a random ID." + }; + + var dialog = new TextBoxDialog { DataContext = model }; + var result = await dialog.ShowDialog(window); + + if (result != DialogResult.OkCancel.Ok) { + return null; + } + + return model.ValidItems + .Where(static item => !string.IsNullOrEmpty(item.Value)) + .ToDictionary(static item => item.Item, static item => ulong.Parse(item.Value)); + } + private static async Task PerformImport(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback, string neutralDialogTitle, string errorDialogTitle, string itemName, Func> performImport) { int total = paths.Length; var oldStatistics = target.SnapshotStatistics(); diff --git a/app/Server/Data/ServerType.cs b/app/Server/Data/ServerType.cs index e3cdc57..8e6e43a 100644 --- a/app/Server/Data/ServerType.cs +++ b/app/Server/Data/ServerType.cs @@ -24,6 +24,15 @@ namespace DHT.Server.Data { }; } + public static string ToNiceString(ServerType? type) { + return type switch { + ServerType.Server => "Server", + ServerType.Group => "Group", + ServerType.DirectMessage => "DM", + _ => "Unknown" + }; + } + internal static string ToJsonViewerString(ServerType? type) { return type switch { ServerType.Server => "server", diff --git a/app/Server/Database/DatabaseExtensions.cs b/app/Server/Database/DatabaseExtensions.cs index 55d69ca..38aa6e8 100644 --- a/app/Server/Database/DatabaseExtensions.cs +++ b/app/Server/Database/DatabaseExtensions.cs @@ -1,16 +1,11 @@ +using System.Collections.Generic; using DHT.Server.Data; namespace DHT.Server.Database { public static class DatabaseExtensions { public static void AddFrom(this IDatabaseFile target, IDatabaseFile source) { - foreach (var server in source.GetAllServers()) { - target.AddServer(server); - } - - foreach (var channel in source.GetAllChannels()) { - target.AddChannel(channel); - } - + target.AddServers(source.GetAllServers()); + target.AddChannels(source.GetAllChannels()); target.AddUsers(source.GetAllUsers().ToArray()); target.AddMessages(source.GetMessages().ToArray()); @@ -18,5 +13,17 @@ namespace DHT.Server.Database { target.AddDownload(download.Status == DownloadStatus.Success ? source.GetDownloadWithData(download) : download); } } + + internal static void AddServers(this IDatabaseFile target, IEnumerable servers) { + foreach (var server in servers) { + target.AddServer(server); + } + } + + internal static void AddChannels(this IDatabaseFile target, IEnumerable channels) { + foreach (var channel in channels) { + target.AddChannel(channel); + } + } } } diff --git a/app/Server/Database/DummyDatabaseFile.cs b/app/Server/Database/DummyDatabaseFile.cs index b771d73..ca5affd 100644 --- a/app/Server/Database/DummyDatabaseFile.cs +++ b/app/Server/Database/DummyDatabaseFile.cs @@ -45,6 +45,10 @@ namespace DHT.Server.Database { return new(); } + public HashSet GetMessageIds(MessageFilter? filter = null) { + return new(); + } + public void RemoveMessages(MessageFilter filter, FilterRemovalMode mode) {} public int CountAttachments(AttachmentFilter? filter = null) { diff --git a/app/Server/Database/IDatabaseFile.cs b/app/Server/Database/IDatabaseFile.cs index e219097..1fb3ed8 100644 --- a/app/Server/Database/IDatabaseFile.cs +++ b/app/Server/Database/IDatabaseFile.cs @@ -23,6 +23,7 @@ namespace DHT.Server.Database { void AddMessages(Message[] messages); int CountMessages(MessageFilter? filter = null); List GetMessages(MessageFilter? filter = null); + HashSet GetMessageIds(MessageFilter? filter = null); void RemoveMessages(MessageFilter filter, FilterRemovalMode mode); int CountAttachments(AttachmentFilter? filter = null); diff --git a/app/Server/Database/Import/FakeSnowflake.cs b/app/Server/Database/Import/FakeSnowflake.cs new file mode 100644 index 0000000..8478c93 --- /dev/null +++ b/app/Server/Database/Import/FakeSnowflake.cs @@ -0,0 +1,21 @@ +using System; + +namespace DHT.Server.Database.Import { + /// + /// https://discord.com/developers/docs/reference#snowflakes + /// + public sealed class FakeSnowflake { + private const ulong DiscordEpoch = 1420070400000UL; + + private ulong id; + + public FakeSnowflake() { + var unixMillis = (ulong) (DateTime.UtcNow.Subtract(DateTime.UnixEpoch).Ticks / TimeSpan.TicksPerMillisecond); + this.id = (unixMillis - DiscordEpoch) << 22; + } + + internal ulong Next() { + return id++; + } + } +} diff --git a/app/Server/Database/Import/LegacyArchiveImport.cs b/app/Server/Database/Import/LegacyArchiveImport.cs new file mode 100644 index 0000000..7e95744 --- /dev/null +++ b/app/Server/Database/Import/LegacyArchiveImport.cs @@ -0,0 +1,263 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Text.Json; +using System.Threading.Tasks; +using DHT.Server.Data; +using DHT.Utils.Collections; +using DHT.Utils.Http; +using DHT.Utils.Logging; +using Microsoft.AspNetCore.StaticFiles; + +namespace DHT.Server.Database.Import { + public static class LegacyArchiveImport { + private static readonly Log Log = Log.ForType(typeof(LegacyArchiveImport)); + + private static readonly FileExtensionContentTypeProvider ContentTypeProvider = new (); + + public static async Task Read(Stream stream, IDatabaseFile db, FakeSnowflake fakeSnowflake, Func?>> askForServerIds) { + var perf = Log.Start(); + var root = await JsonSerializer.DeserializeAsync(stream); + + try { + var meta = root.RequireObject("meta"); + var data = root.RequireObject("data"); + + perf.Step("Deserialize JSON"); + + var users = ReadUserList(meta); + var servers = ReadServerList(meta, fakeSnowflake); + + var newServersOnly = new HashSet(servers); + var oldServersById = db.GetAllServers().ToDictionary(static server => server.Id, static server => server); + + var oldChannels = db.GetAllChannels(); + var oldChannelsById = oldChannels.ToDictionary(static channel => channel.Id, static channel => channel); + + foreach (var (channelId, serverIndex) in ReadChannelToServerIndexMapping(meta, servers)) { + if (oldChannelsById.TryGetValue(channelId, out var oldChannel) && oldServersById.TryGetValue(oldChannel.Server, out var oldServer) && newServersOnly.Remove(servers[serverIndex])) { + servers[serverIndex] = oldServer; + } + } + + perf.Step("Read server and user list"); + + if (newServersOnly.Count > 0) { + var askedServerIds = await askForServerIds(newServersOnly.ToArray()); + if (askedServerIds == null) { + return false; + } + + perf.Step("Ask for server IDs"); + + for (var i = 0; i < servers.Length; i++) { + var server = servers[i]; + if (askedServerIds.TryGetValue(server, out var serverId)) { + servers[i] = new Data.Server { + Id = serverId, + Name = server.Name, + Type = server.Type + }; + } + } + } + + var channels = ReadChannelList(meta, servers); + + perf.Step("Read channel list"); + + var oldMessageIds = db.GetMessageIds(); + var newMessages = channels.SelectMany(channel => ReadMessages(data, channel, users, fakeSnowflake)) + .Where(message => !oldMessageIds.Contains(message.Id)) + .ToArray(); + + perf.Step("Read messages"); + + db.AddUsers(users); + db.AddServers(servers); + db.AddChannels(channels); + db.AddMessages(newMessages); + + perf.Step("Import into database"); + } catch (HttpException e) { + throw new JsonException(e.Message); + } + + perf.End(); + return true; + } + + private static User[] ReadUserList(JsonElement meta) { + const string UsersPath = "meta.users[]"; + + static ulong ParseUserIndex(JsonElement element, int index) { + return ulong.Parse(element.GetString() ?? throw new JsonException("Expected key 'meta.userindex[" + index + "]' to be a string.")); + } + + var userindex = meta.RequireArray("userindex", "meta") + .Select(static (item, index) => (ParseUserIndex(item, index), index)) + .ToDictionary(); + + var users = new User[userindex.Count]; + + foreach (var item in meta.RequireObject("users", "meta").EnumerateObject()) { + var path = UsersPath + "." + item.Name; + var userId = ulong.Parse(item.Name); + var userObj = item.Value; + + users[userindex[userId]] = new User { + Id = userId, + Name = userObj.RequireString("name", path), + AvatarUrl = userObj.HasKey("avatar") ? userObj.RequireString("avatar", path) : null, + Discriminator = userObj.HasKey("tag") ? userObj.RequireString("tag", path) : null + }; + } + + return users; + } + + private static Data.Server[] ReadServerList(JsonElement meta, FakeSnowflake fakeSnowflake) { + const string ServersPath = "meta.servers[]"; + + return meta.RequireArray("servers", "meta").Select(serverObj => new Data.Server { + Id = fakeSnowflake.Next(), + Name = serverObj.RequireString("name", ServersPath), + Type = ServerTypes.FromString(serverObj.RequireString("type", ServersPath)) + }).ToArray(); + } + + private const string ChannelsPath = "meta.channels"; + + private static Dictionary ReadChannelToServerIndexMapping(JsonElement meta, Data.Server[] servers) { + return meta.RequireObject("channels", "meta").EnumerateObject().Select(item => { + var path = ChannelsPath + "." + item.Name; + var channelId = ulong.Parse(item.Name); + var channelObj = item.Value; + + return (channelId, channelObj.RequireInt("server", path, min: 0, max: servers.Length - 1)); + }).ToDictionary(); + } + + private static Channel[] ReadChannelList(JsonElement meta, Data.Server[] servers) { + return meta.RequireObject("channels", "meta").EnumerateObject().Select(item => { + var path = ChannelsPath + "." + item.Name; + var channelId = ulong.Parse(item.Name); + var channelObj = item.Value; + + return new Channel { + Id = channelId, + Server = servers[channelObj.RequireInt("server", path, min: 0, max: servers.Length - 1)].Id, + Name = channelObj.RequireString("name", path), + Position = channelObj.HasKey("position") ? channelObj.RequireInt("position", path, min: 0) : null, + Topic = channelObj.HasKey("topic") ? channelObj.RequireString("topic", path) : null, + Nsfw = channelObj.HasKey("nsfw") ? channelObj.RequireBool("nsfw", path) : null + }; + }).ToArray(); + } + + private static Message[] ReadMessages(JsonElement data, Channel channel, User[] users, FakeSnowflake fakeSnowflake) { + const string DataPath = "data"; + + var channelId = channel.Id; + var channelIdStr = channelId.ToString(); + + var messagesObj = data.HasKey(channelIdStr) ? data.RequireObject(channelIdStr, DataPath) : (JsonElement?) null; + if (messagesObj == null) { + return Array.Empty(); + } + + return messagesObj.Value.EnumerateObject().Select(item => { + var path = DataPath + "." + item.Name; + var messageId = ulong.Parse(item.Name); + var messageObj = item.Value; + + return new Message { + Id = messageId, + Sender = users[messageObj.RequireInt("u", path, min: 0, max: users.Length - 1)].Id, + Channel = channelId, + Text = messageObj.HasKey("m") ? messageObj.RequireString("m", path) : string.Empty, + Timestamp = messageObj.RequireLong("t", path), + EditTimestamp = messageObj.HasKey("te") ? messageObj.RequireLong("te", path) : null, + RepliedToId = messageObj.HasKey("r") ? messageObj.RequireSnowflake("r", path) : null, + Attachments = messageObj.HasKey("a") ? ReadMessageAttachments(messageObj.RequireArray("a", path), fakeSnowflake, path + ".a[]").ToImmutableArray() : ImmutableArray.Empty, + Embeds = messageObj.HasKey("e") ? ReadMessageEmbeds(messageObj.RequireArray("e", path), path + ".e[]").ToImmutableArray() : ImmutableArray.Empty, + Reactions = messageObj.HasKey("re") ? ReadMessageReactions(messageObj.RequireArray("re", path), path + ".re[]").ToImmutableArray() : ImmutableArray.Empty + }; + }).ToArray(); + } + + [SuppressMessage("ReSharper", "ConvertToLambdaExpression")] + private static IEnumerable ReadMessageAttachments(JsonElement.ArrayEnumerator attachmentsArray, FakeSnowflake fakeSnowflake, string path) { + return attachmentsArray.Select(attachmentObj => { + string url = attachmentObj.RequireString("url", path); + string name = url[(url.LastIndexOf('/') + 1)..]; + string? type = ContentTypeProvider.TryGetContentType(name, out var contentType) ? contentType : null; + + return new Attachment { + Id = fakeSnowflake.Next(), + Name = name, + Type = type, + Url = url, + Size = 0 // unknown size + }; + }).DistinctByKeyStable(static attachment => { + // Some Discord messages have duplicate attachments with the same id for unknown reasons. + return attachment.Id; + }); + } + + private static IEnumerable ReadMessageEmbeds(JsonElement.ArrayEnumerator embedsArray, string path) { + // Some rich embeds are missing URLs which causes a missing 'url' key. + return embedsArray.Where(static embedObj => embedObj.HasKey("url")).Select(embedObj => { + string url = embedObj.RequireString("url", path); + string type = embedObj.RequireString("type", path); + + var embedJson = new Dictionary { + { "url", url }, + { "type", type }, + { "dht_legacy", true } + }; + + if (type == "image") { + embedJson["image"] = new Dictionary { + { "url", url } + }; + } + else if (type == "rich") { + if (embedObj.HasKey("t")) { + embedJson["title"] = embedObj.RequireString("t", path); + } + + if (embedObj.HasKey("d")) { + embedJson["description"] = embedObj.RequireString("d", path); + } + } + + return new Embed { + Json = JsonSerializer.Serialize(embedJson) + }; + }); + } + + private static IEnumerable ReadMessageReactions(JsonElement.ArrayEnumerator reactionsArray, string path) { + return reactionsArray.Select(reactionObj => { + var id = reactionObj.HasKey("id") ? reactionObj.RequireSnowflake("id", path) : (ulong?) null; + var name = reactionObj.HasKey("n") ? reactionObj.RequireString("n", path) : null; + + if (id == null && name == null) { + throw new JsonException("Expected key '" + path + ".id' and/or '" + path + ".n' to be present."); + } + + return new Reaction { + EmojiId = id, + EmojiName = name, + EmojiFlags = reactionObj.HasKey("an") && reactionObj.RequireBool("an", path) ? EmojiFlags.Animated : EmojiFlags.None, + Count = reactionObj.RequireInt("c", path, min: 0) + }; + }); + } + } +} diff --git a/app/Server/Database/Sqlite/SqliteDatabaseFile.cs b/app/Server/Database/Sqlite/SqliteDatabaseFile.cs index 02dccab..38bb0c6 100644 --- a/app/Server/Database/Sqlite/SqliteDatabaseFile.cs +++ b/app/Server/Database/Sqlite/SqliteDatabaseFile.cs @@ -386,6 +386,22 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC return list; } + public HashSet GetMessageIds(MessageFilter? filter = null) { + var perf = log.Start(); + var ids = new HashSet(); + + using var conn = pool.Take(); + using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause()); + using var reader = cmd.ExecuteReader(); + + while (reader.Read()) { + ids.Add(reader.GetUint64(0)); + } + + perf.End(); + return ids; + } + public void RemoveMessages(MessageFilter filter, FilterRemovalMode mode) { var whereClause = filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching); diff --git a/app/Utils/Collections/LinqExtensions.cs b/app/Utils/Collections/LinqExtensions.cs index c9b5173..3d6416f 100644 --- a/app/Utils/Collections/LinqExtensions.cs +++ b/app/Utils/Collections/LinqExtensions.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; namespace DHT.Utils.Collections { public static class LinqExtensions { @@ -14,5 +15,9 @@ namespace DHT.Utils.Collections { } } } + + public static Dictionary ToDictionary(this IEnumerable<(TKey, TValue)> collection) where TKey : notnull { + return collection.ToDictionary(static item => item.Item1, static item => item.Item2); + } } }