diff --git a/api/types.go b/api/types.go index e5291a02..e67e3f9c 100644 --- a/api/types.go +++ b/api/types.go @@ -80,6 +80,9 @@ type GenerateRequest struct { // Options lists model-specific options. For example, temperature can be // set through this field, if the model supports it. Options map[string]interface{} `json:"options"` + + // JsonSchema is an optional json schema to use for this request. + JsonSchema string `json:"json_schema,omitempty"` } // ChatRequest describes a request sent by [Client.Chat]. @@ -105,6 +108,9 @@ type ChatRequest struct { // Options lists model-specific options. Options map[string]interface{} `json:"options"` + + // JsonSchema is an optional json schema to use for this request. + JsonSchema string `json:"json_schema,omitempty"` } type Tools []Tool diff --git a/llama/common.h b/llama/common.h index a4a9e1ff..f9d9bcf6 100644 --- a/llama/common.h +++ b/llama/common.h @@ -160,6 +160,7 @@ struct gpt_sampler_params { }; std::string grammar; // optional BNF-like grammar to constrain sampling + std::string json_schema; // optional JSON schema to constrain sampling std::vector logit_bias; // logit biases to apply diff --git a/llama/llama.go b/llama/llama.go index dbb02768..115913f4 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -623,6 +623,7 @@ type SamplingParams struct { PenalizeNl bool Seed uint32 Grammar string + JsonSchema string } func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) { @@ -645,8 +646,12 @@ func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, grammar := C.CString(params.Grammar) defer C.free(unsafe.Pointer(grammar)) - cparams.grammar = grammar + + jsonSchema := C.CString(params.JsonSchema) + defer C.free(unsafe.Pointer(jsonSchema)) + cparams.json_schema = jsonSchema + context := &SamplingContext{c: C.gpt_sampler_cinit(model.c, &cparams)} if context.c == nil { return nil, errors.New("unable to create sampling context") diff --git a/llama/runner/runner.go b/llama/runner/runner.go index cff7d148..565e195d 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -534,6 +534,7 @@ type CompletionRequest struct { Prompt string `json:"prompt"` Images []ImageData `json:"image_data"` Grammar string `json:"grammar"` + JsonSchema string `json:"json_schema"` CachePrompt bool `json:"cache_prompt"` Options @@ -596,6 +597,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { samplingParams.PenalizeNl = req.PenalizeNewline samplingParams.Seed = uint32(req.Seed) samplingParams.Grammar = req.Grammar + samplingParams.JsonSchema = req.JsonSchema seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ numPredict: req.NumPredict, diff --git a/llama/sampling.cpp b/llama/sampling.cpp index d993dc2b..4a25cd78 100644 --- a/llama/sampling.cpp +++ b/llama/sampling.cpp @@ -25,11 +25,14 @@ */ #include "sampling.h" - #include "common.h" +#include "json.hpp" + +#include "json-schema-to-grammar.h" #include #include +#include // the ring buffer works similarly to std::deque, but with a fixed capacity // TODO: deduplicate with llama-impl.h @@ -172,7 +175,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st auto * result = new gpt_sampler { /* .params = */ params, - /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), + /* .grmr = */ nullptr, /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, @@ -246,6 +249,12 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st } llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); } + if (params.json_schema != "") { + nlohmann::ordered_json jsonSchema = nlohmann::ordered_json::parse(params.json_schema); + result->grmr = llama_sampler_init_grammar(model, json_schema_to_grammar(jsonSchema).c_str(), "root"); + } else { + result->grmr = llama_sampler_init_grammar(model, params.grammar.c_str(), "root"); + } return result; } diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index 3dd7edf4..380293c7 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -23,6 +23,7 @@ struct gpt_sampler *gpt_sampler_cinit( sparams.penalize_nl = params->penalize_nl; sparams.seed = params->seed; sparams.grammar = params->grammar; + sparams.json_schema = params->json_schema; return gpt_sampler_init(model, sparams); } catch (const std::exception & err) { return nullptr; diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h index ec919a48..649b836a 100644 --- a/llama/sampling_ext.h +++ b/llama/sampling_ext.h @@ -29,6 +29,7 @@ extern "C" bool penalize_nl; uint32_t seed; char *grammar; + char *json_schema; }; struct gpt_sampler *gpt_sampler_cinit( diff --git a/llm/server.go b/llm/server.go index 96815826..ecfa575a 100644 --- a/llm/server.go +++ b/llm/server.go @@ -673,6 +673,7 @@ type CompletionRequest struct { Format string Images []ImageData Options *api.Options + JsonSchema string } type CompletionResponse struct { @@ -733,9 +734,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if req.Format == "json" { request["grammar"] = jsonGrammar - if !strings.Contains(strings.ToLower(req.Prompt), "json") { - slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.") - } + request["json_schema"] = req.JsonSchema } // Handling JSON marshaling with special characters unescaped. diff --git a/openai/openai.go b/openai/openai.go index 2bf9b9f9..fba4c634 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -61,8 +61,15 @@ type Usage struct { TotalTokens int `json:"total_tokens"` } +type JsonSchema struct { + Name string `json:"name"` + Schema json.RawMessage `json:"schema"` + Strict bool `json:"strict"` +} + type ResponseFormat struct { - Type string `json:"type"` + Type string `json:"type"` + JsonSchema JsonSchema `json:"json_schema"` } type EmbedRequest struct { @@ -476,17 +483,35 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { } var format string - if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" { + var jsonSchema string + if r.ResponseFormat != nil { format = "json" + if r.ResponseFormat.Type == "json_object" { + if len(r.ResponseFormat.JsonSchema.Schema) == 0 { + return nil, errors.New("schema must be specified when method is not 'json_schema'") + } + } else if r.ResponseFormat.Type == "json_schema" { + if len(r.ResponseFormat.JsonSchema.Schema) == 0 { + return nil, errors.New("schema must be specified when method is 'json_schema'") + } + jsonSchemaBytes, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema) + if err != nil { + return nil, errors.New("failed to marshal json_schema") + } + jsonSchema = string(jsonSchemaBytes) + } else { + return nil, errors.New("invalid response format type") + } } return &api.ChatRequest{ - Model: r.Model, - Messages: messages, - Format: format, - Options: options, - Stream: &r.Stream, - Tools: r.Tools, + Model: r.Model, + Messages: messages, + Format: format, + JsonSchema: jsonSchema, + Options: options, + Stream: &r.Stream, + Tools: r.Tools, }, nil } diff --git a/server/routes.go b/server/routes.go index c5fd3293..c9d840ac 100644 --- a/server/routes.go +++ b/server/routes.go @@ -275,10 +275,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + JsonSchema: req.JsonSchema, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, @@ -1460,10 +1461,11 @@ func (s *Server) ChatHandler(c *gin.Context) { go func() { defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + JsonSchema: req.JsonSchema, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model,