From d218e6ce29c5dc9fdf07fdf27d65433ab8504bfb Mon Sep 17 00:00:00 2001 From: zhanluxianshen Date: Mon, 21 Oct 2024 06:36:41 +0800 Subject: [PATCH] Reuse type InvalidModelNameErrMsg, unify the const parameters. Signed-off-by: zhanluxianshen --- server/routes.go | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/server/routes.go b/server/routes.go index 7aff9235..ebadf32e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -124,7 +124,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { switch { case os.IsNotExist(err): c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) - case err.Error() == "invalid model name": + case err.Error() == errtypes.InvalidModelNameErrMsg: c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) default: c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -134,6 +134,18 @@ func (s *Server) GenerateHandler(c *gin.Context) { // expire the runner if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { + model, err := GetModel(req.Model) + if err != nil { + switch { + case os.IsNotExist(err): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case err.Error() == errtypes.InvalidModelNameErrMsg: + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + return + } s.sched.expireRunner(model) c.JSON(http.StatusOK, api.GenerateResponse{ @@ -517,7 +529,7 @@ func (s *Server) PullHandler(c *gin.Context) { name := model.ParseName(cmp.Or(req.Model, req.Name)) if !name.IsValid() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"}) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg}) return } @@ -751,7 +763,7 @@ func (s *Server) ShowHandler(c *gin.Context) { switch { case os.IsNotExist(err): c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) - case err.Error() == "invalid model name": + case err.Error() == errtypes.InvalidModelNameErrMsg: c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) default: c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -788,7 +800,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { n := model.ParseName(req.Model) if !n.IsValid() { - return nil, errors.New("invalid model name") + return nil, errors.New(errtypes.InvalidModelNameErrMsg) } manifest, err := ParseNamedManifest(n) @@ -1374,7 +1386,7 @@ func (s *Server) ChatHandler(c *gin.Context) { switch { case os.IsNotExist(err): c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) - case err.Error() == "invalid model name": + case err.Error() == errtypes.InvalidModelNameErrMsg: c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) default: c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})