Add utilities to simplify working with SQLite

This commit is contained in:
chylex 2022-03-05 22:18:02 +01:00
parent b9899922e0
commit 6f1149ad5e
No known key found for this signature in database
GPG Key ID: 4DE42C8F19A80548
3 changed files with 137 additions and 134 deletions

View File

@ -26,7 +26,7 @@ namespace DHT.Server.Database.Sqlite {
public async Task<bool> Setup(Func<Task<bool>> checkCanUpgradeSchemas) { public async Task<bool> Setup(Func<Task<bool>> checkCanUpgradeSchemas) {
Execute(@"CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)"); Execute(@"CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)");
var dbVersionStr = Sql("SELECT value FROM metadata WHERE key = 'version'").ExecuteScalar(); var dbVersionStr = conn.SelectScalar("SELECT value FROM metadata WHERE key = 'version'");
if (dbVersionStr == null) { if (dbVersionStr == null) {
InitializeSchemas(); InitializeSchemas();
} }

View File

@ -43,13 +43,14 @@ namespace DHT.Server.Database.Sqlite {
public void AddServer(Data.Server server) { public void AddServer(Data.Server server) {
using var cmd = conn.Upsert("servers", new[] { using var cmd = conn.Upsert("servers", new[] {
"id", "name", "type" ("id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text)
}); });
var serverParams = cmd.Parameters; cmd.Set(":id", server.Id);
serverParams.AddAndSet(":id", server.Id); cmd.Set(":name", server.Name);
serverParams.AddAndSet(":name", server.Name); cmd.Set(":type", ServerTypes.ToString(server.Type));
serverParams.AddAndSet(":type", ServerTypes.ToString(server.Type));
cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery();
UpdateServerStatistics(); UpdateServerStatistics();
} }
@ -62,7 +63,7 @@ namespace DHT.Server.Database.Sqlite {
while (reader.Read()) { while (reader.Read()) {
list.Add(new Data.Server { list.Add(new Data.Server {
Id = (ulong) reader.GetInt64(0), Id = reader.GetUint64(0),
Name = reader.GetString(1), Name = reader.GetString(1),
Type = ServerTypes.FromString(reader.GetString(2)) Type = ServerTypes.FromString(reader.GetString(2))
}); });
@ -73,17 +74,22 @@ namespace DHT.Server.Database.Sqlite {
public void AddChannel(Channel channel) { public void AddChannel(Channel channel) {
using var cmd = conn.Upsert("channels", new[] { using var cmd = conn.Upsert("channels", new[] {
"id", "server", "name", "parent_id", "position", "topic", "nsfw" ("id", SqliteType.Integer),
("server", SqliteType.Integer),
("name", SqliteType.Text),
("parent_id", SqliteType.Integer),
("position", SqliteType.Integer),
("topic", SqliteType.Text),
("nsfw", SqliteType.Integer)
}); });
var channelParams = cmd.Parameters; cmd.Set(":id", channel.Id);
channelParams.AddAndSet(":id", channel.Id); cmd.Set(":server", channel.Server);
channelParams.AddAndSet(":server", channel.Server); cmd.Set(":name", channel.Name);
channelParams.AddAndSet(":name", channel.Name); cmd.Set(":parent_id", channel.ParentId);
channelParams.AddAndSet(":parent_id", channel.ParentId); cmd.Set(":position", channel.Position);
channelParams.AddAndSet(":position", channel.Position); cmd.Set(":topic", channel.Topic);
channelParams.AddAndSet(":topic", channel.Topic); cmd.Set(":nsfw", channel.Nsfw);
channelParams.AddAndSet(":nsfw", channel.Nsfw);
cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery();
UpdateChannelStatistics(); UpdateChannelStatistics();
} }
@ -96,10 +102,10 @@ namespace DHT.Server.Database.Sqlite {
while (reader.Read()) { while (reader.Read()) {
list.Add(new Channel { list.Add(new Channel {
Id = (ulong) reader.GetInt64(0), Id = reader.GetUint64(0),
Server = (ulong) reader.GetInt64(1), Server = reader.GetUint64(1),
Name = reader.GetString(2), Name = reader.GetString(2),
ParentId = reader.IsDBNull(3) ? null : (ulong) reader.GetInt64(3), ParentId = reader.IsDBNull(3) ? null : reader.GetUint64(3),
Position = reader.IsDBNull(4) ? null : reader.GetInt32(4), Position = reader.IsDBNull(4) ? null : reader.GetInt32(4),
Topic = reader.IsDBNull(5) ? null : reader.GetString(5), Topic = reader.IsDBNull(5) ? null : reader.GetString(5),
Nsfw = reader.IsDBNull(6) ? null : reader.GetBoolean(6) Nsfw = reader.IsDBNull(6) ? null : reader.GetBoolean(6)
@ -112,20 +118,17 @@ namespace DHT.Server.Database.Sqlite {
public void AddUsers(User[] users) { public void AddUsers(User[] users) {
using var tx = conn.BeginTransaction(); using var tx = conn.BeginTransaction();
using var cmd = conn.Upsert("users", new[] { using var cmd = conn.Upsert("users", new[] {
"id", "name", "avatar_url", "discriminator" ("id", SqliteType.Integer),
("name", SqliteType.Text),
("avatar_url", SqliteType.Text),
("discriminator", SqliteType.Text)
}); });
var userParams = cmd.Parameters;
userParams.Add(":id", SqliteType.Integer);
userParams.Add(":name", SqliteType.Text);
userParams.Add(":avatar_url", SqliteType.Text);
userParams.Add(":discriminator", SqliteType.Text);
foreach (var user in users) { foreach (var user in users) {
userParams.Set(":id", user.Id); cmd.Set(":id", user.Id);
userParams.Set(":name", user.Name); cmd.Set(":name", user.Name);
userParams.Set(":avatar_url", user.AvatarUrl); cmd.Set(":avatar_url", user.AvatarUrl);
userParams.Set(":discriminator", user.Discriminator); cmd.Set(":discriminator", user.Discriminator);
cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery();
} }
@ -141,7 +144,7 @@ namespace DHT.Server.Database.Sqlite {
while (reader.Read()) { while (reader.Read()) {
list.Add(new User { list.Add(new User {
Id = (ulong) reader.GetInt64(0), Id = reader.GetUint64(0),
Name = reader.GetString(1), Name = reader.GetString(1),
AvatarUrl = reader.IsDBNull(2) ? null : reader.GetString(2), AvatarUrl = reader.IsDBNull(2) ? null : reader.GetString(2),
Discriminator = reader.IsDBNull(3) ? null : reader.GetString(3) Discriminator = reader.IsDBNull(3) ? null : reader.GetString(3)
@ -153,110 +156,91 @@ namespace DHT.Server.Database.Sqlite {
public void AddMessages(Message[] messages) { public void AddMessages(Message[] messages) {
using var tx = conn.BeginTransaction(); using var tx = conn.BeginTransaction();
using var messageCmd = conn.Upsert("messages", new[] { using var messageCmd = conn.Upsert("messages", new[] {
"message_id", "sender_id", "channel_id", "text", "timestamp", "edit_timestamp", "replied_to_id" ("message_id", SqliteType.Integer),
("sender_id", SqliteType.Integer),
("channel_id", SqliteType.Integer),
("text", SqliteType.Text),
("timestamp", SqliteType.Integer),
("edit_timestamp", SqliteType.Integer),
("replied_to_id", SqliteType.Integer)
}); });
using var deleteAttachmentsCmd = conn.Command("DELETE FROM attachments WHERE message_id = :message_id"); using var deleteAttachmentsCmd = conn.Delete("attachments", ("message_id", SqliteType.Integer));
using var deleteEmbedsCmd = conn.Delete("embeds", ("message_id", SqliteType.Integer));
using var deleteReactionsCmd = conn.Delete("reactions", ("message_id", SqliteType.Integer));
using var attachmentCmd = conn.Insert("attachments", new[] { using var attachmentCmd = conn.Insert("attachments", new[] {
"message_id", "attachment_id", "name", "type", "url", "size" ("message_id", SqliteType.Integer),
("attachment_id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text),
("url", SqliteType.Text),
("size", SqliteType.Integer)
}); });
using var deleteEmbedsCmd = conn.Command("DELETE FROM embeds WHERE message_id = :message_id");
using var embedCmd = conn.Insert("embeds", new[] { using var embedCmd = conn.Insert("embeds", new[] {
"message_id", "json" ("message_id", SqliteType.Integer),
("json", SqliteType.Text)
}); });
using var deleteReactionsCmd = conn.Command("DELETE FROM reactions WHERE message_id = :message_id");
using var reactionCmd = conn.Insert("reactions", new[] { using var reactionCmd = conn.Insert("reactions", new[] {
"message_id", "emoji_id", "emoji_name", "emoji_flags", "count" ("message_id", SqliteType.Integer),
("emoji_id", SqliteType.Integer),
("emoji_name", SqliteType.Text),
("emoji_flags", SqliteType.Integer),
("count", SqliteType.Integer)
}); });
var messageParams = messageCmd.Parameters;
messageParams.Add(":message_id", SqliteType.Integer);
messageParams.Add(":sender_id", SqliteType.Integer);
messageParams.Add(":channel_id", SqliteType.Integer);
messageParams.Add(":text", SqliteType.Text);
messageParams.Add(":timestamp", SqliteType.Integer);
messageParams.Add(":edit_timestamp", SqliteType.Integer);
messageParams.Add(":replied_to_id", SqliteType.Integer);
var deleteAttachmentsParams = deleteAttachmentsCmd.Parameters;
deleteAttachmentsParams.Add(":message_id", SqliteType.Integer);
var attachmentParams = attachmentCmd.Parameters;
attachmentParams.Add(":message_id", SqliteType.Integer);
attachmentParams.Add(":attachment_id", SqliteType.Integer);
attachmentParams.Add(":name", SqliteType.Text);
attachmentParams.Add(":type", SqliteType.Text);
attachmentParams.Add(":url", SqliteType.Text);
attachmentParams.Add(":size", SqliteType.Integer);
var deleteEmbedsParams = deleteEmbedsCmd.Parameters;
deleteEmbedsParams.Add(":message_id", SqliteType.Integer);
var embedParams = embedCmd.Parameters;
embedParams.Add(":message_id", SqliteType.Integer);
embedParams.Add(":json", SqliteType.Text);
var deleteReactionsParams = deleteReactionsCmd.Parameters;
deleteReactionsParams.Add(":message_id", SqliteType.Integer);
var reactionParams = reactionCmd.Parameters;
reactionParams.Add(":message_id", SqliteType.Integer);
reactionParams.Add(":emoji_id", SqliteType.Integer);
reactionParams.Add(":emoji_name", SqliteType.Text);
reactionParams.Add(":emoji_flags", SqliteType.Integer);
reactionParams.Add(":count", SqliteType.Integer);
foreach (var message in messages) { foreach (var message in messages) {
object messageId = message.Id; object messageId = message.Id;
messageParams.Set(":message_id", messageId); messageCmd.Set(":message_id", messageId);
messageParams.Set(":sender_id", message.Sender); messageCmd.Set(":sender_id", message.Sender);
messageParams.Set(":channel_id", message.Channel); messageCmd.Set(":channel_id", message.Channel);
messageParams.Set(":text", message.Text); messageCmd.Set(":text", message.Text);
messageParams.Set(":timestamp", message.Timestamp); messageCmd.Set(":timestamp", message.Timestamp);
messageParams.Set(":edit_timestamp", message.EditTimestamp); messageCmd.Set(":edit_timestamp", message.EditTimestamp);
messageParams.Set(":replied_to_id", message.RepliedToId); messageCmd.Set(":replied_to_id", message.RepliedToId);
messageCmd.ExecuteNonQuery(); messageCmd.ExecuteNonQuery();
deleteAttachmentsParams.Set(":message_id", messageId); deleteAttachmentsCmd.Set(":message_id", messageId);
deleteAttachmentsCmd.ExecuteNonQuery(); deleteAttachmentsCmd.ExecuteNonQuery();
deleteEmbedsParams.Set(":message_id", messageId); deleteEmbedsCmd.Set(":message_id", messageId);
deleteEmbedsCmd.ExecuteNonQuery(); deleteEmbedsCmd.ExecuteNonQuery();
deleteReactionsParams.Set(":message_id", messageId); deleteReactionsCmd.Set(":message_id", messageId);
deleteReactionsCmd.ExecuteNonQuery(); deleteReactionsCmd.ExecuteNonQuery();
if (!message.Attachments.IsEmpty) { if (!message.Attachments.IsEmpty) {
foreach (var attachment in message.Attachments) { foreach (var attachment in message.Attachments) {
attachmentParams.Set(":message_id", messageId); attachmentCmd.Set(":message_id", messageId);
attachmentParams.Set(":attachment_id", attachment.Id); attachmentCmd.Set(":attachment_id", attachment.Id);
attachmentParams.Set(":name", attachment.Name); attachmentCmd.Set(":name", attachment.Name);
attachmentParams.Set(":type", attachment.Type); attachmentCmd.Set(":type", attachment.Type);
attachmentParams.Set(":url", attachment.Url); attachmentCmd.Set(":url", attachment.Url);
attachmentParams.Set(":size", attachment.Size); attachmentCmd.Set(":size", attachment.Size);
attachmentCmd.ExecuteNonQuery(); attachmentCmd.ExecuteNonQuery();
} }
} }
if (!message.Embeds.IsEmpty) { if (!message.Embeds.IsEmpty) {
foreach (var embed in message.Embeds) { foreach (var embed in message.Embeds) {
embedParams.Set(":message_id", messageId); embedCmd.Set(":message_id", messageId);
embedParams.Set(":json", embed.Json); embedCmd.Set(":json", embed.Json);
embedCmd.ExecuteNonQuery(); embedCmd.ExecuteNonQuery();
} }
} }
if (!message.Reactions.IsEmpty) { if (!message.Reactions.IsEmpty) {
foreach (var reaction in message.Reactions) { foreach (var reaction in message.Reactions) {
reactionParams.Set(":message_id", messageId); reactionCmd.Set(":message_id", messageId);
reactionParams.Set(":emoji_id", reaction.EmojiId); reactionCmd.Set(":emoji_id", reaction.EmojiId);
reactionParams.Set(":emoji_name", reaction.EmojiName); reactionCmd.Set(":emoji_name", reaction.EmojiName);
reactionParams.Set(":emoji_flags", (int) reaction.EmojiFlags); reactionCmd.Set(":emoji_flags", (int) reaction.EmojiFlags);
reactionParams.Set(":count", reaction.Count); reactionCmd.Set(":count", reaction.Count);
reactionCmd.ExecuteNonQuery(); reactionCmd.ExecuteNonQuery();
} }
} }
@ -284,16 +268,16 @@ namespace DHT.Server.Database.Sqlite {
using var reader = cmd.ExecuteReader(); using var reader = cmd.ExecuteReader();
while (reader.Read()) { while (reader.Read()) {
ulong id = (ulong) reader.GetInt64(0); ulong id = reader.GetUint64(0);
list.Add(new Message { list.Add(new Message {
Id = id, Id = id,
Sender = (ulong) reader.GetInt64(1), Sender = reader.GetUint64(1),
Channel = (ulong) reader.GetInt64(2), Channel = reader.GetUint64(2),
Text = reader.GetString(3), Text = reader.GetString(3),
Timestamp = reader.GetInt64(4), Timestamp = reader.GetInt64(4),
EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5), EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5),
RepliedToId = reader.IsDBNull(6) ? null : (ulong) reader.GetInt64(6), RepliedToId = reader.IsDBNull(6) ? null : reader.GetUint64(6),
Attachments = attachments.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Attachment>.Empty, Attachments = attachments.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Attachment>.Empty,
Embeds = embeds.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Embed>.Empty, Embeds = embeds.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Embed>.Empty,
Reactions = reactions.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Reaction>.Empty Reactions = reactions.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Reaction>.Empty
@ -328,14 +312,14 @@ namespace DHT.Server.Database.Sqlite {
using var reader = cmd.ExecuteReader(); using var reader = cmd.ExecuteReader();
while (reader.Read()) { while (reader.Read()) {
ulong messageId = (ulong) reader.GetInt64(0); ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Attachment { dict.Add(messageId, new Attachment {
Id = (ulong) reader.GetInt64(1), Id = reader.GetUint64(1),
Name = reader.GetString(2), Name = reader.GetString(2),
Type = reader.IsDBNull(3) ? null : reader.GetString(3), Type = reader.IsDBNull(3) ? null : reader.GetString(3),
Url = reader.GetString(4), Url = reader.GetString(4),
Size = (ulong) reader.GetInt64(5) Size = reader.GetUint64(5)
}); });
} }
@ -349,7 +333,7 @@ namespace DHT.Server.Database.Sqlite {
using var reader = cmd.ExecuteReader(); using var reader = cmd.ExecuteReader();
while (reader.Read()) { while (reader.Read()) {
ulong messageId = (ulong) reader.GetInt64(0); ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Embed { dict.Add(messageId, new Embed {
Json = reader.GetString(1) Json = reader.GetString(1)
@ -366,10 +350,10 @@ namespace DHT.Server.Database.Sqlite {
using var reader = cmd.ExecuteReader(); using var reader = cmd.ExecuteReader();
while (reader.Read()) { while (reader.Read()) {
ulong messageId = (ulong) reader.GetInt64(0); ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Reaction { dict.Add(messageId, new Reaction {
EmojiId = reader.IsDBNull(1) ? null : (ulong) reader.GetInt64(1), EmojiId = reader.IsDBNull(1) ? null : reader.GetUint64(1),
EmojiName = reader.IsDBNull(2) ? null : reader.GetString(2), EmojiName = reader.IsDBNull(2) ? null : reader.GetString(2),
EmojiFlags = (EmojiFlags) reader.GetInt16(3), EmojiFlags = (EmojiFlags) reader.GetInt16(3),
Count = reader.GetInt32(4) Count = reader.GetInt32(4)
@ -380,23 +364,19 @@ namespace DHT.Server.Database.Sqlite {
} }
private void UpdateServerStatistics() { private void UpdateServerStatistics() {
using var cmd = conn.Command("SELECT COUNT(*) FROM servers"); Statistics.TotalServers = conn.SelectScalar("SELECT COUNT(*) FROM servers") as long? ?? 0;
Statistics.TotalServers = cmd.ExecuteScalar() as long? ?? 0;
} }
private void UpdateChannelStatistics() { private void UpdateChannelStatistics() {
using var cmd = conn.Command("SELECT COUNT(*) FROM channels"); Statistics.TotalChannels = conn.SelectScalar("SELECT COUNT(*) FROM channels") as long? ?? 0;
Statistics.TotalChannels = cmd.ExecuteScalar() as long? ?? 0;
} }
private void UpdateUserStatistics() { private void UpdateUserStatistics() {
using var cmd = conn.Command("SELECT COUNT(*) FROM users"); Statistics.TotalUsers = conn.SelectScalar("SELECT COUNT(*) FROM users") as long? ?? 0;
Statistics.TotalUsers = cmd.ExecuteScalar() as long? ?? 0;
} }
private void UpdateMessageStatistics() { private void UpdateMessageStatistics() {
using var cmd = conn.Command("SELECT COUNT(*) FROM messages"); Statistics.TotalMessages = conn.SelectScalar("SELECT COUNT(*) FROM messages") as long? ?? 0L;
Statistics.TotalMessages = cmd.ExecuteScalar() as long? ?? 0L;
} }
} }
} }

View File

@ -10,31 +10,54 @@ namespace DHT.Server.Database.Sqlite {
return cmd; return cmd;
} }
public static SqliteCommand Insert(this SqliteConnection conn, string tableName, string[] columns) { public static object? SelectScalar(this SqliteConnection conn, string sql) {
string columnNames = string.Join(',', columns); using var cmd = conn.Command(sql);
string columnParams = string.Join(',', columns.Select(static c => ':' + c)); return cmd.ExecuteScalar();
return conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" +
"VALUES (" + columnParams + ")");
} }
public static SqliteCommand Upsert(this SqliteConnection conn, string tableName, string[] columns) { public static SqliteCommand Insert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
string columnNames = string.Join(',', columns); string columnNames = string.Join(',', columns.Select(static c => c.Name));
string columnParams = string.Join(',', columns.Select(static c => ':' + c)); string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name));
string columnUpdates = string.Join(',', columns.Skip(1).Select(static c => c + " = excluded." + c));
return conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" + var cmd = conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" +
"VALUES (" + columnParams + ")" + "VALUES (" + columnParams + ")");
"ON CONFLICT (" + columns[0] + ")" +
"DO UPDATE SET " + columnUpdates); CreateParameters(cmd, columns);
return cmd;
} }
public static void AddAndSet(this SqliteParameterCollection parameters, string key, object? value) { public static SqliteCommand Upsert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
parameters.AddWithValue(key, value ?? DBNull.Value); string columnNames = string.Join(',', columns.Select(static c => c.Name));
string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name));
string columnUpdates = string.Join(',', columns.Skip(1).Select(static c => c.Name + " = excluded." + c.Name));
var cmd = conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" +
"VALUES (" + columnParams + ")" +
"ON CONFLICT (" + columns[0].Name + ")" +
"DO UPDATE SET " + columnUpdates);
CreateParameters(cmd, columns);
return cmd;
} }
public static void Set(this SqliteParameterCollection parameters, string key, object? value) { public static SqliteCommand Delete(this SqliteConnection conn, string tableName, (string Name, SqliteType Type) column) {
parameters[key].Value = value ?? DBNull.Value; var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name);
CreateParameters(cmd, new [] { column });
return cmd;
}
private static void CreateParameters(SqliteCommand cmd, (string Name, SqliteType Type)[] columns) {
foreach (var (name, type) in columns) {
cmd.Parameters.Add(":" + name, type);
}
}
public static void Set(this SqliteCommand cmd, string key, object? value) {
cmd.Parameters[key].Value = value ?? DBNull.Value;
}
public static ulong GetUint64(this SqliteDataReader reader, int ordinal) {
return (ulong) reader.GetInt64(ordinal);
} }
} }
} }