Enable JSON Schema support

This commit is contained in:
Hieu Nguyen 2024-11-09 17:21:13 +07:00
parent 3d25e7bf8c
commit dd25e5fbf5
10 changed files with 73 additions and 22 deletions

View File

@ -80,6 +80,9 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
// JsonSchema is an optional json schema to use for this request.
JsonSchema string `json:"json_schema,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@ -105,6 +108,9 @@ type ChatRequest struct {
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
// JsonSchema is an optional json schema to use for this request.
JsonSchema string `json:"json_schema,omitempty"`
}
type Tools []Tool

1
llama/common.h vendored
View File

@ -160,6 +160,7 @@ struct gpt_sampler_params {
};
std::string grammar; // optional BNF-like grammar to constrain sampling
std::string json_schema; // optional JSON schema to constrain sampling
std::vector<llama_logit_bias> logit_bias; // logit biases to apply

View File

@ -621,6 +621,7 @@ type SamplingParams struct {
PenalizeNl bool
Seed uint32
Grammar string
JsonSchema string
}
func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
@ -643,8 +644,12 @@ func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext,
grammar := C.CString(params.Grammar)
defer C.free(unsafe.Pointer(grammar))
cparams.grammar = grammar
jsonSchema := C.CString(params.JsonSchema)
defer C.free(unsafe.Pointer(jsonSchema))
cparams.json_schema = jsonSchema
context := &SamplingContext{c: C.gpt_sampler_cinit(model.c, &cparams)}
if context.c == nil {
return nil, errors.New("unable to create sampling context")

View File

@ -552,6 +552,7 @@ type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
JsonSchema string `json:"json_schema"`
CachePrompt bool `json:"cache_prompt"`
Options
@ -614,6 +615,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
samplingParams.PenalizeNl = req.PenalizeNewline
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar
samplingParams.JsonSchema = req.JsonSchema
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,

13
llama/sampling.cpp vendored
View File

@ -25,11 +25,14 @@
*/
#include "sampling.h"
#include "common.h"
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include <cmath>
#include <unordered_map>
#include <iostream>
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
@ -172,7 +175,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
auto * result = new gpt_sampler {
/* .params = */ params,
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
/* .grmr = */ nullptr,
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
@ -246,6 +249,12 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
}
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
}
if (params.json_schema != "") {
nlohmann::json jsonSchema = nlohmann::json::parse(params.json_schema);
result->grmr = llama_sampler_init_grammar(model, json_schema_to_grammar(jsonSchema).c_str(), "root");
} else {
result->grmr = llama_sampler_init_grammar(model, params.grammar.c_str(), "root");
}
return result;
}

View File

@ -23,6 +23,7 @@ struct gpt_sampler *gpt_sampler_cinit(
sparams.penalize_nl = params->penalize_nl;
sparams.seed = params->seed;
sparams.grammar = params->grammar;
sparams.json_schema = params->json_schema;
return gpt_sampler_init(model, sparams);
} catch (const std::exception & err) {
return nullptr;

View File

@ -29,6 +29,7 @@ extern "C"
bool penalize_nl;
uint32_t seed;
char *grammar;
char *json_schema;
};
struct gpt_sampler *gpt_sampler_cinit(

View File

@ -673,6 +673,7 @@ type CompletionRequest struct {
Format string
Images []ImageData
Options *api.Options
JsonSchema string
}
type CompletionResponse struct {
@ -733,9 +734,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if req.Format == "json" {
request["grammar"] = jsonGrammar
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
}
request["json_schema"] = req.JsonSchema
}
// Handling JSON marshaling with special characters unescaped.

View File

@ -61,8 +61,15 @@ type Usage struct {
TotalTokens int `json:"total_tokens"`
}
type JsonSchema struct {
Name string `json:"name"`
Schema json.RawMessage `json:"schema"`
Strict bool `json:"strict"`
}
type ResponseFormat struct {
Type string `json:"type"`
Type string `json:"type"`
JsonSchema JsonSchema `json:"json_schema"`
}
type EmbedRequest struct {
@ -476,17 +483,35 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
var format string
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
var jsonSchema string
if r.ResponseFormat != nil {
format = "json"
if r.ResponseFormat.Type == "json_object" {
if len(r.ResponseFormat.JsonSchema.Schema) == 0 {
return nil, errors.New("schema must be specified when method is not 'json_schema'")
}
} else if r.ResponseFormat.Type == "json_schema" {
if len(r.ResponseFormat.JsonSchema.Schema) == 0 {
return nil, errors.New("schema must be specified when method is 'json_schema'")
}
jsonSchemaBytes, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
if err != nil {
return nil, errors.New("failed to marshal json_schema")
}
jsonSchema = string(jsonSchemaBytes)
} else {
return nil, errors.New("invalid response format type")
}
}
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Format: format,
Options: options,
Stream: &r.Stream,
Tools: r.Tools,
Model: r.Model,
Messages: messages,
Format: format,
JsonSchema: jsonSchema,
Options: options,
Stream: &r.Stream,
Tools: r.Tools,
}, nil
}

View File

@ -275,10 +275,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
JsonSchema: req.JsonSchema,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
@ -1460,10 +1461,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
go func() {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
JsonSchema: req.JsonSchema,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,