Fix not rolling back database transactions after unhandled exceptions

This commit is contained in:
chylex 2024-04-17 11:34:11 +02:00
parent daafdbbfaf
commit c8d8d95daa
No known key found for this signature in database
GPG Key ID: 4DE42C8F19A80548
10 changed files with 132 additions and 106 deletions

View File

@ -19,9 +19,9 @@ sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository
}
public async Task Add(IReadOnlyList<Channel> channels) {
await using var conn = await pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using (var conn = await pool.Take()) {
await conn.BeginTransactionAsync();
await using var cmd = conn.Upsert("channels", [
("id", SqliteType.Integer),
("server", SqliteType.Integer),
@ -43,7 +43,7 @@ sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
await conn.CommitTransactionAsync();
}
UpdateTotalCount();

View File

@ -70,8 +70,8 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
var (download, data) = item;
await using (var conn = await pool.Take()) {
var tx = await conn.BeginTransactionAsync();
await conn.BeginTransactionAsync();
await using var metadataCmd = conn.Upsert("download_metadata", [
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
@ -103,7 +103,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
await upsertBlobCmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
await conn.CommitTransactionAsync();
}
UpdateTotalCount();

View File

@ -39,7 +39,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
}
await using (var conn = await pool.Take()) {
await using var tx = await conn.BeginTransactionAsync();
await conn.BeginTransactionAsync();
await using var messageCmd = conn.Upsert("messages", [
("message_id", SqliteType.Integer),
@ -167,7 +167,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
}
}
await tx.CommitAsync();
await conn.CommitTransactionAsync();
downloadCollector.OnCommitted();
}

View File

@ -19,9 +19,9 @@ sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository {
}
public async Task Add(IReadOnlyList<Data.Server> servers) {
await using var conn = await pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using (var conn = await pool.Take()) {
await conn.BeginTransactionAsync();
await using var cmd = conn.Upsert("servers", [
("id", SqliteType.Integer),
("name", SqliteType.Text),
@ -35,7 +35,7 @@ sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository {
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
await conn.CommitTransactionAsync();
}
UpdateTotalCount();

View File

@ -23,7 +23,7 @@ sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository {
public async Task Add(IReadOnlyList<User> users) {
await using (var conn = await pool.Take()) {
await using var tx = await conn.BeginTransactionAsync();
await conn.BeginTransactionAsync();
await using var cmd = conn.Upsert("users", [
("id", SqliteType.Integer),
@ -46,7 +46,7 @@ sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository {
}
}
await tx.CommitAsync();
await conn.CommitTransactionAsync();
downloadCollector.OnCommitted();
}

View File

@ -1,5 +1,4 @@
using System.Collections.Generic;
using System.Data.Common;
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
@ -23,7 +22,7 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade {
await conn.ExecuteAsync("ALTER TABLE attachments RENAME COLUMN url TO normalized_url");
await conn.ExecuteAsync("ALTER TABLE downloads RENAME COLUMN url TO normalized_url");
}
private async Task NormalizeAttachmentUrls(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing attachments...", 0, 0);
@ -39,7 +38,7 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade {
}
}
await using var tx = await conn.BeginTransactionAsync();
await conn.BeginTransactionAsync();
int totalUrls = normalizedUrls.Count;
int processedUrls = -1;
@ -61,7 +60,7 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade {
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await tx.CommitAsync();
await conn.CommitTransactionAsync();
}
private async Task NormalizeDownloadUrls(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
@ -84,26 +83,23 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade {
}
await conn.ExecuteAsync("PRAGMA cache_size = -20000");
await conn.BeginTransactionAsync();
await reporter.SubWork("Deleting duplicates...", 0, 0);
DbTransaction tx;
await using (tx = await conn.BeginTransactionAsync()) {
await reporter.SubWork("Deleting duplicates...", 0, 0);
await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) {
foreach (var duplicateUrl in duplicateUrlsToDelete) {
deleteCmd.Set(":url", duplicateUrl);
await deleteCmd.ExecuteNonQueryAsync();
}
await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) {
foreach (var duplicateUrl in duplicateUrlsToDelete) {
deleteCmd.Set(":url", duplicateUrl);
await deleteCmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
}
await conn.CommitTransactionAsync();
int totalUrls = normalizedUrlsToOriginalUrls.Count;
int processedUrls = -1;
tx = await conn.BeginTransactionAsync();
await conn.BeginTransactionAsync();
await using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) {
updateCmd.Add(":normalized_url", SqliteType.Text);
@ -115,11 +111,10 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade {
// Not proper way of dealing with transactions, but it avoids a long commit at the end.
// Schema upgrades are already non-atomic anyways, so this doesn't make it worse.
await tx.CommitAsync();
await tx.DisposeAsync();
await conn.CommitTransactionAsync();
tx = await conn.BeginTransactionAsync();
updateCmd.Transaction = (SqliteTransaction) tx;
await conn.BeginTransactionAsync();
conn.AssignActiveTransaction(updateCmd);
}
updateCmd.Set(":normalized_url", normalizedUrl);
@ -130,8 +125,7 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade {
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await tx.CommitAsync();
await tx.DisposeAsync();
await conn.CommitTransactionAsync();
await conn.ExecuteAsync("PRAGMA cache_size = -2000");
}

View File

@ -11,56 +11,55 @@ sealed class SqliteSchemaUpgradeTo7 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 6);
await SqliteSchema.CreateDownloadTables(conn);
await reporter.MainWork("Migrating download metadata...", 1, 6);
await conn.ExecuteAsync("INSERT INTO download_metadata (normalized_url, download_url, status, size) SELECT normalized_url, download_url, status, size FROM downloads");
await reporter.MainWork("Merging attachment metadata...", 2, 6);
await conn.ExecuteAsync("UPDATE download_metadata SET type = (SELECT type FROM attachments WHERE download_metadata.normalized_url = attachments.normalized_url)");
await reporter.MainWork("Migrating downloaded files...", 3, 6);
await MigrateDownloadBlobsToNewTable(conn, reporter);
await reporter.MainWork("Applying schema changes...", 4, 6);
await conn.ExecuteAsync("DROP TABLE downloads");
await reporter.MainWork("Discovering downloadable links...", 5, 6);
await DiscoverDownloadableLinks(conn, reporter);
}
private async Task MigrateDownloadBlobsToNewTable(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Listing downloaded files...", 0, 0);
var urlsToMigrate = await GetDownloadedFileUrls(conn);
int totalFiles = urlsToMigrate.Count;
int processedFiles = -1;
await reporter.SubWork("Processing downloaded files...", 0, totalFiles);
var tx = await conn.BeginTransactionAsync();
await conn.BeginTransactionAsync();
await using (var insertCmd = conn.Command("INSERT INTO download_blobs (normalized_url, blob) SELECT normalized_url, blob FROM downloads WHERE normalized_url = :normalized_url"))
await using (var deleteCmd = conn.Command("DELETE FROM downloads WHERE normalized_url = :normalized_url")) {
insertCmd.Add(":normalized_url", SqliteType.Text);
deleteCmd.Add(":normalized_url", SqliteType.Text);
foreach (var url in urlsToMigrate) {
if (++processedFiles % 10 == 0) {
await reporter.SubWork("Processing downloaded files...", processedFiles, totalFiles);
// Not proper way of dealing with transactions, but it avoids a long commit at the end.
// Schema upgrades are already non-atomic anyways, so this doesn't make it worse.
await tx.CommitAsync();
await tx.DisposeAsync();
tx = await conn.BeginTransactionAsync();
insertCmd.Transaction = (SqliteTransaction) tx;
deleteCmd.Transaction = (SqliteTransaction) tx;
await conn.CommitTransactionAsync();
await conn.BeginTransactionAsync();
conn.AssignActiveTransaction(insertCmd);
conn.AssignActiveTransaction(deleteCmd);
}
insertCmd.Set(":normalized_url", url);
await insertCmd.ExecuteNonQueryAsync();
deleteCmd.Set(":normalized_url", url);
await deleteCmd.ExecuteNonQueryAsync();
}
@ -68,8 +67,7 @@ sealed class SqliteSchemaUpgradeTo7 : ISchemaUpgrade {
await reporter.SubWork("Processing downloaded files...", totalFiles, totalFiles);
await tx.CommitAsync();
await tx.DisposeAsync();
await conn.CommitTransactionAsync();
}
private async Task<List<string>> GetDownloadedFileUrls(ISqliteConnection conn) {
@ -110,46 +108,46 @@ sealed class SqliteSchemaUpgradeTo7 : ISchemaUpgrade {
insertCmd.Set(":size", download.Size);
await insertCmd.ExecuteNonQueryAsync();
}
await using (var tx = await conn.BeginTransactionAsync()) {
await using var insertCmd = conn.Command("INSERT OR IGNORE INTO download_metadata (normalized_url, download_url, status, type, size) VALUES (:normalized_url, :download_url, :status, :type, :size)");
insertCmd.Add(":normalized_url", SqliteType.Text);
insertCmd.Add(":download_url", SqliteType.Text);
insertCmd.Add(":status", SqliteType.Integer);
insertCmd.Add(":type", SqliteType.Text);
insertCmd.Add(":size", SqliteType.Integer);
await reporter.SubWork("Processing embeds...", 1, 4);
await using (var embedCmd = conn.Command("SELECT json FROM embeds")) {
await using var reader = await embedCmd.ExecuteReaderAsync();
await conn.BeginTransactionAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, await DownloadLinkExtractor.TryFromEmbedJson(reader.GetStream(0)));
}
await using var insertCmd = conn.Command("INSERT OR IGNORE INTO download_metadata (normalized_url, download_url, status, type, size) VALUES (:normalized_url, :download_url, :status, :type, :size)");
insertCmd.Add(":normalized_url", SqliteType.Text);
insertCmd.Add(":download_url", SqliteType.Text);
insertCmd.Add(":status", SqliteType.Integer);
insertCmd.Add(":type", SqliteType.Text);
insertCmd.Add(":size", SqliteType.Integer);
await reporter.SubWork("Processing embeds...", 1, 4);
await using (var embedCmd = conn.Command("SELECT json FROM embeds")) {
await using var reader = await embedCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, await DownloadLinkExtractor.TryFromEmbedJson(reader.GetStream(0)));
}
await reporter.SubWork("Processing users...", 2, 4);
await using (var avatarCmd = conn.Command("SELECT id, avatar_url FROM users WHERE avatar_url IS NOT NULL")) {
await using var reader = await avatarCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, DownloadLinkExtractor.FromUserAvatar(reader.GetUint64(0), reader.GetString(1)));
}
}
await reporter.SubWork("Processing reactions...", 3, 4);
await using (var avatarCmd = conn.Command("SELECT DISTINCT emoji_id, emoji_flags FROM reactions WHERE emoji_id IS NOT NULL")) {
await using var reader = await avatarCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, DownloadLinkExtractor.FromEmoji(reader.GetUint64(0), (EmojiFlags) reader.GetInt16(1)));
}
}
await tx.CommitAsync();
}
await reporter.SubWork("Processing users...", 2, 4);
await using (var avatarCmd = conn.Command("SELECT id, avatar_url FROM users WHERE avatar_url IS NOT NULL")) {
await using var reader = await avatarCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, DownloadLinkExtractor.FromUserAvatar(reader.GetUint64(0), reader.GetString(1)));
}
}
await reporter.SubWork("Processing reactions...", 3, 4);
await using (var avatarCmd = conn.Command("SELECT DISTINCT emoji_id, emoji_flags FROM reactions WHERE emoji_id IS NOT NULL")) {
await using var reader = await avatarCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, DownloadLinkExtractor.FromEmoji(reader.GetUint64(0), (EmojiFlags) reader.GetInt16(1)));
}
}
await conn.CommitTransactionAsync();
}
}

View File

@ -1,8 +1,15 @@
using System;
using System.Threading.Tasks;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils;
interface ISqliteConnection : IAsyncDisposable {
SqliteConnection InnerConnection { get; }
Task BeginTransactionAsync();
Task CommitTransactionAsync();
Task RollbackTransactionAsync();
void AssignActiveTransaction(SqliteCommand command);
}

View File

@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using DHT.Utils.Collections;
@ -73,17 +74,48 @@ sealed class SqliteConnectionPool : IAsyncDisposable {
disposalTokenSource.Dispose();
}
private sealed class PooledConnection : ISqliteConnection {
public SqliteConnection InnerConnection { get; }
private sealed class PooledConnection(SqliteConnectionPool pool, SqliteConnection conn) : ISqliteConnection {
public SqliteConnection InnerConnection { get; } = conn;
private readonly SqliteConnectionPool pool;
private DbTransaction? activeTransaction;
public PooledConnection(SqliteConnectionPool pool, SqliteConnection conn) {
this.pool = pool;
this.InnerConnection = conn;
public async Task BeginTransactionAsync() {
if (activeTransaction != null) {
throw new InvalidOperationException("A transaction is already active.");
}
activeTransaction = await InnerConnection.BeginTransactionAsync();
}
public async Task CommitTransactionAsync() {
if (activeTransaction == null) {
throw new InvalidOperationException("No active transaction to commit.");
}
await activeTransaction.CommitAsync();
await activeTransaction.DisposeAsync();
activeTransaction = null;
}
public async Task RollbackTransactionAsync() {
if (activeTransaction == null) {
throw new InvalidOperationException("No active transaction to rollback.");
}
await activeTransaction.RollbackAsync();
await activeTransaction.DisposeAsync();
activeTransaction = null;
}
public void AssignActiveTransaction(SqliteCommand command) {
command.Transaction = (SqliteTransaction?) activeTransaction;
}
public async ValueTask DisposeAsync() {
if (activeTransaction != null) {
await RollbackTransactionAsync();
}
await pool.Return(this);
}
}

View File

@ -1,5 +1,4 @@
using System;
using System.Data.Common;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
@ -9,10 +8,6 @@ using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils;
static class SqliteExtensions {
public static ValueTask<DbTransaction> BeginTransactionAsync(this ISqliteConnection conn) {
return conn.InnerConnection.BeginTransactionAsync();
}
public static SqliteCommand Command(this ISqliteConnection conn, [LanguageInjection("sql")] string sql) {
var cmd = conn.InnerConnection.CreateCommand();
cmd.CommandText = sql;