diff --git a/app/Server/Database/Sqlite/Schema.cs b/app/Server/Database/Sqlite/Schema.cs index 9deee74..b9be12f 100644 --- a/app/Server/Database/Sqlite/Schema.cs +++ b/app/Server/Database/Sqlite/Schema.cs @@ -1,8 +1,8 @@ using System; using System.Threading.Tasks; using DHT.Server.Database.Exceptions; +using DHT.Server.Database.Sqlite.Utils; using DHT.Utils.Logging; -using Microsoft.Data.Sqlite; namespace DHT.Server.Database.Sqlite { sealed class Schema { @@ -10,20 +10,14 @@ namespace DHT.Server.Database.Sqlite { private static readonly Log Log = Log.ForType(); - private readonly SqliteConnection conn; + private readonly ISqliteConnection conn; - public Schema(SqliteConnection conn) { + public Schema(ISqliteConnection conn) { this.conn = conn; } - private SqliteCommand Sql(string sql) { - var cmd = conn.CreateCommand(); - cmd.CommandText = sql; - return cmd; - } - private void Execute(string sql) { - Sql(sql).ExecuteNonQuery(); + conn.Command(sql).ExecuteNonQuery(); } public async Task Setup(Func> checkCanUpgradeSchemas) { diff --git a/app/Server/Database/Sqlite/SqliteDatabaseFile.cs b/app/Server/Database/Sqlite/SqliteDatabaseFile.cs index 0a05c42..9c11bcf 100644 --- a/app/Server/Database/Sqlite/SqliteDatabaseFile.cs +++ b/app/Server/Database/Sqlite/SqliteDatabaseFile.cs @@ -5,46 +5,57 @@ using System.Text; using System.Threading.Tasks; using DHT.Server.Data; using DHT.Server.Data.Filters; +using DHT.Server.Database.Sqlite.Utils; using DHT.Utils.Collections; using DHT.Utils.Logging; using Microsoft.Data.Sqlite; namespace DHT.Server.Database.Sqlite { public sealed class SqliteDatabaseFile : IDatabaseFile { + private const int DefaultPoolSize = 5; + public static async Task OpenOrCreate(string path, Func> checkCanUpgradeSchemas) { - string connectionString = new SqliteConnectionStringBuilder { + var connectionString = new SqliteConnectionStringBuilder { DataSource = path, - Mode = SqliteOpenMode.ReadWriteCreate - }.ToString(); + Mode = SqliteOpenMode.ReadWriteCreate, + }; - var conn = new SqliteConnection(connectionString); - conn.Open(); + var pool = new SqliteConnectionPool(connectionString, DefaultPoolSize); - return await new Schema(conn).Setup(checkCanUpgradeSchemas) ? new SqliteDatabaseFile(path, conn) : null; + using (var conn = pool.Take()) { + if (!await new Schema(conn).Setup(checkCanUpgradeSchemas)) { + return null; + } + } + + return new SqliteDatabaseFile(path, pool); } public string Path { get; } public DatabaseStatistics Statistics { get; } private readonly Log log; - private readonly SqliteConnection conn; + private readonly SqliteConnectionPool pool; - private SqliteDatabaseFile(string path, SqliteConnection conn) { + private SqliteDatabaseFile(string path, SqliteConnectionPool pool) { this.log = Log.ForType(typeof(SqliteDatabaseFile), System.IO.Path.GetFileName(path)); - this.conn = conn; + this.pool = pool; this.Path = path; this.Statistics = new DatabaseStatistics(); - UpdateServerStatistics(); - UpdateChannelStatistics(); - UpdateUserStatistics(); - UpdateMessageStatistics(); + + using var conn = pool.Take(); + UpdateServerStatistics(conn); + UpdateChannelStatistics(conn); + UpdateUserStatistics(conn); + UpdateMessageStatistics(conn); } public void Dispose() { - conn.Dispose(); + pool.Dispose(); } public void AddServer(Data.Server server) { + using var conn = pool.Take(); using var cmd = conn.Upsert("servers", new[] { ("id", SqliteType.Integer), ("name", SqliteType.Text), @@ -55,13 +66,14 @@ namespace DHT.Server.Database.Sqlite { cmd.Set(":name", server.Name); cmd.Set(":type", ServerTypes.ToString(server.Type)); cmd.ExecuteNonQuery(); - UpdateServerStatistics(); + UpdateServerStatistics(conn); } public List GetAllServers() { var perf = log.Start(); var list = new List(); + using var conn = pool.Take(); using var cmd = conn.Command("SELECT id, name, type FROM servers"); using var reader = cmd.ExecuteReader(); @@ -78,6 +90,7 @@ namespace DHT.Server.Database.Sqlite { } public void AddChannel(Channel channel) { + using var conn = pool.Take(); using var cmd = conn.Upsert("channels", new[] { ("id", SqliteType.Integer), ("server", SqliteType.Integer), @@ -96,12 +109,13 @@ namespace DHT.Server.Database.Sqlite { cmd.Set(":topic", channel.Topic); cmd.Set(":nsfw", channel.Nsfw); cmd.ExecuteNonQuery(); - UpdateChannelStatistics(); + UpdateChannelStatistics(conn); } public List GetAllChannels() { var list = new List(); + using var conn = pool.Take(); using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels"); using var reader = cmd.ExecuteReader(); @@ -121,6 +135,7 @@ namespace DHT.Server.Database.Sqlite { } public void AddUsers(User[] users) { + using var conn = pool.Take(); using var tx = conn.BeginTransaction(); using var cmd = conn.Upsert("users", new[] { ("id", SqliteType.Integer), @@ -138,13 +153,14 @@ namespace DHT.Server.Database.Sqlite { } tx.Commit(); - UpdateUserStatistics(); + UpdateUserStatistics(conn); } public List GetAllUsers() { var perf = log.Start(); var list = new List(); + using var conn = pool.Take(); using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users"); using var reader = cmd.ExecuteReader(); @@ -162,7 +178,7 @@ namespace DHT.Server.Database.Sqlite { } public void AddMessages(Message[] messages) { - static SqliteCommand DeleteByMessageId(SqliteConnection conn, string tableName) { + static SqliteCommand DeleteByMessageId(ISqliteConnection conn, string tableName) { return conn.Delete(tableName, ("message_id", SqliteType.Integer)); } @@ -171,6 +187,7 @@ namespace DHT.Server.Database.Sqlite { cmd.ExecuteNonQuery(); } + using var conn = pool.Take(); using var tx = conn.BeginTransaction(); using var messageCmd = conn.Upsert("messages", new[] { @@ -282,10 +299,11 @@ namespace DHT.Server.Database.Sqlite { } tx.Commit(); - UpdateMessageStatistics(); + UpdateMessageStatistics(conn); } public int CountMessages(MessageFilter? filter = null) { + using var conn = pool.Take(); using var cmd = conn.Command("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause()); using var reader = cmd.ExecuteReader(); @@ -300,6 +318,7 @@ namespace DHT.Server.Database.Sqlite { var embeds = GetAllEmbeds(); var reactions = GetAllReactions(); + using var conn = pool.Take(); using var cmd = conn.Command(@" SELECT m.message_id, m.sender_id, m.channel_id, m.text, m.timestamp, et.edit_timestamp, rt.replied_to_id FROM messages m @@ -342,16 +361,18 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC .Append("FROM messages") .Append(whereClause); + using var conn = pool.Take(); using var cmd = conn.Command(build.ToString()); cmd.ExecuteNonQuery(); - UpdateMessageStatistics(); + UpdateMessageStatistics(conn); perf.End(); } private MultiDictionary GetAllAttachments() { var dict = new MultiDictionary(); + using var conn = pool.Take(); using var cmd = conn.Command("SELECT message_id, attachment_id, name, type, url, size FROM attachments"); using var reader = cmd.ExecuteReader(); @@ -373,6 +394,7 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC private MultiDictionary GetAllEmbeds() { var dict = new MultiDictionary(); + using var conn = pool.Take(); using var cmd = conn.Command("SELECT message_id, json FROM embeds"); using var reader = cmd.ExecuteReader(); @@ -390,6 +412,7 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC private MultiDictionary GetAllReactions() { var dict = new MultiDictionary(); + using var conn = pool.Take(); using var cmd = conn.Command("SELECT message_id, emoji_id, emoji_name, emoji_flags, count FROM reactions"); using var reader = cmd.ExecuteReader(); @@ -407,19 +430,19 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC return dict; } - private void UpdateServerStatistics() { + private void UpdateServerStatistics(ISqliteConnection conn) { Statistics.TotalServers = conn.SelectScalar("SELECT COUNT(*) FROM servers") as long? ?? 0; } - private void UpdateChannelStatistics() { + private void UpdateChannelStatistics(ISqliteConnection conn) { Statistics.TotalChannels = conn.SelectScalar("SELECT COUNT(*) FROM channels") as long? ?? 0; } - private void UpdateUserStatistics() { + private void UpdateUserStatistics(ISqliteConnection conn) { Statistics.TotalUsers = conn.SelectScalar("SELECT COUNT(*) FROM users") as long? ?? 0; } - private void UpdateMessageStatistics() { + private void UpdateMessageStatistics(ISqliteConnection conn) { Statistics.TotalMessages = conn.SelectScalar("SELECT COUNT(*) FROM messages") as long? ?? 0L; } } diff --git a/app/Server/Database/Sqlite/Utils/ISqliteConnection.cs b/app/Server/Database/Sqlite/Utils/ISqliteConnection.cs new file mode 100644 index 0000000..0f31c1a --- /dev/null +++ b/app/Server/Database/Sqlite/Utils/ISqliteConnection.cs @@ -0,0 +1,8 @@ +using System; +using Microsoft.Data.Sqlite; + +namespace DHT.Server.Database.Sqlite.Utils { + interface ISqliteConnection : IDisposable { + SqliteConnection InnerConnection { get; } + } +} diff --git a/app/Server/Database/Sqlite/Utils/SqliteConnectionPool.cs b/app/Server/Database/Sqlite/Utils/SqliteConnectionPool.cs new file mode 100644 index 0000000..4501123 --- /dev/null +++ b/app/Server/Database/Sqlite/Utils/SqliteConnectionPool.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; +using DHT.Utils.Logging; +using Microsoft.Data.Sqlite; + +namespace DHT.Server.Database.Sqlite.Utils { + sealed class SqliteConnectionPool : IDisposable { + private static string GetConnectionString(SqliteConnectionStringBuilder connectionStringBuilder) { + connectionStringBuilder.Pooling = false; + return connectionStringBuilder.ToString(); + } + + private readonly object monitor = new (); + private volatile bool isDisposed; + + private readonly BlockingCollection free = new (new ConcurrentStack()); + private readonly List used; + + public SqliteConnectionPool(SqliteConnectionStringBuilder connectionStringBuilder, int poolSize) { + var connectionString = GetConnectionString(connectionStringBuilder); + + for (int i = 0; i < poolSize; i++) { + var conn = new SqliteConnection(connectionString); + conn.Open(); + free.Add(new PooledConnection(this, conn)); + } + + used = new List(poolSize); + } + + private void ThrowIfDisposed() { + if (isDisposed) { + throw new ObjectDisposedException(nameof(SqliteConnectionPool)); + } + } + + public ISqliteConnection Take() { + PooledConnection? conn = null; + + while (conn == null) { + ThrowIfDisposed(); + lock (monitor) { + if (free.TryTake(out conn, TimeSpan.FromMilliseconds(100))) { + used.Add(conn); + break; + } + else { + Log.ForType().Warn("Thread " + Thread.CurrentThread.ManagedThreadId + " is starving for connections."); + } + } + } + + return conn; + } + + private void Return(PooledConnection conn) { + ThrowIfDisposed(); + + lock (monitor) { + if (used.Remove(conn)) { + free.Add(conn); + } + } + } + + public void Dispose() { + if (isDisposed) { + return; + } + + isDisposed = true; + + lock (monitor) { + while (free.TryTake(out var conn)) { + Close(conn.InnerConnection); + } + + foreach (var conn in used) { + Close(conn.InnerConnection); + } + + free.Dispose(); + used.Clear(); + } + } + + private static void Close(SqliteConnection conn) { + conn.Close(); + conn.Dispose(); + } + + private sealed class PooledConnection : ISqliteConnection { + public SqliteConnection InnerConnection { get; } + + private readonly SqliteConnectionPool pool; + + public PooledConnection(SqliteConnectionPool pool, SqliteConnection conn) { + this.pool = pool; + this.InnerConnection = conn; + } + + void IDisposable.Dispose() { + pool.Return(this); + } + } + } +} diff --git a/app/Server/Database/Sqlite/SqliteUtils.cs b/app/Server/Database/Sqlite/Utils/SqliteExtensions.cs similarity index 69% rename from app/Server/Database/Sqlite/SqliteUtils.cs rename to app/Server/Database/Sqlite/Utils/SqliteExtensions.cs index 53e7c2c..c159753 100644 --- a/app/Server/Database/Sqlite/SqliteUtils.cs +++ b/app/Server/Database/Sqlite/Utils/SqliteExtensions.cs @@ -2,20 +2,24 @@ using System; using System.Linq; using Microsoft.Data.Sqlite; -namespace DHT.Server.Database.Sqlite { - static class SqliteUtils { - public static SqliteCommand Command(this SqliteConnection conn, string sql) { - var cmd = conn.CreateCommand(); +namespace DHT.Server.Database.Sqlite.Utils { + static class SqliteExtensions { + public static SqliteCommand Command(this ISqliteConnection conn, string sql) { + var cmd = conn.InnerConnection.CreateCommand(); cmd.CommandText = sql; return cmd; } - public static object? SelectScalar(this SqliteConnection conn, string sql) { + public static SqliteTransaction BeginTransaction(this ISqliteConnection conn) { + return conn.InnerConnection.BeginTransaction(); + } + + public static object? SelectScalar(this ISqliteConnection conn, string sql) { using var cmd = conn.Command(sql); return cmd.ExecuteScalar(); } - public static SqliteCommand Insert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) { + public static SqliteCommand Insert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) { string columnNames = string.Join(',', columns.Select(static c => c.Name)); string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name)); @@ -26,7 +30,7 @@ namespace DHT.Server.Database.Sqlite { return cmd; } - public static SqliteCommand Upsert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) { + public static SqliteCommand Upsert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) { string columnNames = string.Join(',', columns.Select(static c => c.Name)); string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name)); string columnUpdates = string.Join(',', columns.Skip(1).Select(static c => c.Name + " = excluded." + c.Name)); @@ -40,7 +44,7 @@ namespace DHT.Server.Database.Sqlite { return cmd; } - public static SqliteCommand Delete(this SqliteConnection conn, string tableName, (string Name, SqliteType Type) column) { + public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) { var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name); CreateParameters(cmd, new [] { column }); return cmd;