Add SQLite connection pooling and fix not releasing file lock after closing database

Closes #167
This commit is contained in:
chylex 2022-03-13 12:59:05 +01:00
parent 1bddde7ccd
commit ab7f5d0a41
No known key found for this signature in database
GPG Key ID: 4DE42C8F19A80548
5 changed files with 180 additions and 42 deletions

View File

@ -1,8 +1,8 @@
using System;
using System.Threading.Tasks;
using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite {
sealed class Schema {
@ -10,20 +10,14 @@ namespace DHT.Server.Database.Sqlite {
private static readonly Log Log = Log.ForType<Schema>();
private readonly SqliteConnection conn;
private readonly ISqliteConnection conn;
public Schema(SqliteConnection conn) {
public Schema(ISqliteConnection conn) {
this.conn = conn;
}
private SqliteCommand Sql(string sql) {
var cmd = conn.CreateCommand();
cmd.CommandText = sql;
return cmd;
}
private void Execute(string sql) {
Sql(sql).ExecuteNonQuery();
conn.Command(sql).ExecuteNonQuery();
}
public async Task<bool> Setup(Func<Task<bool>> checkCanUpgradeSchemas) {

View File

@ -5,46 +5,57 @@ using System.Text;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Collections;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite {
public sealed class SqliteDatabaseFile : IDatabaseFile {
private const int DefaultPoolSize = 5;
public static async Task<SqliteDatabaseFile?> OpenOrCreate(string path, Func<Task<bool>> checkCanUpgradeSchemas) {
string connectionString = new SqliteConnectionStringBuilder {
var connectionString = new SqliteConnectionStringBuilder {
DataSource = path,
Mode = SqliteOpenMode.ReadWriteCreate
}.ToString();
Mode = SqliteOpenMode.ReadWriteCreate,
};
var conn = new SqliteConnection(connectionString);
conn.Open();
var pool = new SqliteConnectionPool(connectionString, DefaultPoolSize);
return await new Schema(conn).Setup(checkCanUpgradeSchemas) ? new SqliteDatabaseFile(path, conn) : null;
using (var conn = pool.Take()) {
if (!await new Schema(conn).Setup(checkCanUpgradeSchemas)) {
return null;
}
}
return new SqliteDatabaseFile(path, pool);
}
public string Path { get; }
public DatabaseStatistics Statistics { get; }
private readonly Log log;
private readonly SqliteConnection conn;
private readonly SqliteConnectionPool pool;
private SqliteDatabaseFile(string path, SqliteConnection conn) {
private SqliteDatabaseFile(string path, SqliteConnectionPool pool) {
this.log = Log.ForType(typeof(SqliteDatabaseFile), System.IO.Path.GetFileName(path));
this.conn = conn;
this.pool = pool;
this.Path = path;
this.Statistics = new DatabaseStatistics();
UpdateServerStatistics();
UpdateChannelStatistics();
UpdateUserStatistics();
UpdateMessageStatistics();
using var conn = pool.Take();
UpdateServerStatistics(conn);
UpdateChannelStatistics(conn);
UpdateUserStatistics(conn);
UpdateMessageStatistics(conn);
}
public void Dispose() {
conn.Dispose();
pool.Dispose();
}
public void AddServer(Data.Server server) {
using var conn = pool.Take();
using var cmd = conn.Upsert("servers", new[] {
("id", SqliteType.Integer),
("name", SqliteType.Text),
@ -55,13 +66,14 @@ namespace DHT.Server.Database.Sqlite {
cmd.Set(":name", server.Name);
cmd.Set(":type", ServerTypes.ToString(server.Type));
cmd.ExecuteNonQuery();
UpdateServerStatistics();
UpdateServerStatistics(conn);
}
public List<Data.Server> GetAllServers() {
var perf = log.Start();
var list = new List<Data.Server>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, name, type FROM servers");
using var reader = cmd.ExecuteReader();
@ -78,6 +90,7 @@ namespace DHT.Server.Database.Sqlite {
}
public void AddChannel(Channel channel) {
using var conn = pool.Take();
using var cmd = conn.Upsert("channels", new[] {
("id", SqliteType.Integer),
("server", SqliteType.Integer),
@ -96,12 +109,13 @@ namespace DHT.Server.Database.Sqlite {
cmd.Set(":topic", channel.Topic);
cmd.Set(":nsfw", channel.Nsfw);
cmd.ExecuteNonQuery();
UpdateChannelStatistics();
UpdateChannelStatistics(conn);
}
public List<Channel> GetAllChannels() {
var list = new List<Channel>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
using var reader = cmd.ExecuteReader();
@ -121,6 +135,7 @@ namespace DHT.Server.Database.Sqlite {
}
public void AddUsers(User[] users) {
using var conn = pool.Take();
using var tx = conn.BeginTransaction();
using var cmd = conn.Upsert("users", new[] {
("id", SqliteType.Integer),
@ -138,13 +153,14 @@ namespace DHT.Server.Database.Sqlite {
}
tx.Commit();
UpdateUserStatistics();
UpdateUserStatistics(conn);
}
public List<User> GetAllUsers() {
var perf = log.Start();
var list = new List<User>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
using var reader = cmd.ExecuteReader();
@ -162,7 +178,7 @@ namespace DHT.Server.Database.Sqlite {
}
public void AddMessages(Message[] messages) {
static SqliteCommand DeleteByMessageId(SqliteConnection conn, string tableName) {
static SqliteCommand DeleteByMessageId(ISqliteConnection conn, string tableName) {
return conn.Delete(tableName, ("message_id", SqliteType.Integer));
}
@ -171,6 +187,7 @@ namespace DHT.Server.Database.Sqlite {
cmd.ExecuteNonQuery();
}
using var conn = pool.Take();
using var tx = conn.BeginTransaction();
using var messageCmd = conn.Upsert("messages", new[] {
@ -282,10 +299,11 @@ namespace DHT.Server.Database.Sqlite {
}
tx.Commit();
UpdateMessageStatistics();
UpdateMessageStatistics(conn);
}
public int CountMessages(MessageFilter? filter = null) {
using var conn = pool.Take();
using var cmd = conn.Command("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause());
using var reader = cmd.ExecuteReader();
@ -300,6 +318,7 @@ namespace DHT.Server.Database.Sqlite {
var embeds = GetAllEmbeds();
var reactions = GetAllReactions();
using var conn = pool.Take();
using var cmd = conn.Command(@"
SELECT m.message_id, m.sender_id, m.channel_id, m.text, m.timestamp, et.edit_timestamp, rt.replied_to_id
FROM messages m
@ -342,16 +361,18 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC
.Append("FROM messages")
.Append(whereClause);
using var conn = pool.Take();
using var cmd = conn.Command(build.ToString());
cmd.ExecuteNonQuery();
UpdateMessageStatistics();
UpdateMessageStatistics(conn);
perf.End();
}
private MultiDictionary<ulong, Attachment> GetAllAttachments() {
var dict = new MultiDictionary<ulong, Attachment>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, attachment_id, name, type, url, size FROM attachments");
using var reader = cmd.ExecuteReader();
@ -373,6 +394,7 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC
private MultiDictionary<ulong, Embed> GetAllEmbeds() {
var dict = new MultiDictionary<ulong, Embed>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, json FROM embeds");
using var reader = cmd.ExecuteReader();
@ -390,6 +412,7 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC
private MultiDictionary<ulong, Reaction> GetAllReactions() {
var dict = new MultiDictionary<ulong, Reaction>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, emoji_id, emoji_name, emoji_flags, count FROM reactions");
using var reader = cmd.ExecuteReader();
@ -407,19 +430,19 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC
return dict;
}
private void UpdateServerStatistics() {
private void UpdateServerStatistics(ISqliteConnection conn) {
Statistics.TotalServers = conn.SelectScalar("SELECT COUNT(*) FROM servers") as long? ?? 0;
}
private void UpdateChannelStatistics() {
private void UpdateChannelStatistics(ISqliteConnection conn) {
Statistics.TotalChannels = conn.SelectScalar("SELECT COUNT(*) FROM channels") as long? ?? 0;
}
private void UpdateUserStatistics() {
private void UpdateUserStatistics(ISqliteConnection conn) {
Statistics.TotalUsers = conn.SelectScalar("SELECT COUNT(*) FROM users") as long? ?? 0;
}
private void UpdateMessageStatistics() {
private void UpdateMessageStatistics(ISqliteConnection conn) {
Statistics.TotalMessages = conn.SelectScalar("SELECT COUNT(*) FROM messages") as long? ?? 0L;
}
}

View File

@ -0,0 +1,8 @@
using System;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils {
interface ISqliteConnection : IDisposable {
SqliteConnection InnerConnection { get; }
}
}

View File

@ -0,0 +1,109 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils {
sealed class SqliteConnectionPool : IDisposable {
private static string GetConnectionString(SqliteConnectionStringBuilder connectionStringBuilder) {
connectionStringBuilder.Pooling = false;
return connectionStringBuilder.ToString();
}
private readonly object monitor = new ();
private volatile bool isDisposed;
private readonly BlockingCollection<PooledConnection> free = new (new ConcurrentStack<PooledConnection>());
private readonly List<PooledConnection> used;
public SqliteConnectionPool(SqliteConnectionStringBuilder connectionStringBuilder, int poolSize) {
var connectionString = GetConnectionString(connectionStringBuilder);
for (int i = 0; i < poolSize; i++) {
var conn = new SqliteConnection(connectionString);
conn.Open();
free.Add(new PooledConnection(this, conn));
}
used = new List<PooledConnection>(poolSize);
}
private void ThrowIfDisposed() {
if (isDisposed) {
throw new ObjectDisposedException(nameof(SqliteConnectionPool));
}
}
public ISqliteConnection Take() {
PooledConnection? conn = null;
while (conn == null) {
ThrowIfDisposed();
lock (monitor) {
if (free.TryTake(out conn, TimeSpan.FromMilliseconds(100))) {
used.Add(conn);
break;
}
else {
Log.ForType<SqliteConnectionPool>().Warn("Thread " + Thread.CurrentThread.ManagedThreadId + " is starving for connections.");
}
}
}
return conn;
}
private void Return(PooledConnection conn) {
ThrowIfDisposed();
lock (monitor) {
if (used.Remove(conn)) {
free.Add(conn);
}
}
}
public void Dispose() {
if (isDisposed) {
return;
}
isDisposed = true;
lock (monitor) {
while (free.TryTake(out var conn)) {
Close(conn.InnerConnection);
}
foreach (var conn in used) {
Close(conn.InnerConnection);
}
free.Dispose();
used.Clear();
}
}
private static void Close(SqliteConnection conn) {
conn.Close();
conn.Dispose();
}
private sealed class PooledConnection : ISqliteConnection {
public SqliteConnection InnerConnection { get; }
private readonly SqliteConnectionPool pool;
public PooledConnection(SqliteConnectionPool pool, SqliteConnection conn) {
this.pool = pool;
this.InnerConnection = conn;
}
void IDisposable.Dispose() {
pool.Return(this);
}
}
}
}

View File

@ -2,20 +2,24 @@ using System;
using System.Linq;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite {
static class SqliteUtils {
public static SqliteCommand Command(this SqliteConnection conn, string sql) {
var cmd = conn.CreateCommand();
namespace DHT.Server.Database.Sqlite.Utils {
static class SqliteExtensions {
public static SqliteCommand Command(this ISqliteConnection conn, string sql) {
var cmd = conn.InnerConnection.CreateCommand();
cmd.CommandText = sql;
return cmd;
}
public static object? SelectScalar(this SqliteConnection conn, string sql) {
public static SqliteTransaction BeginTransaction(this ISqliteConnection conn) {
return conn.InnerConnection.BeginTransaction();
}
public static object? SelectScalar(this ISqliteConnection conn, string sql) {
using var cmd = conn.Command(sql);
return cmd.ExecuteScalar();
}
public static SqliteCommand Insert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
public static SqliteCommand Insert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
string columnNames = string.Join(',', columns.Select(static c => c.Name));
string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name));
@ -26,7 +30,7 @@ namespace DHT.Server.Database.Sqlite {
return cmd;
}
public static SqliteCommand Upsert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
public static SqliteCommand Upsert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
string columnNames = string.Join(',', columns.Select(static c => c.Name));
string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name));
string columnUpdates = string.Join(',', columns.Skip(1).Select(static c => c.Name + " = excluded." + c.Name));
@ -40,7 +44,7 @@ namespace DHT.Server.Database.Sqlite {
return cmd;
}
public static SqliteCommand Delete(this SqliteConnection conn, string tableName, (string Name, SqliteType Type) column) {
public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) {
var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name);
CreateParameters(cmd, new [] { column });
return cmd;