diff --git a/db/db.go b/db/db.go index b3e793202..1023ce827 100644 --- a/db/db.go +++ b/db/db.go @@ -3,6 +3,7 @@ package db import ( "database/sql" "os" + "sync" "github.com/deluan/navidrome/conf" _ "github.com/deluan/navidrome/db/migrations" @@ -11,19 +12,35 @@ import ( "github.com/pressly/goose" ) -const driver = "sqlite3" +var ( + once sync.Once + Driver = "sqlite3" + Path string +) -func EnsureDB() { - db, err := sql.Open(driver, conf.Server.DbPath) +func Init() { + once.Do(func() { + Path = conf.Server.DbPath + if Path == ":memory:" { + Path = "file::memory:?cache=shared" + conf.Server.DbPath = Path + } + log.Debug("Opening DataBase", "dbPath", Path, "driver", Driver) + }) +} + +func EnsureLatestVersion() { + Init() + db, err := sql.Open(Driver, Path) defer db.Close() if err != nil { log.Error("Failed to open DB", err) os.Exit(1) } - err = goose.SetDialect(driver) + err = goose.SetDialect(Driver) if err != nil { - log.Error("Invalid DB driver", "driver", driver, err) + log.Error("Invalid DB driver", "driver", Driver, err) os.Exit(1) } err = goose.Run("up", db, "./") diff --git a/main.go b/main.go index 88014b698..de3c96fa8 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,7 @@ func main() { } conf.Load() - db.EnsureDB() + db.EnsureLatestVersion() a := CreateServer(conf.Server.MusicFolder) a.MountRouter("/rest", CreateSubsonicAPIRouter()) diff --git a/persistence/persistence.go b/persistence/persistence.go index eb0f013e5..fc2ee1714 100644 --- a/persistence/persistence.go +++ b/persistence/persistence.go @@ -3,18 +3,16 @@ package persistence import ( "context" "reflect" - "strings" "sync" "github.com/astaxie/beego/orm" - "github.com/deluan/navidrome/conf" + "github.com/deluan/navidrome/db" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" ) var ( - once sync.Once - driver = "sqlite3" + once sync.Once ) type SQLStore struct { @@ -23,13 +21,7 @@ type SQLStore struct { func New() model.DataStore { once.Do(func() { - dbPath := conf.Server.DbPath - if dbPath == ":memory:" { - dbPath = "file::memory:?cache=shared" - } - log.Debug("Opening DataBase", "dbPath", dbPath, "driver", driver) - - err := initORM(dbPath) + err := orm.RegisterDataBase("default", db.Driver, db.Path) if err != nil { panic(err) } @@ -147,10 +139,3 @@ func (db *SQLStore) getOrmer() orm.Ormer { } return db.orm } - -func initORM(dbPath string) error { - if strings.Contains(dbPath, "postgres") { - driver = "postgres" - } - return orm.RegisterDataBase("default", driver, dbPath) -} diff --git a/persistence/persistence_suite_test.go b/persistence/persistence_suite_test.go index c137e4ef9..9fa315cd3 100644 --- a/persistence/persistence_suite_test.go +++ b/persistence/persistence_suite_test.go @@ -21,10 +21,11 @@ func TestPersistence(t *testing.T) { tests.Init(t, true) //os.Remove("./test-123.db") - //conf.Server.DbPath = "./test-123.db" - conf.Server.DbPath = "file::memory:?cache=shared" + //conf.Server.Path = "./test-123.db" + conf.Server.DbPath = ":memory:" + db.Init() New() - db.EnsureDB() + db.EnsureLatestVersion() log.SetLevel(log.LevelCritical) RegisterFailHandler(Fail) RunSpecs(t, "Persistence Suite")