diff --git a/app/Server/Database/Sqlite/Repositories/SqliteChannelRepository.cs b/app/Server/Database/Sqlite/Repositories/SqliteChannelRepository.cs index e0ab2b5..b308fa1 100644 --- a/app/Server/Database/Sqlite/Repositories/SqliteChannelRepository.cs +++ b/app/Server/Database/Sqlite/Repositories/SqliteChannelRepository.cs @@ -19,9 +19,9 @@ sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository } public async Task Add(IReadOnlyList channels) { - await using var conn = await pool.Take(); - - await using (var tx = await conn.BeginTransactionAsync()) { + await using (var conn = await pool.Take()) { + await conn.BeginTransactionAsync(); + await using var cmd = conn.Upsert("channels", [ ("id", SqliteType.Integer), ("server", SqliteType.Integer), @@ -43,7 +43,7 @@ sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository await cmd.ExecuteNonQueryAsync(); } - await tx.CommitAsync(); + await conn.CommitTransactionAsync(); } UpdateTotalCount(); diff --git a/app/Server/Database/Sqlite/Repositories/SqliteDownloadRepository.cs b/app/Server/Database/Sqlite/Repositories/SqliteDownloadRepository.cs index e82f591..4495b58 100644 --- a/app/Server/Database/Sqlite/Repositories/SqliteDownloadRepository.cs +++ b/app/Server/Database/Sqlite/Repositories/SqliteDownloadRepository.cs @@ -70,8 +70,8 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor var (download, data) = item; await using (var conn = await pool.Take()) { - var tx = await conn.BeginTransactionAsync(); - + await conn.BeginTransactionAsync(); + await using var metadataCmd = conn.Upsert("download_metadata", [ ("normalized_url", SqliteType.Text), ("download_url", SqliteType.Text), @@ -103,7 +103,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor await upsertBlobCmd.ExecuteNonQueryAsync(); } - await tx.CommitAsync(); + await conn.CommitTransactionAsync(); } UpdateTotalCount(); diff --git a/app/Server/Database/Sqlite/Repositories/SqliteMessageRepository.cs b/app/Server/Database/Sqlite/Repositories/SqliteMessageRepository.cs index 1b989f3..5eacf84 100644 --- a/app/Server/Database/Sqlite/Repositories/SqliteMessageRepository.cs +++ b/app/Server/Database/Sqlite/Repositories/SqliteMessageRepository.cs @@ -39,7 +39,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository } await using (var conn = await pool.Take()) { - await using var tx = await conn.BeginTransactionAsync(); + await conn.BeginTransactionAsync(); await using var messageCmd = conn.Upsert("messages", [ ("message_id", SqliteType.Integer), @@ -167,7 +167,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository } } - await tx.CommitAsync(); + await conn.CommitTransactionAsync(); downloadCollector.OnCommitted(); } diff --git a/app/Server/Database/Sqlite/Repositories/SqliteServerRepository.cs b/app/Server/Database/Sqlite/Repositories/SqliteServerRepository.cs index a969efc..9932b7f 100644 --- a/app/Server/Database/Sqlite/Repositories/SqliteServerRepository.cs +++ b/app/Server/Database/Sqlite/Repositories/SqliteServerRepository.cs @@ -19,9 +19,9 @@ sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository { } public async Task Add(IReadOnlyList servers) { - await using var conn = await pool.Take(); - - await using (var tx = await conn.BeginTransactionAsync()) { + await using (var conn = await pool.Take()) { + await conn.BeginTransactionAsync(); + await using var cmd = conn.Upsert("servers", [ ("id", SqliteType.Integer), ("name", SqliteType.Text), @@ -35,7 +35,7 @@ sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository { await cmd.ExecuteNonQueryAsync(); } - await tx.CommitAsync(); + await conn.CommitTransactionAsync(); } UpdateTotalCount(); diff --git a/app/Server/Database/Sqlite/Repositories/SqliteUserRepository.cs b/app/Server/Database/Sqlite/Repositories/SqliteUserRepository.cs index d935fa4..c3854dd 100644 --- a/app/Server/Database/Sqlite/Repositories/SqliteUserRepository.cs +++ b/app/Server/Database/Sqlite/Repositories/SqliteUserRepository.cs @@ -23,7 +23,7 @@ sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository { public async Task Add(IReadOnlyList users) { await using (var conn = await pool.Take()) { - await using var tx = await conn.BeginTransactionAsync(); + await conn.BeginTransactionAsync(); await using var cmd = conn.Upsert("users", [ ("id", SqliteType.Integer), @@ -46,7 +46,7 @@ sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository { } } - await tx.CommitAsync(); + await conn.CommitTransactionAsync(); downloadCollector.OnCommitted(); } diff --git a/app/Server/Database/Sqlite/Schema/SqliteSchemaUpgradeTo6.cs b/app/Server/Database/Sqlite/Schema/SqliteSchemaUpgradeTo6.cs index 418ff17..a180ff6 100644 --- a/app/Server/Database/Sqlite/Schema/SqliteSchemaUpgradeTo6.cs +++ b/app/Server/Database/Sqlite/Schema/SqliteSchemaUpgradeTo6.cs @@ -1,5 +1,4 @@ using System.Collections.Generic; -using System.Data.Common; using System.Threading.Tasks; using DHT.Server.Database.Sqlite.Utils; using DHT.Server.Download; @@ -23,7 +22,7 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade { await conn.ExecuteAsync("ALTER TABLE attachments RENAME COLUMN url TO normalized_url"); await conn.ExecuteAsync("ALTER TABLE downloads RENAME COLUMN url TO normalized_url"); } - + private async Task NormalizeAttachmentUrls(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) { await reporter.SubWork("Preparing attachments...", 0, 0); @@ -39,7 +38,7 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade { } } - await using var tx = await conn.BeginTransactionAsync(); + await conn.BeginTransactionAsync(); int totalUrls = normalizedUrls.Count; int processedUrls = -1; @@ -61,7 +60,7 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade { await reporter.SubWork("Updating URLs...", totalUrls, totalUrls); - await tx.CommitAsync(); + await conn.CommitTransactionAsync(); } private async Task NormalizeDownloadUrls(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) { @@ -84,26 +83,23 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade { } await conn.ExecuteAsync("PRAGMA cache_size = -20000"); + await conn.BeginTransactionAsync(); + + await reporter.SubWork("Deleting duplicates...", 0, 0); - DbTransaction tx; - - await using (tx = await conn.BeginTransactionAsync()) { - await reporter.SubWork("Deleting duplicates...", 0, 0); - - await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) { - foreach (var duplicateUrl in duplicateUrlsToDelete) { - deleteCmd.Set(":url", duplicateUrl); - await deleteCmd.ExecuteNonQueryAsync(); - } + await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) { + foreach (var duplicateUrl in duplicateUrlsToDelete) { + deleteCmd.Set(":url", duplicateUrl); + await deleteCmd.ExecuteNonQueryAsync(); } - - await tx.CommitAsync(); } + await conn.CommitTransactionAsync(); + int totalUrls = normalizedUrlsToOriginalUrls.Count; int processedUrls = -1; - tx = await conn.BeginTransactionAsync(); + await conn.BeginTransactionAsync(); await using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) { updateCmd.Add(":normalized_url", SqliteType.Text); @@ -115,11 +111,10 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade { // Not proper way of dealing with transactions, but it avoids a long commit at the end. // Schema upgrades are already non-atomic anyways, so this doesn't make it worse. - await tx.CommitAsync(); - await tx.DisposeAsync(); + await conn.CommitTransactionAsync(); - tx = await conn.BeginTransactionAsync(); - updateCmd.Transaction = (SqliteTransaction) tx; + await conn.BeginTransactionAsync(); + conn.AssignActiveTransaction(updateCmd); } updateCmd.Set(":normalized_url", normalizedUrl); @@ -130,8 +125,7 @@ sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade { await reporter.SubWork("Updating URLs...", totalUrls, totalUrls); - await tx.CommitAsync(); - await tx.DisposeAsync(); + await conn.CommitTransactionAsync(); await conn.ExecuteAsync("PRAGMA cache_size = -2000"); } diff --git a/app/Server/Database/Sqlite/Schema/SqliteSchemaUpgradeTo7.cs b/app/Server/Database/Sqlite/Schema/SqliteSchemaUpgradeTo7.cs index c4d81ed..ffc804b 100644 --- a/app/Server/Database/Sqlite/Schema/SqliteSchemaUpgradeTo7.cs +++ b/app/Server/Database/Sqlite/Schema/SqliteSchemaUpgradeTo7.cs @@ -11,56 +11,55 @@ sealed class SqliteSchemaUpgradeTo7 : ISchemaUpgrade { async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) { await reporter.MainWork("Applying schema changes...", 0, 6); await SqliteSchema.CreateDownloadTables(conn); - + await reporter.MainWork("Migrating download metadata...", 1, 6); await conn.ExecuteAsync("INSERT INTO download_metadata (normalized_url, download_url, status, size) SELECT normalized_url, download_url, status, size FROM downloads"); - + await reporter.MainWork("Merging attachment metadata...", 2, 6); await conn.ExecuteAsync("UPDATE download_metadata SET type = (SELECT type FROM attachments WHERE download_metadata.normalized_url = attachments.normalized_url)"); - + await reporter.MainWork("Migrating downloaded files...", 3, 6); await MigrateDownloadBlobsToNewTable(conn, reporter); - + await reporter.MainWork("Applying schema changes...", 4, 6); await conn.ExecuteAsync("DROP TABLE downloads"); - + await reporter.MainWork("Discovering downloadable links...", 5, 6); await DiscoverDownloadableLinks(conn, reporter); } private async Task MigrateDownloadBlobsToNewTable(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) { await reporter.SubWork("Listing downloaded files...", 0, 0); - + var urlsToMigrate = await GetDownloadedFileUrls(conn); int totalFiles = urlsToMigrate.Count; int processedFiles = -1; await reporter.SubWork("Processing downloaded files...", 0, totalFiles); - - var tx = await conn.BeginTransactionAsync(); + + await conn.BeginTransactionAsync(); await using (var insertCmd = conn.Command("INSERT INTO download_blobs (normalized_url, blob) SELECT normalized_url, blob FROM downloads WHERE normalized_url = :normalized_url")) await using (var deleteCmd = conn.Command("DELETE FROM downloads WHERE normalized_url = :normalized_url")) { insertCmd.Add(":normalized_url", SqliteType.Text); deleteCmd.Add(":normalized_url", SqliteType.Text); - + foreach (var url in urlsToMigrate) { if (++processedFiles % 10 == 0) { await reporter.SubWork("Processing downloaded files...", processedFiles, totalFiles); - + // Not proper way of dealing with transactions, but it avoids a long commit at the end. // Schema upgrades are already non-atomic anyways, so this doesn't make it worse. - await tx.CommitAsync(); - await tx.DisposeAsync(); - - tx = await conn.BeginTransactionAsync(); - insertCmd.Transaction = (SqliteTransaction) tx; - deleteCmd.Transaction = (SqliteTransaction) tx; + await conn.CommitTransactionAsync(); + + await conn.BeginTransactionAsync(); + conn.AssignActiveTransaction(insertCmd); + conn.AssignActiveTransaction(deleteCmd); } - + insertCmd.Set(":normalized_url", url); await insertCmd.ExecuteNonQueryAsync(); - + deleteCmd.Set(":normalized_url", url); await deleteCmd.ExecuteNonQueryAsync(); } @@ -68,8 +67,7 @@ sealed class SqliteSchemaUpgradeTo7 : ISchemaUpgrade { await reporter.SubWork("Processing downloaded files...", totalFiles, totalFiles); - await tx.CommitAsync(); - await tx.DisposeAsync(); + await conn.CommitTransactionAsync(); } private async Task> GetDownloadedFileUrls(ISqliteConnection conn) { @@ -110,46 +108,46 @@ sealed class SqliteSchemaUpgradeTo7 : ISchemaUpgrade { insertCmd.Set(":size", download.Size); await insertCmd.ExecuteNonQueryAsync(); } - - await using (var tx = await conn.BeginTransactionAsync()) { - await using var insertCmd = conn.Command("INSERT OR IGNORE INTO download_metadata (normalized_url, download_url, status, type, size) VALUES (:normalized_url, :download_url, :status, :type, :size)"); - insertCmd.Add(":normalized_url", SqliteType.Text); - insertCmd.Add(":download_url", SqliteType.Text); - insertCmd.Add(":status", SqliteType.Integer); - insertCmd.Add(":type", SqliteType.Text); - insertCmd.Add(":size", SqliteType.Integer); - await reporter.SubWork("Processing embeds...", 1, 4); - - await using (var embedCmd = conn.Command("SELECT json FROM embeds")) { - await using var reader = await embedCmd.ExecuteReaderAsync(); + await conn.BeginTransactionAsync(); - while (await reader.ReadAsync()) { - await InsertDownload(insertCmd, await DownloadLinkExtractor.TryFromEmbedJson(reader.GetStream(0))); - } + await using var insertCmd = conn.Command("INSERT OR IGNORE INTO download_metadata (normalized_url, download_url, status, type, size) VALUES (:normalized_url, :download_url, :status, :type, :size)"); + insertCmd.Add(":normalized_url", SqliteType.Text); + insertCmd.Add(":download_url", SqliteType.Text); + insertCmd.Add(":status", SqliteType.Integer); + insertCmd.Add(":type", SqliteType.Text); + insertCmd.Add(":size", SqliteType.Integer); + + await reporter.SubWork("Processing embeds...", 1, 4); + + await using (var embedCmd = conn.Command("SELECT json FROM embeds")) { + await using var reader = await embedCmd.ExecuteReaderAsync(); + + while (await reader.ReadAsync()) { + await InsertDownload(insertCmd, await DownloadLinkExtractor.TryFromEmbedJson(reader.GetStream(0))); } - - await reporter.SubWork("Processing users...", 2, 4); - - await using (var avatarCmd = conn.Command("SELECT id, avatar_url FROM users WHERE avatar_url IS NOT NULL")) { - await using var reader = await avatarCmd.ExecuteReaderAsync(); - - while (await reader.ReadAsync()) { - await InsertDownload(insertCmd, DownloadLinkExtractor.FromUserAvatar(reader.GetUint64(0), reader.GetString(1))); - } - } - - await reporter.SubWork("Processing reactions...", 3, 4); - - await using (var avatarCmd = conn.Command("SELECT DISTINCT emoji_id, emoji_flags FROM reactions WHERE emoji_id IS NOT NULL")) { - await using var reader = await avatarCmd.ExecuteReaderAsync(); - - while (await reader.ReadAsync()) { - await InsertDownload(insertCmd, DownloadLinkExtractor.FromEmoji(reader.GetUint64(0), (EmojiFlags) reader.GetInt16(1))); - } - } - - await tx.CommitAsync(); } + + await reporter.SubWork("Processing users...", 2, 4); + + await using (var avatarCmd = conn.Command("SELECT id, avatar_url FROM users WHERE avatar_url IS NOT NULL")) { + await using var reader = await avatarCmd.ExecuteReaderAsync(); + + while (await reader.ReadAsync()) { + await InsertDownload(insertCmd, DownloadLinkExtractor.FromUserAvatar(reader.GetUint64(0), reader.GetString(1))); + } + } + + await reporter.SubWork("Processing reactions...", 3, 4); + + await using (var avatarCmd = conn.Command("SELECT DISTINCT emoji_id, emoji_flags FROM reactions WHERE emoji_id IS NOT NULL")) { + await using var reader = await avatarCmd.ExecuteReaderAsync(); + + while (await reader.ReadAsync()) { + await InsertDownload(insertCmd, DownloadLinkExtractor.FromEmoji(reader.GetUint64(0), (EmojiFlags) reader.GetInt16(1))); + } + } + + await conn.CommitTransactionAsync(); } } diff --git a/app/Server/Database/Sqlite/Utils/ISqliteConnection.cs b/app/Server/Database/Sqlite/Utils/ISqliteConnection.cs index 531a9d7..39b3853 100644 --- a/app/Server/Database/Sqlite/Utils/ISqliteConnection.cs +++ b/app/Server/Database/Sqlite/Utils/ISqliteConnection.cs @@ -1,8 +1,15 @@ using System; +using System.Threading.Tasks; using Microsoft.Data.Sqlite; namespace DHT.Server.Database.Sqlite.Utils; interface ISqliteConnection : IAsyncDisposable { SqliteConnection InnerConnection { get; } + + Task BeginTransactionAsync(); + Task CommitTransactionAsync(); + Task RollbackTransactionAsync(); + + void AssignActiveTransaction(SqliteCommand command); } diff --git a/app/Server/Database/Sqlite/Utils/SqliteConnectionPool.cs b/app/Server/Database/Sqlite/Utils/SqliteConnectionPool.cs index 095e8ff..cce8ba7 100644 --- a/app/Server/Database/Sqlite/Utils/SqliteConnectionPool.cs +++ b/app/Server/Database/Sqlite/Utils/SqliteConnectionPool.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Data.Common; using System.Threading; using System.Threading.Tasks; using DHT.Utils.Collections; @@ -73,17 +74,48 @@ sealed class SqliteConnectionPool : IAsyncDisposable { disposalTokenSource.Dispose(); } - private sealed class PooledConnection : ISqliteConnection { - public SqliteConnection InnerConnection { get; } + private sealed class PooledConnection(SqliteConnectionPool pool, SqliteConnection conn) : ISqliteConnection { + public SqliteConnection InnerConnection { get; } = conn; - private readonly SqliteConnectionPool pool; + private DbTransaction? activeTransaction; - public PooledConnection(SqliteConnectionPool pool, SqliteConnection conn) { - this.pool = pool; - this.InnerConnection = conn; + public async Task BeginTransactionAsync() { + if (activeTransaction != null) { + throw new InvalidOperationException("A transaction is already active."); + } + + activeTransaction = await InnerConnection.BeginTransactionAsync(); + } + + public async Task CommitTransactionAsync() { + if (activeTransaction == null) { + throw new InvalidOperationException("No active transaction to commit."); + } + + await activeTransaction.CommitAsync(); + await activeTransaction.DisposeAsync(); + activeTransaction = null; + } + + public async Task RollbackTransactionAsync() { + if (activeTransaction == null) { + throw new InvalidOperationException("No active transaction to rollback."); + } + + await activeTransaction.RollbackAsync(); + await activeTransaction.DisposeAsync(); + activeTransaction = null; + } + + public void AssignActiveTransaction(SqliteCommand command) { + command.Transaction = (SqliteTransaction?) activeTransaction; } public async ValueTask DisposeAsync() { + if (activeTransaction != null) { + await RollbackTransactionAsync(); + } + await pool.Return(this); } } diff --git a/app/Server/Database/Sqlite/Utils/SqliteExtensions.cs b/app/Server/Database/Sqlite/Utils/SqliteExtensions.cs index a018a33..a704e4d 100644 --- a/app/Server/Database/Sqlite/Utils/SqliteExtensions.cs +++ b/app/Server/Database/Sqlite/Utils/SqliteExtensions.cs @@ -1,5 +1,4 @@ using System; -using System.Data.Common; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -9,10 +8,6 @@ using Microsoft.Data.Sqlite; namespace DHT.Server.Database.Sqlite.Utils; static class SqliteExtensions { - public static ValueTask BeginTransactionAsync(this ISqliteConnection conn) { - return conn.InnerConnection.BeginTransactionAsync(); - } - public static SqliteCommand Command(this ISqliteConnection conn, [LanguageInjection("sql")] string sql) { var cmd = conn.InnerConnection.CreateCommand(); cmd.CommandText = sql;