diff --git a/app/Server/Database/Sqlite/Schema.cs b/app/Server/Database/Sqlite/Schema.cs index 37bef24..073aa2d 100644 --- a/app/Server/Database/Sqlite/Schema.cs +++ b/app/Server/Database/Sqlite/Schema.cs @@ -26,7 +26,7 @@ namespace DHT.Server.Database.Sqlite { public async Task Setup(Func> checkCanUpgradeSchemas) { Execute(@"CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)"); - var dbVersionStr = Sql("SELECT value FROM metadata WHERE key = 'version'").ExecuteScalar(); + var dbVersionStr = conn.SelectScalar("SELECT value FROM metadata WHERE key = 'version'"); if (dbVersionStr == null) { InitializeSchemas(); } diff --git a/app/Server/Database/Sqlite/SqliteDatabaseFile.cs b/app/Server/Database/Sqlite/SqliteDatabaseFile.cs index 43017f4..16e33d5 100644 --- a/app/Server/Database/Sqlite/SqliteDatabaseFile.cs +++ b/app/Server/Database/Sqlite/SqliteDatabaseFile.cs @@ -43,13 +43,14 @@ namespace DHT.Server.Database.Sqlite { public void AddServer(Data.Server server) { using var cmd = conn.Upsert("servers", new[] { - "id", "name", "type" + ("id", SqliteType.Integer), + ("name", SqliteType.Text), + ("type", SqliteType.Text) }); - var serverParams = cmd.Parameters; - serverParams.AddAndSet(":id", server.Id); - serverParams.AddAndSet(":name", server.Name); - serverParams.AddAndSet(":type", ServerTypes.ToString(server.Type)); + cmd.Set(":id", server.Id); + cmd.Set(":name", server.Name); + cmd.Set(":type", ServerTypes.ToString(server.Type)); cmd.ExecuteNonQuery(); UpdateServerStatistics(); } @@ -62,7 +63,7 @@ namespace DHT.Server.Database.Sqlite { while (reader.Read()) { list.Add(new Data.Server { - Id = (ulong) reader.GetInt64(0), + Id = reader.GetUint64(0), Name = reader.GetString(1), Type = ServerTypes.FromString(reader.GetString(2)) }); @@ -73,17 +74,22 @@ namespace DHT.Server.Database.Sqlite { public void AddChannel(Channel channel) { using var cmd = conn.Upsert("channels", new[] { - "id", "server", "name", "parent_id", "position", "topic", "nsfw" + ("id", SqliteType.Integer), + ("server", SqliteType.Integer), + ("name", SqliteType.Text), + ("parent_id", SqliteType.Integer), + ("position", SqliteType.Integer), + ("topic", SqliteType.Text), + ("nsfw", SqliteType.Integer) }); - var channelParams = cmd.Parameters; - channelParams.AddAndSet(":id", channel.Id); - channelParams.AddAndSet(":server", channel.Server); - channelParams.AddAndSet(":name", channel.Name); - channelParams.AddAndSet(":parent_id", channel.ParentId); - channelParams.AddAndSet(":position", channel.Position); - channelParams.AddAndSet(":topic", channel.Topic); - channelParams.AddAndSet(":nsfw", channel.Nsfw); + cmd.Set(":id", channel.Id); + cmd.Set(":server", channel.Server); + cmd.Set(":name", channel.Name); + cmd.Set(":parent_id", channel.ParentId); + cmd.Set(":position", channel.Position); + cmd.Set(":topic", channel.Topic); + cmd.Set(":nsfw", channel.Nsfw); cmd.ExecuteNonQuery(); UpdateChannelStatistics(); } @@ -96,10 +102,10 @@ namespace DHT.Server.Database.Sqlite { while (reader.Read()) { list.Add(new Channel { - Id = (ulong) reader.GetInt64(0), - Server = (ulong) reader.GetInt64(1), + Id = reader.GetUint64(0), + Server = reader.GetUint64(1), Name = reader.GetString(2), - ParentId = reader.IsDBNull(3) ? null : (ulong) reader.GetInt64(3), + ParentId = reader.IsDBNull(3) ? null : reader.GetUint64(3), Position = reader.IsDBNull(4) ? null : reader.GetInt32(4), Topic = reader.IsDBNull(5) ? null : reader.GetString(5), Nsfw = reader.IsDBNull(6) ? null : reader.GetBoolean(6) @@ -112,20 +118,17 @@ namespace DHT.Server.Database.Sqlite { public void AddUsers(User[] users) { using var tx = conn.BeginTransaction(); using var cmd = conn.Upsert("users", new[] { - "id", "name", "avatar_url", "discriminator" + ("id", SqliteType.Integer), + ("name", SqliteType.Text), + ("avatar_url", SqliteType.Text), + ("discriminator", SqliteType.Text) }); - var userParams = cmd.Parameters; - userParams.Add(":id", SqliteType.Integer); - userParams.Add(":name", SqliteType.Text); - userParams.Add(":avatar_url", SqliteType.Text); - userParams.Add(":discriminator", SqliteType.Text); - foreach (var user in users) { - userParams.Set(":id", user.Id); - userParams.Set(":name", user.Name); - userParams.Set(":avatar_url", user.AvatarUrl); - userParams.Set(":discriminator", user.Discriminator); + cmd.Set(":id", user.Id); + cmd.Set(":name", user.Name); + cmd.Set(":avatar_url", user.AvatarUrl); + cmd.Set(":discriminator", user.Discriminator); cmd.ExecuteNonQuery(); } @@ -141,7 +144,7 @@ namespace DHT.Server.Database.Sqlite { while (reader.Read()) { list.Add(new User { - Id = (ulong) reader.GetInt64(0), + Id = reader.GetUint64(0), Name = reader.GetString(1), AvatarUrl = reader.IsDBNull(2) ? null : reader.GetString(2), Discriminator = reader.IsDBNull(3) ? null : reader.GetString(3) @@ -153,110 +156,91 @@ namespace DHT.Server.Database.Sqlite { public void AddMessages(Message[] messages) { using var tx = conn.BeginTransaction(); + using var messageCmd = conn.Upsert("messages", new[] { - "message_id", "sender_id", "channel_id", "text", "timestamp", "edit_timestamp", "replied_to_id" + ("message_id", SqliteType.Integer), + ("sender_id", SqliteType.Integer), + ("channel_id", SqliteType.Integer), + ("text", SqliteType.Text), + ("timestamp", SqliteType.Integer), + ("edit_timestamp", SqliteType.Integer), + ("replied_to_id", SqliteType.Integer) }); - using var deleteAttachmentsCmd = conn.Command("DELETE FROM attachments WHERE message_id = :message_id"); + using var deleteAttachmentsCmd = conn.Delete("attachments", ("message_id", SqliteType.Integer)); + using var deleteEmbedsCmd = conn.Delete("embeds", ("message_id", SqliteType.Integer)); + using var deleteReactionsCmd = conn.Delete("reactions", ("message_id", SqliteType.Integer)); + using var attachmentCmd = conn.Insert("attachments", new[] { - "message_id", "attachment_id", "name", "type", "url", "size" + ("message_id", SqliteType.Integer), + ("attachment_id", SqliteType.Integer), + ("name", SqliteType.Text), + ("type", SqliteType.Text), + ("url", SqliteType.Text), + ("size", SqliteType.Integer) }); - using var deleteEmbedsCmd = conn.Command("DELETE FROM embeds WHERE message_id = :message_id"); using var embedCmd = conn.Insert("embeds", new[] { - "message_id", "json" + ("message_id", SqliteType.Integer), + ("json", SqliteType.Text) }); - using var deleteReactionsCmd = conn.Command("DELETE FROM reactions WHERE message_id = :message_id"); using var reactionCmd = conn.Insert("reactions", new[] { - "message_id", "emoji_id", "emoji_name", "emoji_flags", "count" + ("message_id", SqliteType.Integer), + ("emoji_id", SqliteType.Integer), + ("emoji_name", SqliteType.Text), + ("emoji_flags", SqliteType.Integer), + ("count", SqliteType.Integer) }); - var messageParams = messageCmd.Parameters; - messageParams.Add(":message_id", SqliteType.Integer); - messageParams.Add(":sender_id", SqliteType.Integer); - messageParams.Add(":channel_id", SqliteType.Integer); - messageParams.Add(":text", SqliteType.Text); - messageParams.Add(":timestamp", SqliteType.Integer); - messageParams.Add(":edit_timestamp", SqliteType.Integer); - messageParams.Add(":replied_to_id", SqliteType.Integer); - - var deleteAttachmentsParams = deleteAttachmentsCmd.Parameters; - deleteAttachmentsParams.Add(":message_id", SqliteType.Integer); - - var attachmentParams = attachmentCmd.Parameters; - attachmentParams.Add(":message_id", SqliteType.Integer); - attachmentParams.Add(":attachment_id", SqliteType.Integer); - attachmentParams.Add(":name", SqliteType.Text); - attachmentParams.Add(":type", SqliteType.Text); - attachmentParams.Add(":url", SqliteType.Text); - attachmentParams.Add(":size", SqliteType.Integer); - - var deleteEmbedsParams = deleteEmbedsCmd.Parameters; - deleteEmbedsParams.Add(":message_id", SqliteType.Integer); - - var embedParams = embedCmd.Parameters; - embedParams.Add(":message_id", SqliteType.Integer); - embedParams.Add(":json", SqliteType.Text); - - var deleteReactionsParams = deleteReactionsCmd.Parameters; - deleteReactionsParams.Add(":message_id", SqliteType.Integer); - - var reactionParams = reactionCmd.Parameters; - reactionParams.Add(":message_id", SqliteType.Integer); - reactionParams.Add(":emoji_id", SqliteType.Integer); - reactionParams.Add(":emoji_name", SqliteType.Text); - reactionParams.Add(":emoji_flags", SqliteType.Integer); - reactionParams.Add(":count", SqliteType.Integer); - foreach (var message in messages) { object messageId = message.Id; - messageParams.Set(":message_id", messageId); - messageParams.Set(":sender_id", message.Sender); - messageParams.Set(":channel_id", message.Channel); - messageParams.Set(":text", message.Text); - messageParams.Set(":timestamp", message.Timestamp); - messageParams.Set(":edit_timestamp", message.EditTimestamp); - messageParams.Set(":replied_to_id", message.RepliedToId); + messageCmd.Set(":message_id", messageId); + messageCmd.Set(":sender_id", message.Sender); + messageCmd.Set(":channel_id", message.Channel); + messageCmd.Set(":text", message.Text); + messageCmd.Set(":timestamp", message.Timestamp); + messageCmd.Set(":edit_timestamp", message.EditTimestamp); + messageCmd.Set(":replied_to_id", message.RepliedToId); messageCmd.ExecuteNonQuery(); - deleteAttachmentsParams.Set(":message_id", messageId); + deleteAttachmentsCmd.Set(":message_id", messageId); deleteAttachmentsCmd.ExecuteNonQuery(); - deleteEmbedsParams.Set(":message_id", messageId); + deleteEmbedsCmd.Set(":message_id", messageId); deleteEmbedsCmd.ExecuteNonQuery(); - deleteReactionsParams.Set(":message_id", messageId); + deleteReactionsCmd.Set(":message_id", messageId); deleteReactionsCmd.ExecuteNonQuery(); if (!message.Attachments.IsEmpty) { foreach (var attachment in message.Attachments) { - attachmentParams.Set(":message_id", messageId); - attachmentParams.Set(":attachment_id", attachment.Id); - attachmentParams.Set(":name", attachment.Name); - attachmentParams.Set(":type", attachment.Type); - attachmentParams.Set(":url", attachment.Url); - attachmentParams.Set(":size", attachment.Size); + attachmentCmd.Set(":message_id", messageId); + attachmentCmd.Set(":attachment_id", attachment.Id); + attachmentCmd.Set(":name", attachment.Name); + attachmentCmd.Set(":type", attachment.Type); + attachmentCmd.Set(":url", attachment.Url); + attachmentCmd.Set(":size", attachment.Size); attachmentCmd.ExecuteNonQuery(); } } if (!message.Embeds.IsEmpty) { foreach (var embed in message.Embeds) { - embedParams.Set(":message_id", messageId); - embedParams.Set(":json", embed.Json); + embedCmd.Set(":message_id", messageId); + embedCmd.Set(":json", embed.Json); embedCmd.ExecuteNonQuery(); } } if (!message.Reactions.IsEmpty) { foreach (var reaction in message.Reactions) { - reactionParams.Set(":message_id", messageId); - reactionParams.Set(":emoji_id", reaction.EmojiId); - reactionParams.Set(":emoji_name", reaction.EmojiName); - reactionParams.Set(":emoji_flags", (int) reaction.EmojiFlags); - reactionParams.Set(":count", reaction.Count); + reactionCmd.Set(":message_id", messageId); + reactionCmd.Set(":emoji_id", reaction.EmojiId); + reactionCmd.Set(":emoji_name", reaction.EmojiName); + reactionCmd.Set(":emoji_flags", (int) reaction.EmojiFlags); + reactionCmd.Set(":count", reaction.Count); reactionCmd.ExecuteNonQuery(); } } @@ -284,16 +268,16 @@ namespace DHT.Server.Database.Sqlite { using var reader = cmd.ExecuteReader(); while (reader.Read()) { - ulong id = (ulong) reader.GetInt64(0); + ulong id = reader.GetUint64(0); list.Add(new Message { Id = id, - Sender = (ulong) reader.GetInt64(1), - Channel = (ulong) reader.GetInt64(2), + Sender = reader.GetUint64(1), + Channel = reader.GetUint64(2), Text = reader.GetString(3), Timestamp = reader.GetInt64(4), EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5), - RepliedToId = reader.IsDBNull(6) ? null : (ulong) reader.GetInt64(6), + RepliedToId = reader.IsDBNull(6) ? null : reader.GetUint64(6), Attachments = attachments.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray.Empty, Embeds = embeds.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray.Empty, Reactions = reactions.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray.Empty @@ -328,14 +312,14 @@ namespace DHT.Server.Database.Sqlite { using var reader = cmd.ExecuteReader(); while (reader.Read()) { - ulong messageId = (ulong) reader.GetInt64(0); + ulong messageId = reader.GetUint64(0); dict.Add(messageId, new Attachment { - Id = (ulong) reader.GetInt64(1), + Id = reader.GetUint64(1), Name = reader.GetString(2), Type = reader.IsDBNull(3) ? null : reader.GetString(3), Url = reader.GetString(4), - Size = (ulong) reader.GetInt64(5) + Size = reader.GetUint64(5) }); } @@ -349,7 +333,7 @@ namespace DHT.Server.Database.Sqlite { using var reader = cmd.ExecuteReader(); while (reader.Read()) { - ulong messageId = (ulong) reader.GetInt64(0); + ulong messageId = reader.GetUint64(0); dict.Add(messageId, new Embed { Json = reader.GetString(1) @@ -366,10 +350,10 @@ namespace DHT.Server.Database.Sqlite { using var reader = cmd.ExecuteReader(); while (reader.Read()) { - ulong messageId = (ulong) reader.GetInt64(0); + ulong messageId = reader.GetUint64(0); dict.Add(messageId, new Reaction { - EmojiId = reader.IsDBNull(1) ? null : (ulong) reader.GetInt64(1), + EmojiId = reader.IsDBNull(1) ? null : reader.GetUint64(1), EmojiName = reader.IsDBNull(2) ? null : reader.GetString(2), EmojiFlags = (EmojiFlags) reader.GetInt16(3), Count = reader.GetInt32(4) @@ -380,23 +364,19 @@ namespace DHT.Server.Database.Sqlite { } private void UpdateServerStatistics() { - using var cmd = conn.Command("SELECT COUNT(*) FROM servers"); - Statistics.TotalServers = cmd.ExecuteScalar() as long? ?? 0; + Statistics.TotalServers = conn.SelectScalar("SELECT COUNT(*) FROM servers") as long? ?? 0; } private void UpdateChannelStatistics() { - using var cmd = conn.Command("SELECT COUNT(*) FROM channels"); - Statistics.TotalChannels = cmd.ExecuteScalar() as long? ?? 0; + Statistics.TotalChannels = conn.SelectScalar("SELECT COUNT(*) FROM channels") as long? ?? 0; } private void UpdateUserStatistics() { - using var cmd = conn.Command("SELECT COUNT(*) FROM users"); - Statistics.TotalUsers = cmd.ExecuteScalar() as long? ?? 0; + Statistics.TotalUsers = conn.SelectScalar("SELECT COUNT(*) FROM users") as long? ?? 0; } private void UpdateMessageStatistics() { - using var cmd = conn.Command("SELECT COUNT(*) FROM messages"); - Statistics.TotalMessages = cmd.ExecuteScalar() as long? ?? 0L; + Statistics.TotalMessages = conn.SelectScalar("SELECT COUNT(*) FROM messages") as long? ?? 0L; } } } diff --git a/app/Server/Database/Sqlite/SqliteUtils.cs b/app/Server/Database/Sqlite/SqliteUtils.cs index ef897b2..53e7c2c 100644 --- a/app/Server/Database/Sqlite/SqliteUtils.cs +++ b/app/Server/Database/Sqlite/SqliteUtils.cs @@ -10,31 +10,54 @@ namespace DHT.Server.Database.Sqlite { return cmd; } - public static SqliteCommand Insert(this SqliteConnection conn, string tableName, string[] columns) { - string columnNames = string.Join(',', columns); - string columnParams = string.Join(',', columns.Select(static c => ':' + c)); - - return conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" + - "VALUES (" + columnParams + ")"); + public static object? SelectScalar(this SqliteConnection conn, string sql) { + using var cmd = conn.Command(sql); + return cmd.ExecuteScalar(); } - public static SqliteCommand Upsert(this SqliteConnection conn, string tableName, string[] columns) { - string columnNames = string.Join(',', columns); - string columnParams = string.Join(',', columns.Select(static c => ':' + c)); - string columnUpdates = string.Join(',', columns.Skip(1).Select(static c => c + " = excluded." + c)); + public static SqliteCommand Insert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) { + string columnNames = string.Join(',', columns.Select(static c => c.Name)); + string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name)); - return conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" + - "VALUES (" + columnParams + ")" + - "ON CONFLICT (" + columns[0] + ")" + - "DO UPDATE SET " + columnUpdates); + var cmd = conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" + + "VALUES (" + columnParams + ")"); + + CreateParameters(cmd, columns); + return cmd; } - public static void AddAndSet(this SqliteParameterCollection parameters, string key, object? value) { - parameters.AddWithValue(key, value ?? DBNull.Value); + public static SqliteCommand Upsert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) { + string columnNames = string.Join(',', columns.Select(static c => c.Name)); + string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name)); + string columnUpdates = string.Join(',', columns.Skip(1).Select(static c => c.Name + " = excluded." + c.Name)); + + var cmd = conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" + + "VALUES (" + columnParams + ")" + + "ON CONFLICT (" + columns[0].Name + ")" + + "DO UPDATE SET " + columnUpdates); + + CreateParameters(cmd, columns); + return cmd; } - public static void Set(this SqliteParameterCollection parameters, string key, object? value) { - parameters[key].Value = value ?? DBNull.Value; + public static SqliteCommand Delete(this SqliteConnection conn, string tableName, (string Name, SqliteType Type) column) { + var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name); + CreateParameters(cmd, new [] { column }); + return cmd; + } + + private static void CreateParameters(SqliteCommand cmd, (string Name, SqliteType Type)[] columns) { + foreach (var (name, type) in columns) { + cmd.Parameters.Add(":" + name, type); + } + } + + public static void Set(this SqliteCommand cmd, string key, object? value) { + cmd.Parameters[key].Value = value ?? DBNull.Value; + } + + public static ulong GetUint64(this SqliteDataReader reader, int ordinal) { + return (ulong) reader.GetInt64(ordinal); } } }