Add SQLite connection pooling and fix not releasing file lock after closing database

Closes #167
This commit is contained in:
chylex 2022-03-13 12:59:05 +01:00
parent 1bddde7ccd
commit ab7f5d0a41
No known key found for this signature in database
GPG Key ID: 4DE42C8F19A80548
5 changed files with 180 additions and 42 deletions

View File

@ -1,8 +1,8 @@
using System; using System;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Database.Exceptions; using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite { namespace DHT.Server.Database.Sqlite {
sealed class Schema { sealed class Schema {
@ -10,20 +10,14 @@ namespace DHT.Server.Database.Sqlite {
private static readonly Log Log = Log.ForType<Schema>(); private static readonly Log Log = Log.ForType<Schema>();
private readonly SqliteConnection conn; private readonly ISqliteConnection conn;
public Schema(SqliteConnection conn) { public Schema(ISqliteConnection conn) {
this.conn = conn; this.conn = conn;
} }
private SqliteCommand Sql(string sql) {
var cmd = conn.CreateCommand();
cmd.CommandText = sql;
return cmd;
}
private void Execute(string sql) { private void Execute(string sql) {
Sql(sql).ExecuteNonQuery(); conn.Command(sql).ExecuteNonQuery();
} }
public async Task<bool> Setup(Func<Task<bool>> checkCanUpgradeSchemas) { public async Task<bool> Setup(Func<Task<bool>> checkCanUpgradeSchemas) {

View File

@ -5,46 +5,57 @@ using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Collections; using DHT.Utils.Collections;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite { namespace DHT.Server.Database.Sqlite {
public sealed class SqliteDatabaseFile : IDatabaseFile { public sealed class SqliteDatabaseFile : IDatabaseFile {
private const int DefaultPoolSize = 5;
public static async Task<SqliteDatabaseFile?> OpenOrCreate(string path, Func<Task<bool>> checkCanUpgradeSchemas) { public static async Task<SqliteDatabaseFile?> OpenOrCreate(string path, Func<Task<bool>> checkCanUpgradeSchemas) {
string connectionString = new SqliteConnectionStringBuilder { var connectionString = new SqliteConnectionStringBuilder {
DataSource = path, DataSource = path,
Mode = SqliteOpenMode.ReadWriteCreate Mode = SqliteOpenMode.ReadWriteCreate,
}.ToString(); };
var conn = new SqliteConnection(connectionString); var pool = new SqliteConnectionPool(connectionString, DefaultPoolSize);
conn.Open();
return await new Schema(conn).Setup(checkCanUpgradeSchemas) ? new SqliteDatabaseFile(path, conn) : null; using (var conn = pool.Take()) {
if (!await new Schema(conn).Setup(checkCanUpgradeSchemas)) {
return null;
}
}
return new SqliteDatabaseFile(path, pool);
} }
public string Path { get; } public string Path { get; }
public DatabaseStatistics Statistics { get; } public DatabaseStatistics Statistics { get; }
private readonly Log log; private readonly Log log;
private readonly SqliteConnection conn; private readonly SqliteConnectionPool pool;
private SqliteDatabaseFile(string path, SqliteConnection conn) { private SqliteDatabaseFile(string path, SqliteConnectionPool pool) {
this.log = Log.ForType(typeof(SqliteDatabaseFile), System.IO.Path.GetFileName(path)); this.log = Log.ForType(typeof(SqliteDatabaseFile), System.IO.Path.GetFileName(path));
this.conn = conn; this.pool = pool;
this.Path = path; this.Path = path;
this.Statistics = new DatabaseStatistics(); this.Statistics = new DatabaseStatistics();
UpdateServerStatistics();
UpdateChannelStatistics(); using var conn = pool.Take();
UpdateUserStatistics(); UpdateServerStatistics(conn);
UpdateMessageStatistics(); UpdateChannelStatistics(conn);
UpdateUserStatistics(conn);
UpdateMessageStatistics(conn);
} }
public void Dispose() { public void Dispose() {
conn.Dispose(); pool.Dispose();
} }
public void AddServer(Data.Server server) { public void AddServer(Data.Server server) {
using var conn = pool.Take();
using var cmd = conn.Upsert("servers", new[] { using var cmd = conn.Upsert("servers", new[] {
("id", SqliteType.Integer), ("id", SqliteType.Integer),
("name", SqliteType.Text), ("name", SqliteType.Text),
@ -55,13 +66,14 @@ namespace DHT.Server.Database.Sqlite {
cmd.Set(":name", server.Name); cmd.Set(":name", server.Name);
cmd.Set(":type", ServerTypes.ToString(server.Type)); cmd.Set(":type", ServerTypes.ToString(server.Type));
cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery();
UpdateServerStatistics(); UpdateServerStatistics(conn);
} }
public List<Data.Server> GetAllServers() { public List<Data.Server> GetAllServers() {
var perf = log.Start(); var perf = log.Start();
var list = new List<Data.Server>(); var list = new List<Data.Server>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, name, type FROM servers"); using var cmd = conn.Command("SELECT id, name, type FROM servers");
using var reader = cmd.ExecuteReader(); using var reader = cmd.ExecuteReader();
@ -78,6 +90,7 @@ namespace DHT.Server.Database.Sqlite {
} }
public void AddChannel(Channel channel) { public void AddChannel(Channel channel) {
using var conn = pool.Take();
using var cmd = conn.Upsert("channels", new[] { using var cmd = conn.Upsert("channels", new[] {
("id", SqliteType.Integer), ("id", SqliteType.Integer),
("server", SqliteType.Integer), ("server", SqliteType.Integer),
@ -96,12 +109,13 @@ namespace DHT.Server.Database.Sqlite {
cmd.Set(":topic", channel.Topic); cmd.Set(":topic", channel.Topic);
cmd.Set(":nsfw", channel.Nsfw); cmd.Set(":nsfw", channel.Nsfw);
cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery();
UpdateChannelStatistics(); UpdateChannelStatistics(conn);
} }
public List<Channel> GetAllChannels() { public List<Channel> GetAllChannels() {
var list = new List<Channel>(); var list = new List<Channel>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels"); using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
using var reader = cmd.ExecuteReader(); using var reader = cmd.ExecuteReader();
@ -121,6 +135,7 @@ namespace DHT.Server.Database.Sqlite {
} }
public void AddUsers(User[] users) { public void AddUsers(User[] users) {
using var conn = pool.Take();
using var tx = conn.BeginTransaction(); using var tx = conn.BeginTransaction();
using var cmd = conn.Upsert("users", new[] { using var cmd = conn.Upsert("users", new[] {
("id", SqliteType.Integer), ("id", SqliteType.Integer),
@ -138,13 +153,14 @@ namespace DHT.Server.Database.Sqlite {
} }
tx.Commit(); tx.Commit();
UpdateUserStatistics(); UpdateUserStatistics(conn);
} }
public List<User> GetAllUsers() { public List<User> GetAllUsers() {
var perf = log.Start(); var perf = log.Start();
var list = new List<User>(); var list = new List<User>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users"); using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
using var reader = cmd.ExecuteReader(); using var reader = cmd.ExecuteReader();
@ -162,7 +178,7 @@ namespace DHT.Server.Database.Sqlite {
} }
public void AddMessages(Message[] messages) { public void AddMessages(Message[] messages) {
static SqliteCommand DeleteByMessageId(SqliteConnection conn, string tableName) { static SqliteCommand DeleteByMessageId(ISqliteConnection conn, string tableName) {
return conn.Delete(tableName, ("message_id", SqliteType.Integer)); return conn.Delete(tableName, ("message_id", SqliteType.Integer));
} }
@ -171,6 +187,7 @@ namespace DHT.Server.Database.Sqlite {
cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery();
} }
using var conn = pool.Take();
using var tx = conn.BeginTransaction(); using var tx = conn.BeginTransaction();
using var messageCmd = conn.Upsert("messages", new[] { using var messageCmd = conn.Upsert("messages", new[] {
@ -282,10 +299,11 @@ namespace DHT.Server.Database.Sqlite {
} }
tx.Commit(); tx.Commit();
UpdateMessageStatistics(); UpdateMessageStatistics(conn);
} }
public int CountMessages(MessageFilter? filter = null) { public int CountMessages(MessageFilter? filter = null) {
using var conn = pool.Take();
using var cmd = conn.Command("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause()); using var cmd = conn.Command("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause());
using var reader = cmd.ExecuteReader(); using var reader = cmd.ExecuteReader();
@ -300,6 +318,7 @@ namespace DHT.Server.Database.Sqlite {
var embeds = GetAllEmbeds(); var embeds = GetAllEmbeds();
var reactions = GetAllReactions(); var reactions = GetAllReactions();
using var conn = pool.Take();
using var cmd = conn.Command(@" using var cmd = conn.Command(@"
SELECT m.message_id, m.sender_id, m.channel_id, m.text, m.timestamp, et.edit_timestamp, rt.replied_to_id SELECT m.message_id, m.sender_id, m.channel_id, m.text, m.timestamp, et.edit_timestamp, rt.replied_to_id
FROM messages m FROM messages m
@ -342,16 +361,18 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC
.Append("FROM messages") .Append("FROM messages")
.Append(whereClause); .Append(whereClause);
using var conn = pool.Take();
using var cmd = conn.Command(build.ToString()); using var cmd = conn.Command(build.ToString());
cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery();
UpdateMessageStatistics(); UpdateMessageStatistics(conn);
perf.End(); perf.End();
} }
private MultiDictionary<ulong, Attachment> GetAllAttachments() { private MultiDictionary<ulong, Attachment> GetAllAttachments() {
var dict = new MultiDictionary<ulong, Attachment>(); var dict = new MultiDictionary<ulong, Attachment>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, attachment_id, name, type, url, size FROM attachments"); using var cmd = conn.Command("SELECT message_id, attachment_id, name, type, url, size FROM attachments");
using var reader = cmd.ExecuteReader(); using var reader = cmd.ExecuteReader();
@ -373,6 +394,7 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC
private MultiDictionary<ulong, Embed> GetAllEmbeds() { private MultiDictionary<ulong, Embed> GetAllEmbeds() {
var dict = new MultiDictionary<ulong, Embed>(); var dict = new MultiDictionary<ulong, Embed>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, json FROM embeds"); using var cmd = conn.Command("SELECT message_id, json FROM embeds");
using var reader = cmd.ExecuteReader(); using var reader = cmd.ExecuteReader();
@ -390,6 +412,7 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC
private MultiDictionary<ulong, Reaction> GetAllReactions() { private MultiDictionary<ulong, Reaction> GetAllReactions() {
var dict = new MultiDictionary<ulong, Reaction>(); var dict = new MultiDictionary<ulong, Reaction>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, emoji_id, emoji_name, emoji_flags, count FROM reactions"); using var cmd = conn.Command("SELECT message_id, emoji_id, emoji_name, emoji_flags, count FROM reactions");
using var reader = cmd.ExecuteReader(); using var reader = cmd.ExecuteReader();
@ -407,19 +430,19 @@ LEFT JOIN replied_to rt ON m.message_id = rt.message_id" + filter.GenerateWhereC
return dict; return dict;
} }
private void UpdateServerStatistics() { private void UpdateServerStatistics(ISqliteConnection conn) {
Statistics.TotalServers = conn.SelectScalar("SELECT COUNT(*) FROM servers") as long? ?? 0; Statistics.TotalServers = conn.SelectScalar("SELECT COUNT(*) FROM servers") as long? ?? 0;
} }
private void UpdateChannelStatistics() { private void UpdateChannelStatistics(ISqliteConnection conn) {
Statistics.TotalChannels = conn.SelectScalar("SELECT COUNT(*) FROM channels") as long? ?? 0; Statistics.TotalChannels = conn.SelectScalar("SELECT COUNT(*) FROM channels") as long? ?? 0;
} }
private void UpdateUserStatistics() { private void UpdateUserStatistics(ISqliteConnection conn) {
Statistics.TotalUsers = conn.SelectScalar("SELECT COUNT(*) FROM users") as long? ?? 0; Statistics.TotalUsers = conn.SelectScalar("SELECT COUNT(*) FROM users") as long? ?? 0;
} }
private void UpdateMessageStatistics() { private void UpdateMessageStatistics(ISqliteConnection conn) {
Statistics.TotalMessages = conn.SelectScalar("SELECT COUNT(*) FROM messages") as long? ?? 0L; Statistics.TotalMessages = conn.SelectScalar("SELECT COUNT(*) FROM messages") as long? ?? 0L;
} }
} }

View File

@ -0,0 +1,8 @@
using System;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils {
interface ISqliteConnection : IDisposable {
SqliteConnection InnerConnection { get; }
}
}

View File

@ -0,0 +1,109 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils {
sealed class SqliteConnectionPool : IDisposable {
private static string GetConnectionString(SqliteConnectionStringBuilder connectionStringBuilder) {
connectionStringBuilder.Pooling = false;
return connectionStringBuilder.ToString();
}
private readonly object monitor = new ();
private volatile bool isDisposed;
private readonly BlockingCollection<PooledConnection> free = new (new ConcurrentStack<PooledConnection>());
private readonly List<PooledConnection> used;
public SqliteConnectionPool(SqliteConnectionStringBuilder connectionStringBuilder, int poolSize) {
var connectionString = GetConnectionString(connectionStringBuilder);
for (int i = 0; i < poolSize; i++) {
var conn = new SqliteConnection(connectionString);
conn.Open();
free.Add(new PooledConnection(this, conn));
}
used = new List<PooledConnection>(poolSize);
}
private void ThrowIfDisposed() {
if (isDisposed) {
throw new ObjectDisposedException(nameof(SqliteConnectionPool));
}
}
public ISqliteConnection Take() {
PooledConnection? conn = null;
while (conn == null) {
ThrowIfDisposed();
lock (monitor) {
if (free.TryTake(out conn, TimeSpan.FromMilliseconds(100))) {
used.Add(conn);
break;
}
else {
Log.ForType<SqliteConnectionPool>().Warn("Thread " + Thread.CurrentThread.ManagedThreadId + " is starving for connections.");
}
}
}
return conn;
}
private void Return(PooledConnection conn) {
ThrowIfDisposed();
lock (monitor) {
if (used.Remove(conn)) {
free.Add(conn);
}
}
}
public void Dispose() {
if (isDisposed) {
return;
}
isDisposed = true;
lock (monitor) {
while (free.TryTake(out var conn)) {
Close(conn.InnerConnection);
}
foreach (var conn in used) {
Close(conn.InnerConnection);
}
free.Dispose();
used.Clear();
}
}
private static void Close(SqliteConnection conn) {
conn.Close();
conn.Dispose();
}
private sealed class PooledConnection : ISqliteConnection {
public SqliteConnection InnerConnection { get; }
private readonly SqliteConnectionPool pool;
public PooledConnection(SqliteConnectionPool pool, SqliteConnection conn) {
this.pool = pool;
this.InnerConnection = conn;
}
void IDisposable.Dispose() {
pool.Return(this);
}
}
}
}

View File

@ -2,20 +2,24 @@ using System;
using System.Linq; using System.Linq;
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite { namespace DHT.Server.Database.Sqlite.Utils {
static class SqliteUtils { static class SqliteExtensions {
public static SqliteCommand Command(this SqliteConnection conn, string sql) { public static SqliteCommand Command(this ISqliteConnection conn, string sql) {
var cmd = conn.CreateCommand(); var cmd = conn.InnerConnection.CreateCommand();
cmd.CommandText = sql; cmd.CommandText = sql;
return cmd; return cmd;
} }
public static object? SelectScalar(this SqliteConnection conn, string sql) { public static SqliteTransaction BeginTransaction(this ISqliteConnection conn) {
return conn.InnerConnection.BeginTransaction();
}
public static object? SelectScalar(this ISqliteConnection conn, string sql) {
using var cmd = conn.Command(sql); using var cmd = conn.Command(sql);
return cmd.ExecuteScalar(); return cmd.ExecuteScalar();
} }
public static SqliteCommand Insert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) { public static SqliteCommand Insert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
string columnNames = string.Join(',', columns.Select(static c => c.Name)); string columnNames = string.Join(',', columns.Select(static c => c.Name));
string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name)); string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name));
@ -26,7 +30,7 @@ namespace DHT.Server.Database.Sqlite {
return cmd; return cmd;
} }
public static SqliteCommand Upsert(this SqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) { public static SqliteCommand Upsert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
string columnNames = string.Join(',', columns.Select(static c => c.Name)); string columnNames = string.Join(',', columns.Select(static c => c.Name));
string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name)); string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name));
string columnUpdates = string.Join(',', columns.Skip(1).Select(static c => c.Name + " = excluded." + c.Name)); string columnUpdates = string.Join(',', columns.Skip(1).Select(static c => c.Name + " = excluded." + c.Name));
@ -40,7 +44,7 @@ namespace DHT.Server.Database.Sqlite {
return cmd; return cmd;
} }
public static SqliteCommand Delete(this SqliteConnection conn, string tableName, (string Name, SqliteType Type) column) { public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) {
var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name); var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name);
CreateParameters(cmd, new [] { column }); CreateParameters(cmd, new [] { column });
return cmd; return cmd;