diff --git a/persistence/share_repository.go b/persistence/share_repository.go index 795257b1f..a47c2dafe 100644 --- a/persistence/share_repository.go +++ b/persistence/share_repository.go @@ -48,7 +48,7 @@ func (r *shareRepository) Put(s *model.Share) error { func (r *shareRepository) Update(entity interface{}, cols ...string) error { s := entity.(*model.Share) - _, err := r.put(s.ID, s) + _, err := r.put(s.ID, s, cols...) if err == model.ErrNotFound { return rest.ErrNotFound } diff --git a/persistence/sql_base_repository.go b/persistence/sql_base_repository.go index d25501d43..fd89f9fd0 100644 --- a/persistence/sql_base_repository.go +++ b/persistence/sql_base_repository.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/navidrome/navidrome/utils" + . "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" "github.com/google/uuid" @@ -184,30 +186,31 @@ func (r sqlRepository) count(countQuery SelectBuilder, options ...model.QueryOpt return res.Count, err } -func (r sqlRepository) put(id string, m interface{}) (newId string, err error) { +func (r sqlRepository) put(id string, m interface{}, colsToUpdate ...string) (newId string, err error) { values, _ := toSqlArgs(m) - // Remove created_at from args and save it for later, if needed for insert - createdAt := values["created_at"] - delete(values, "created_at") + // If there's an ID, try to update first if id != "" { - update := Update(r.tableName).Where(Eq{"id": id}).SetMap(values) + updateValues := map[string]interface{}{} + for k, v := range values { + if len(colsToUpdate) == 0 || utils.StringInSlice(k, colsToUpdate) { + updateValues[k] = v + } + } + delete(updateValues, "created_at") + update := Update(r.tableName).Where(Eq{"id": id}).SetMap(updateValues) count, err := r.executeSQL(update) if err != nil { return "", err } if count > 0 { - return id, err + return id, nil } } - // If does not have an id OR could not update (new record with predefined id) + // If does not have an ID OR the ID was not found (when it is a new record with predefined id) if id == "" { id = uuid.NewString() values["id"] = id } - // It is a insert. if there was a created_at, add it back to args - if createdAt != nil { - values["created_at"] = createdAt - } insert := Insert(r.tableName).SetMap(values) _, err = r.executeSQL(insert) return id, err