diff --git a/api/types.go b/api/types.go index 609c4a8a..702d070c 100644 --- a/api/types.go +++ b/api/types.go @@ -279,85 +279,20 @@ func (m *Metrics) Summary() { var ErrInvalidOpts = fmt.Errorf("invalid options") func (opts *Options) FromMap(m map[string]interface{}) error { - valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct - typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct + data, err := json.Marshal(m) + if err != nil { + return err + } - // build map of json struct tags to their types - jsonOpts := make(map[string]reflect.StructField) - for _, field := range reflect.VisibleFields(typeOpts) { - jsonTag := strings.Split(field.Tag.Get("json"), ",")[0] - if jsonTag != "" { - jsonOpts[jsonTag] = field + err = json.Unmarshal(data, opts) + if err != nil { + // Custom error handling + if jsonErr, ok := err.(*json.UnmarshalTypeError); ok { + return fmt.Errorf("invalid type for option '%v': expected %v, got %v", jsonErr.Field, jsonErr.Type, jsonErr.Value) } + return err } - invalidOpts := []string{} - for key, val := range m { - if opt, ok := jsonOpts[key]; ok { - field := valueOpts.FieldByName(opt.Name) - if field.IsValid() && field.CanSet() { - if val == nil { - continue - } - - switch field.Kind() { - case reflect.Int: - switch t := val.(type) { - case int64: - field.SetInt(t) - case float64: - // when JSON unmarshals numbers, it uses float64, not int - field.SetInt(int64(t)) - default: - return fmt.Errorf("option %q must be of type integer", key) - } - case reflect.Bool: - val, ok := val.(bool) - if !ok { - return fmt.Errorf("option %q must be of type boolean", key) - } - field.SetBool(val) - case reflect.Float32: - // JSON unmarshals to float64 - val, ok := val.(float64) - if !ok { - return fmt.Errorf("option %q must be of type float32", key) - } - field.SetFloat(val) - case reflect.String: - val, ok := val.(string) - if !ok { - return fmt.Errorf("option %q must be of type string", key) - } - field.SetString(val) - case reflect.Slice: - // JSON unmarshals to []interface{}, not []string - val, ok := val.([]interface{}) - if !ok { - return fmt.Errorf("option %q must be of type array", key) - } - // convert []interface{} to []string - slice := make([]string, len(val)) - for i, item := range val { - str, ok := item.(string) - if !ok { - return fmt.Errorf("option %q must be of an array of strings", key) - } - slice[i] = str - } - field.Set(reflect.ValueOf(slice)) - default: - return fmt.Errorf("unknown type loading config params: %v", field.Kind()) - } - } - } else { - invalidOpts = append(invalidOpts, key) - } - } - - if len(invalidOpts) > 0 { - return fmt.Errorf("%w: %v", ErrInvalidOpts, strings.Join(invalidOpts, ", ")) - } return nil } diff --git a/server/routes.go b/server/routes.go index 7d1f9dfb..3d13ebe3 100644 --- a/server/routes.go +++ b/server/routes.go @@ -178,11 +178,7 @@ func GenerateHandler(c *gin.Context) { opts, err := modelOptions(model, req.Options) if err != nil { - if errors.Is(err, api.ErrInvalidOpts) { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -396,11 +392,7 @@ func EmbeddingHandler(c *gin.Context) { opts, err := modelOptions(model, req.Options) if err != nil { - if errors.Is(err, api.ErrInvalidOpts) { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -1112,11 +1104,7 @@ func ChatHandler(c *gin.Context) { opts, err := modelOptions(model, req.Options) if err != nil { - if errors.Is(err, api.ErrInvalidOpts) { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return }