Add utilities to simplify working with SQLite

This commit is contained in:
chylex 2022-03-05 22:18:02 +01:00
parent b9899922e0
commit 6f1149ad5e
No known key found for this signature in database
GPG Key ID: 4DE42C8F19A80548
3 changed files with 137 additions and 134 deletions

View File

@ -26,7 +26,7 @@ namespace DHT.Server.Database.Sqlite {
public async Task<bool> Setup(Func<Task<bool>> checkCanUpgradeSchemas) {
Execute(@"CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)");
var dbVersionStr = Sql("SELECT value FROM metadata WHERE key = 'version'").ExecuteScalar();
var dbVersionStr = conn.SelectScalar("SELECT value FROM metadata WHERE key = 'version'");
if (dbVersionStr == null) {
InitializeSchemas();
}

View File

@ -43,13 +43,14 @@ namespace DHT.Server.Database.Sqlite {
public void AddServer(Data.Server server) {
using var cmd = conn.Upsert("servers", new[] {
"id", "name", "type"
("id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text)
});
var serverParams = cmd.Parameters;
serverParams.AddAndSet(":id", server.Id);
serverParams.AddAndSet(":name", server.Name);
serverParams.AddAndSet(":type", ServerTypes.ToString(server.Type));
cmd.Set(":id", server.Id);
cmd.Set(":name", server.Name);
cmd.Set(":type", ServerTypes.ToString(server.Type));
cmd.ExecuteNonQuery();
UpdateServerStatistics();
}
@ -62,7 +63,7 @@ namespace DHT.Server.Database.Sqlite {
while (reader.Read()) {
list.Add(new Data.Server {
Id = (ulong) reader.GetInt64(0),
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = ServerTypes.FromString(reader.GetString(2))
});
@ -73,17 +74,22 @@ namespace DHT.Server.Database.Sqlite {
public void AddChannel(Channel channel) {
using var cmd = conn.Upsert("channels", new[] {
"id", "server", "name", "parent_id", "position", "topic", "nsfw"
("id", SqliteType.Integer),
("server", SqliteType.Integer),
("name", SqliteType.Text),
("parent_id", SqliteType.Integer),
("position", SqliteType.Integer),
("topic", SqliteType.Text),
("nsfw", SqliteType.Integer)
});
var channelParams = cmd.Parameters;
channelParams.AddAndSet(":id", channel.Id);
channelParams.AddAndSet(":server", channel.Server);
channelParams.AddAndSet(":name", channel.Name);
channelParams.AddAndSet(":parent_id", channel.ParentId);
channelParams.AddAndSet(":position", channel.Position);
channelParams.AddAndSet(":topic", channel.Topic);
channelParams.AddAndSet(":nsfw", channel.Nsfw);
cmd.Set(":id", channel.Id);
cmd.Set(":server", channel.Server);
cmd.Set(":name", channel.Name);
cmd.Set(":parent_id", channel.ParentId);
cmd.Set(":position", channel.Position);
cmd.Set(":topic", channel.Topic);
cmd.Set(":nsfw", channel.Nsfw);
cmd.ExecuteNonQuery();
UpdateChannelStatistics();
}
@ -96,10 +102,10 @@ namespace DHT.Server.Database.Sqlite {
while (reader.Read()) {
list.Add(new Channel {
Id = (ulong) reader.GetInt64(0),
Server = (ulong) reader.GetInt64(1),
Id = reader.GetUint64(0),
Server = reader.GetUint64(1),
Name = reader.GetString(2),
ParentId = reader.IsDBNull(3) ? null : (ulong) reader.GetInt64(3),
ParentId = reader.IsDBNull(3) ? null : reader.GetUint64(3),
Position = reader.IsDBNull(4) ? null : reader.GetInt32(4),
Topic = reader.IsDBNull(5) ? null : reader.GetString(5),
Nsfw = reader.IsDBNull(6) ? null : reader.GetBoolean(6)
@ -112,20 +118,17 @@ namespace DHT.Server.Database.Sqlite {
public void AddUsers(User[] users) {
using var tx = conn.BeginTransaction();
using var cmd = conn.Upsert("users", new[] {
"id", "name", "avatar_url", "discriminator"
("id", SqliteType.Integer),
("name", SqliteType.Text),
("avatar_url", SqliteType.Text),
("discriminator", SqliteType.Text)
});
var userParams = cmd.Parameters;
userParams.Add(":id", SqliteType.Integer);
userParams.Add(":name", SqliteType.Text);
userParams.Add(":avatar_url", SqliteType.Text);
userParams.Add(":discriminator", SqliteType.Text);
foreach (var user in users) {
userParams.Set(":id", user.Id);
userParams.Set(":name", user.Name);
userParams.Set(":avatar_url", user.AvatarUrl);
userParams.Set(":discriminator", user.Discriminator);
cmd.Set(":id", user.Id);
cmd.Set(":name", user.Name);
cmd.Set(":avatar_url", user.AvatarUrl);
cmd.Set(":discriminator", user.Discriminator);
cmd.ExecuteNonQuery();
}
@ -141,7 +144,7 @@ namespace DHT.Server.Database.Sqlite {
while (reader.Read()) {
list.Add(new User {
Id = (ulong) reader.GetInt64(0),
Id = reader.GetUint64(0),
Name = reader.GetString(1),
AvatarUrl = reader.IsDBNull(2) ? null : reader.GetString(2),
Discriminator = reader.IsDBNull(3) ? null : reader.GetString(3)
@ -153,110 +156,91 @@ namespace DHT.Server.Database.Sqlite {
public void AddMessages(Message[] messages) {
using var tx = conn.BeginTransaction();
using var messageCmd = conn.Upsert("messages", new[] {
"message_id", "sender_id", "channel_id", "text", "timestamp", "edit_timestamp", "replied_to_id"
("message_id", SqliteType.Integer),
("sender_id", SqliteType.Integer),
("channel_id", SqliteType.Integer),
("text", SqliteType.Text),
("timestamp", SqliteType.Integer),
("edit_timestamp", SqliteType.Integer),
("replied_to_id", SqliteType.Integer)
});
using var deleteAttachmentsCmd = conn.Command("DELETE FROM attachments WHERE message_id = :message_id");
using var deleteAttachmentsCmd = conn.Delete("attachments", ("message_id", SqliteType.Integer));
using var deleteEmbedsCmd = conn.Delete("embeds", ("message_id", SqliteType.Integer));
using var deleteReactionsCmd = conn.Delete("reactions", ("message_id", SqliteType.Integer));
using var attachmentCmd = conn.Insert("attachments", new[] {
"message_id", "attachment_id", "name", "type", "url", "size"
("message_id", SqliteType.Integer),
("attachment_id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text),
("url", SqliteType.Text),
("size", SqliteType.Integer)
});
using var deleteEmbedsCmd = conn.Command("DELETE FROM embeds WHERE message_id = :message_id");
using var embedCmd = conn.Insert("embeds", new[] {
"message_id", "json"
("message_id", SqliteType.Integer),
("json", SqliteType.Text)
});
using var deleteReactionsCmd = conn.Command("DELETE FROM reactions WHERE message_id = :message_id");
using var reactionCmd = conn.Insert("reactions", new[] {
"message_id", "emoji_id", "emoji_name", "emoji_flags", "count"
("message_id", SqliteType.Integer),
("emoji_id", SqliteType.Integer),
("emoji_name", SqliteType.Text),
("emoji_flags", SqliteType.Integer),
("count", SqliteType.Integer)
});
var messageParams = messageCmd.Parameters;
messageParams.Add(":message_id", SqliteType.Integer);
messageParams.Add(":sender_id", SqliteType.Integer);
messageParams.Add(":channel_id", SqliteType.Integer);
messageParams.Add(":text", SqliteType.Text);
messageParams.Add(":timestamp", SqliteType.Integer);
messageParams.Add(":edit_timestamp", SqliteType.Integer);
messageParams.Add(":replied_to_id", SqliteType.Integer);
var deleteAttachmentsParams = deleteAttachmentsCmd.Parameters;
deleteAttachmentsParams.Add(":message_id", SqliteType.Integer);
var attachmentParams = attachmentCmd.Parameters;
attachmentParams.Add(":message_id", SqliteType.Integer);
attachmentParams.Add(":attachment_id", SqliteType.Integer);
attachmentParams.Add(":name", SqliteType.Text);
attachmentParams.Add(":type", SqliteType.Text);
attachmentParams.Add(":url", SqliteType.Text);
attachmentParams.Add(":size", SqliteType.Integer);
var deleteEmbedsParams = deleteEmbedsCmd.Parameters;
deleteEmbedsParams.Add(":message_id", SqliteType.Integer);
var embedParams = embedCmd.Parameters;
embedParams.Add(":message_id", SqliteType.Integer);
embedParams.Add(":json", SqliteType.Text);
var deleteReactionsParams = deleteReactionsCmd.Parameters;
deleteReactionsParams.Add(":message_id", SqliteType.Integer);
var reactionParams = reactionCmd.Parameters;
reactionParams.Add(":message_id", SqliteType.Integer);
reactionParams.Add(":emoji_id", SqliteType.Integer);
reactionParams.Add(":emoji_name", SqliteType.Text);
reactionParams.Add(":emoji_flags", SqliteType.Integer);
reactionParams.Add(":count", SqliteType.Integer);
foreach (var message in messages) {
object messageId = message.Id;
messageParams.Set(":message_id", messageId);
messageParams.Set(":sender_id", message.Sender);
messageParams.Set(":channel_id", message.Channel);
messageParams.Set(":text", message.Text);
messageParams.Set(":timestamp", message.Timestamp);
messageParams.Set(":edit_timestamp", message.EditTimestamp);
messageParams.Set(":replied_to_id", message.RepliedToId);
messageCmd.Set(":message_id", messageId);
messageCmd.Set(":sender_id", message.Sender);
messageCmd.Set(":channel_id", message.Channel);
messageCmd.Set(":text", message.Text);
messageCmd.Set(":timestamp", message.Timestamp);
messageCmd.Set(":edit_timestamp", message.EditTimestamp);
messageCmd.Set(":replied_to_id", message.RepliedToId);
messageCmd.ExecuteNonQuery();
deleteAttachmentsParams.Set(":message_id", messageId);
deleteAttachmentsCmd.Set(":message_id", messageId);
deleteAttachmentsCmd.ExecuteNonQuery();
deleteEmbedsParams.Set(":message_id", messageId);
deleteEmbedsCmd.Set(":message_id", messageId);
deleteEmbedsCmd.ExecuteNonQuery();
deleteReactionsParams.Set(":message_id", messageId);
deleteReactionsCmd.Set(":message_id", messageId);
deleteReactionsCmd.ExecuteNonQuery();
if (!message.Attachments.IsEmpty) {
foreach (var attachment in message.Attachments) {
attachmentParams.Set(":message_id", messageId);
attachmentParams.Set(":attachment_id", attachment.Id);
attachmentParams.Set(":name", attachment.Name);
attachmentParams.Set(":type", attachment.Type);
attachmentParams.Set(":url", attachment.Url);
attachmentParams.Set(":size", attachment.Size);
attachmentCmd.Set(":message_id", messageId);
attachmentCmd.Set(":attachment_id", attachment.Id);
attachmentCmd.Set(":name", attachment.Name);
attachmentCmd.Set(":type", attachment.Type);
attachmentCmd.Set(":url", attachment.Url);
attachmentCmd.Set(":size", attachment.Size);
attachmentCmd.ExecuteNonQuery();
}
}
if (!message.Embeds.IsEmpty) {
foreach (var embed in message.Embeds) {
embedParams.Set(":message_id", messageId);
embedParams.Set(":json", embed.Json);
embedCmd.Set(":message_id", messageId);
embedCmd.Set(":json", embed.Json);
embedCmd.ExecuteNonQuery();
}
}
if (!message.Reactions.IsEmpty) {
foreach (var reaction in message.Reactions) {
reactionParams.Set(":message_id", messageId);
reactionParams.Set(":emoji_id", reaction.EmojiId);
reactionParams.Set(":emoji_name", reaction.EmojiName);
reactionParams.Set(":emoji_flags", (int) reaction.EmojiFlags);
reactionParams.Set(":count", reaction.Count);
reactionCmd.Set(":message_id", messageId);
reactionCmd.Set(":emoji_id", reaction.EmojiId);
reactionCmd.Set(":emoji_name", reaction.EmojiName);
reactionCmd.Set(":emoji_flags", (int) reaction.EmojiFlags);
reactionCmd.Set(":count", reaction.Count);
reactionCmd.ExecuteNonQuery();
}
}
@ -284,16 +268,16 @@ namespace DHT.Server.Database.Sqlite {
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong id = (ulong) reader.GetInt64(0);
ulong id = reader.GetUint64(0);
list.Add(new Message {
Id = id,
Sender = (ulong) reader.GetInt64(1),
Channel = (ulong) reader.GetInt64(2),
Sender = reader.GetUint64(1),
Channel = reader.GetUint64(2),
Text = reader.GetString(3),
Timestamp = reader.GetInt64(4),
EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5),
RepliedToId = reader.IsDBNull(6) ? null : (ulong) reader.GetInt64(6),
RepliedToId = reader.IsDBNull(6) ? null : reader.GetUint64(6),
Attachments = attachments.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Attachment>.Empty,
Embeds = embeds.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Embed>.Empty,
Reactions = reactions.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Reaction>.Empty
@ -328,14 +312,14 @@ namespace DHT.Server.Database.Sqlite {
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = (ulong) reader.GetInt64(0);
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Attachment {
Id = (ulong) reader.GetInt64(1),
Id = reader.GetUint64(1),
Name = reader.GetString(2),
Type = reader.IsDBNull(3) ? null : reader.GetString(3),
Url = reader.GetString(4),
Size = (ulong) reader.GetInt64(5)
Size = reader.GetUint64(5)
});
}
@ -349,7 +333,7 @@ namespace DHT.Server.Database.Sqlite {
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = (ulong) reader.GetInt64(0);
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Embed {
Json = reader.GetString(1)
@ -366,10 +350,10 @@ namespace DHT.Server.Database.Sqlite {
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = (ulong) reader.GetInt64(0);
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Reaction {
EmojiId = reader.IsDBNull(1) ? null : (ulong) reader.GetInt64(1),
EmojiId = reader.IsDBNull(1) ? null : reader.GetUint64(1),
EmojiName = reader.IsDBNull(2) ? null : reader.GetString(2),
EmojiFlags = (EmojiFlags) reader.GetInt16(3),
Count = reader.GetInt32(4)
@ -380,23 +364,19 @@ namespace DHT.Server.Database.Sqlite {
}
private void UpdateServerStatistics() {
using var cmd = conn.Command("SELECT COUNT(*) FROM servers");
Statistics.TotalServers = cmd.ExecuteScalar() as long? ?? 0;
Statistics.TotalServers = conn.SelectScalar("SELECT COUNT(*) FROM servers") as long? ?? 0;
}
private void UpdateChannelStatistics() {
using var cmd = conn.Command("SELECT COUNT(*) FROM channels");
Statistics.TotalChannels = cmd.ExecuteScalar() as long? ?? 0;
Statistics.TotalChannels = conn.SelectScalar("SELECT COUNT(*) FROM channels") as long? ?? 0;
}
private void UpdateUserStatistics() {
using var cmd = conn.Command("SELECT COUNT(*) FROM users");
Statistics.TotalUsers = cmd.ExecuteScalar() as long? ?? 0;
Statistics.TotalUsers = conn.SelectScalar("SELECT COUNT(*) FROM users") as long? ?? 0;
}
private void UpdateMessageStatistics() {
using var cmd = conn.Command("SELECT COUNT(*) FROM messages");
Statistics.TotalMessages = cmd.ExecuteScalar() as long? ?? 0L;
Statistics.TotalMessages = conn.SelectScalar("SELECT COUNT(*) FROM messages") as long? ?? 0L;
}
}
}

View File

@ -10,31 +10,54 @@ namespace DHT.Server.Database.Sqlite {
return cmd;
}
public static SqliteCommand Insert(this SqliteConnection conn, string tableName, string[] columns) {
string columnNames = string.Join(',', columns);
string columnParams = string.Join(',', columns.Select(static c => ':' + c));
return conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" +
"VALUES (" + columnParams + ")");
public static object? SelectScalar(this SqliteConnection conn, string sql) {
using var cmd = conn.Command(sql);
return cmd.ExecuteScalar();
}
public static SqliteCommand Upsert(this SqliteConnection conn, string tableName, string[] columns) {
string columnNames = string.Join(',', columns);
string columnParams = string.Join(',', columns.Select(static c => ':' + c));
string columnUpdates = string.Join(',', columns.Skip(1).Select(static c => c + " = excluded." + c));
public static SqliteCommand Insert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
string columnNames = string.Join(',', columns.Select(static c => c.Name));
string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name));
return conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" +
"VALUES (" + columnParams + ")" +
"ON CONFLICT (" + columns[0] + ")" +
"DO UPDATE SET " + columnUpdates);
var cmd = conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" +
"VALUES (" + columnParams + ")");
CreateParameters(cmd, columns);
return cmd;
}
public static void AddAndSet(this SqliteParameterCollection parameters, string key, object? value) {
parameters.AddWithValue(key, value ?? DBNull.Value);
public static SqliteCommand Upsert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
string columnNames = string.Join(',', columns.Select(static c => c.Name));
string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name));
string columnUpdates = string.Join(',', columns.Skip(1).Select(static c => c.Name + " = excluded." + c.Name));
var cmd = conn.Command("INSERT INTO " + tableName + " (" + columnNames + ")" +
"VALUES (" + columnParams + ")" +
"ON CONFLICT (" + columns[0].Name + ")" +
"DO UPDATE SET " + columnUpdates);
CreateParameters(cmd, columns);
return cmd;
}
public static void Set(this SqliteParameterCollection parameters, string key, object? value) {
parameters[key].Value = value ?? DBNull.Value;
public static SqliteCommand Delete(this SqliteConnection conn, string tableName, (string Name, SqliteType Type) column) {
var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name);
CreateParameters(cmd, new [] { column });
return cmd;
}
private static void CreateParameters(SqliteCommand cmd, (string Name, SqliteType Type)[] columns) {
foreach (var (name, type) in columns) {
cmd.Parameters.Add(":" + name, type);
}
}
public static void Set(this SqliteCommand cmd, string key, object? value) {
cmd.Parameters[key].Value = value ?? DBNull.Value;
}
public static ulong GetUint64(this SqliteDataReader reader, int ordinal) {
return (ulong) reader.GetInt64(ordinal);
}
}
}