Enable JSON Schema support
This commit is contained in:
parent
3d25e7bf8c
commit
dd25e5fbf5
@ -80,6 +80,9 @@ type GenerateRequest struct {
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
|
||||
// JsonSchema is an optional json schema to use for this request.
|
||||
JsonSchema string `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@ -105,6 +108,9 @@ type ChatRequest struct {
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
|
||||
// JsonSchema is an optional json schema to use for this request.
|
||||
JsonSchema string `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
|
1
llama/common.h
vendored
1
llama/common.h
vendored
@ -160,6 +160,7 @@ struct gpt_sampler_params {
|
||||
};
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
std::string json_schema; // optional JSON schema to constrain sampling
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
|
||||
|
@ -621,6 +621,7 @@ type SamplingParams struct {
|
||||
PenalizeNl bool
|
||||
Seed uint32
|
||||
Grammar string
|
||||
JsonSchema string
|
||||
}
|
||||
|
||||
func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
|
||||
@ -643,8 +644,12 @@ func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext,
|
||||
|
||||
grammar := C.CString(params.Grammar)
|
||||
defer C.free(unsafe.Pointer(grammar))
|
||||
|
||||
cparams.grammar = grammar
|
||||
|
||||
jsonSchema := C.CString(params.JsonSchema)
|
||||
defer C.free(unsafe.Pointer(jsonSchema))
|
||||
cparams.json_schema = jsonSchema
|
||||
|
||||
context := &SamplingContext{c: C.gpt_sampler_cinit(model.c, &cparams)}
|
||||
if context.c == nil {
|
||||
return nil, errors.New("unable to create sampling context")
|
||||
|
@ -552,6 +552,7 @@ type CompletionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []ImageData `json:"image_data"`
|
||||
Grammar string `json:"grammar"`
|
||||
JsonSchema string `json:"json_schema"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
|
||||
Options
|
||||
@ -614,6 +615,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
samplingParams.PenalizeNl = req.PenalizeNewline
|
||||
samplingParams.Seed = uint32(req.Seed)
|
||||
samplingParams.Grammar = req.Grammar
|
||||
samplingParams.JsonSchema = req.JsonSchema
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
|
13
llama/sampling.cpp
vendored
13
llama/sampling.cpp
vendored
@ -25,11 +25,14 @@
|
||||
*/
|
||||
|
||||
#include "sampling.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "json.hpp"
|
||||
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
// TODO: deduplicate with llama-impl.h
|
||||
@ -172,7 +175,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||
|
||||
auto * result = new gpt_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
|
||||
/* .grmr = */ nullptr,
|
||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
@ -246,6 +249,12 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||
}
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
||||
}
|
||||
if (params.json_schema != "") {
|
||||
nlohmann::json jsonSchema = nlohmann::json::parse(params.json_schema);
|
||||
result->grmr = llama_sampler_init_grammar(model, json_schema_to_grammar(jsonSchema).c_str(), "root");
|
||||
} else {
|
||||
result->grmr = llama_sampler_init_grammar(model, params.grammar.c_str(), "root");
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
1
llama/sampling_ext.cpp
vendored
1
llama/sampling_ext.cpp
vendored
@ -23,6 +23,7 @@ struct gpt_sampler *gpt_sampler_cinit(
|
||||
sparams.penalize_nl = params->penalize_nl;
|
||||
sparams.seed = params->seed;
|
||||
sparams.grammar = params->grammar;
|
||||
sparams.json_schema = params->json_schema;
|
||||
return gpt_sampler_init(model, sparams);
|
||||
} catch (const std::exception & err) {
|
||||
return nullptr;
|
||||
|
1
llama/sampling_ext.h
vendored
1
llama/sampling_ext.h
vendored
@ -29,6 +29,7 @@ extern "C"
|
||||
bool penalize_nl;
|
||||
uint32_t seed;
|
||||
char *grammar;
|
||||
char *json_schema;
|
||||
};
|
||||
|
||||
struct gpt_sampler *gpt_sampler_cinit(
|
||||
|
@ -673,6 +673,7 @@ type CompletionRequest struct {
|
||||
Format string
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
JsonSchema string
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
@ -733,9 +734,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
|
||||
if req.Format == "json" {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
}
|
||||
request["json_schema"] = req.JsonSchema
|
||||
}
|
||||
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
|
@ -61,8 +61,15 @@ type Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type JsonSchema struct {
|
||||
Name string `json:"name"`
|
||||
Schema json.RawMessage `json:"schema"`
|
||||
Strict bool `json:"strict"`
|
||||
}
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
Type string `json:"type"`
|
||||
JsonSchema JsonSchema `json:"json_schema"`
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
@ -476,17 +483,35 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
var format string
|
||||
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
|
||||
var jsonSchema string
|
||||
if r.ResponseFormat != nil {
|
||||
format = "json"
|
||||
if r.ResponseFormat.Type == "json_object" {
|
||||
if len(r.ResponseFormat.JsonSchema.Schema) == 0 {
|
||||
return nil, errors.New("schema must be specified when method is not 'json_schema'")
|
||||
}
|
||||
} else if r.ResponseFormat.Type == "json_schema" {
|
||||
if len(r.ResponseFormat.JsonSchema.Schema) == 0 {
|
||||
return nil, errors.New("schema must be specified when method is 'json_schema'")
|
||||
}
|
||||
jsonSchemaBytes, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to marshal json_schema")
|
||||
}
|
||||
jsonSchema = string(jsonSchemaBytes)
|
||||
} else {
|
||||
return nil, errors.New("invalid response format type")
|
||||
}
|
||||
}
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Tools: r.Tools,
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Format: format,
|
||||
JsonSchema: jsonSchema,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Tools: r.Tools,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -275,10 +275,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
var sb strings.Builder
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
JsonSchema: req.JsonSchema,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
@ -1460,10 +1461,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
JsonSchema: req.JsonSchema,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
|
Loading…
x
Reference in New Issue
Block a user