basic progress
This commit is contained in:
parent
20afaae020
commit
43efc893d7
@ -4,10 +4,10 @@ package llama
|
||||
// #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
|
||||
// #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
||||
// #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
||||
// #cgo darwin,arm64 LDFLAGS: -ld_classic ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
|
||||
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
|
||||
// #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
|
||||
// #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
|
||||
// #cgo darwin,amd64 LDFLAGS: -ld_classic -framework Foundation -framework Accelerate
|
||||
// #cgo darwin,amd64 LDFLAGS: -framework Foundation -framework Accelerate
|
||||
// #cgo linux CFLAGS: -D_GNU_SOURCE
|
||||
// #cgo linux CXXFLAGS: -D_GNU_SOURCE
|
||||
// #cgo windows LDFLAGS: -lmsvcrt
|
||||
@ -29,11 +29,14 @@ package llama
|
||||
// #include "clip.h"
|
||||
// #include "llava.h"
|
||||
// #include "sampling_ext.h"
|
||||
//
|
||||
// bool llamaProgressCallback(float progress, void *user_data);
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"runtime/cgo"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
@ -65,10 +68,26 @@ type ModelParams struct {
|
||||
c C.struct_llama_model_params
|
||||
}
|
||||
|
||||
func NewModelParams(numGpuLayers int, mainGpu int) ModelParams {
|
||||
//export llamaProgressCallback
|
||||
func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
|
||||
handle := cgo.Handle(userData)
|
||||
callback := handle.Value().(func(float32))
|
||||
callback(float32(progress))
|
||||
return true
|
||||
}
|
||||
|
||||
func NewModelParams(numGpuLayers int, mainGpu int, callback func(float32)) ModelParams {
|
||||
params := C.llama_model_default_params()
|
||||
params.n_gpu_layers = C.int(numGpuLayers)
|
||||
params.main_gpu = C.int32_t(mainGpu)
|
||||
|
||||
handle := cgo.NewHandle(callback)
|
||||
params.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
|
||||
params.progress_callback_user_data = unsafe.Pointer(handle)
|
||||
runtime.SetFinalizer(¶ms, func(p *C.struct_llama_model_params) {
|
||||
handle.Delete()
|
||||
})
|
||||
|
||||
return ModelParams{c: params}
|
||||
}
|
||||
|
||||
@ -233,7 +252,8 @@ func (m *Model) TokenToPiece(token int) string {
|
||||
return strings.TrimRight(string(buf), "\x00")
|
||||
}
|
||||
|
||||
func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) {
|
||||
func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
|
||||
maxTokens := len(text) + 2
|
||||
cTokens := make([]C.llama_token, maxTokens)
|
||||
cText := C.CString(text)
|
||||
defer C.free(unsafe.Pointer(cText))
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
@ -28,6 +29,9 @@ type Sequence struct {
|
||||
// channel to send responses over
|
||||
responses chan string
|
||||
|
||||
// number of tokens to predict
|
||||
numPredict int
|
||||
|
||||
samplingCtx *llama.SamplingContext
|
||||
|
||||
// channel to send back the embedding if embedding only
|
||||
@ -38,6 +42,8 @@ type Sequence struct {
|
||||
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
doneReason string
|
||||
}
|
||||
|
||||
// prompt returns true if the prompt is still being processed
|
||||
@ -46,11 +52,18 @@ func (s *Sequence) prompt() bool {
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
||||
tokens, err := s.lc.Model().Tokenize(prompt, 2048, false, true)
|
||||
tokens, err := s.lc.Model().Tokenize(prompt, false, true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// truncate to last n tokens
|
||||
// TODO: this shouldn't happen and will severely impact generation
|
||||
// quality. instead we should ensure to cut prompt in the API.
|
||||
if len(tokens) > s.numCtx {
|
||||
tokens = tokens[:s.numCtx]
|
||||
}
|
||||
|
||||
var sc *llama.SamplingContext
|
||||
if params != nil {
|
||||
sc = llama.NewSamplingContext(*params)
|
||||
@ -83,9 +96,16 @@ type Server struct {
|
||||
// TODO (jmorganca): this can probably be moved into run()
|
||||
seqs []*Sequence
|
||||
|
||||
// context window size
|
||||
numCtx int
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
cond *sync.Cond
|
||||
|
||||
progress float32
|
||||
|
||||
status string
|
||||
}
|
||||
|
||||
func (s *Server) allNil() bool {
|
||||
@ -183,6 +203,15 @@ func (s *Server) run(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
// we've reached the context limit
|
||||
if seq.nPast > s.numCtx {
|
||||
seq.doneReason = "limit"
|
||||
close(seq.responses)
|
||||
s.lc.KvCacheSeqRm(i, 0, -1)
|
||||
s.seqs[i] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
for j, t := range seq.tokens {
|
||||
// todo: make this n_batch
|
||||
if j > s.batchSize {
|
||||
@ -252,6 +281,7 @@ func (s *Server) run(ctx context.Context) {
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
|
||||
seq.doneReason = "stop"
|
||||
close(seq.responses)
|
||||
seq.samplingCtx.Free()
|
||||
pieces[i] = []string{}
|
||||
@ -273,6 +303,7 @@ func (s *Server) run(ctx context.Context) {
|
||||
}
|
||||
|
||||
s.lc.KvCacheSeqRm(i, 0, -1)
|
||||
seq.doneReason = "stop"
|
||||
close(seq.responses)
|
||||
seq.samplingCtx.Free()
|
||||
pieces[i] = []string{}
|
||||
@ -411,6 +442,24 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
// TODO (jmorganca): is it safe to do this concurrently with decoding?
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||
Status: s.status,
|
||||
Progress: s.progress,
|
||||
}); err != nil {
|
||||
log.Println("Failed to encode result:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
mpath := flag.String("model", "", "Path to model binary file")
|
||||
ppath := flag.String("projector", "", "Path to projector binary file")
|
||||
@ -425,36 +474,31 @@ func main() {
|
||||
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
flag.Parse()
|
||||
|
||||
// load the model
|
||||
llama.BackendInit()
|
||||
params := llama.NewModelParams(*nGpuLayers, *mainGpu)
|
||||
model := llama.LoadModelFromFile(*mpath, params)
|
||||
|
||||
if *lpath != "" {
|
||||
model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
|
||||
}
|
||||
|
||||
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
|
||||
lc := llama.NewContextWithModel(model, ctxParams)
|
||||
if lc == nil {
|
||||
panic("Failed to create context")
|
||||
}
|
||||
|
||||
var cc *llama.ClipContext
|
||||
if *ppath != "" {
|
||||
cc = llama.NewClipContext(*ppath)
|
||||
if cc == nil {
|
||||
panic("Failed to create clip context")
|
||||
}
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
model: model,
|
||||
lc: lc,
|
||||
cc: cc,
|
||||
numCtx: *numCtx,
|
||||
batchSize: *batchSize,
|
||||
parallel: *parallel,
|
||||
seqs: make([]*Sequence, *parallel),
|
||||
status: "loading",
|
||||
}
|
||||
|
||||
// load the model
|
||||
llama.BackendInit()
|
||||
params := llama.NewModelParams(*nGpuLayers, *mainGpu, func(progress float32) {
|
||||
slog.Info("Loading model", "progress %", math.Round(float64(progress*100)))
|
||||
server.progress = progress
|
||||
})
|
||||
server.model = llama.LoadModelFromFile(*mpath, params)
|
||||
|
||||
if *lpath != "" {
|
||||
server.model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
|
||||
}
|
||||
|
||||
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
|
||||
server.lc = llama.NewContextWithModel(server.model, ctxParams)
|
||||
|
||||
if *ppath != "" {
|
||||
server.cc = llama.NewClipContext(*ppath)
|
||||
}
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
@ -473,11 +517,14 @@ func main() {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/embeddings", server.embeddings)
|
||||
mux.HandleFunc("/completion", server.completion)
|
||||
mux.HandleFunc("/health", server.health)
|
||||
|
||||
httpServer := http.Server{
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
server.status = "ready"
|
||||
|
||||
log.Println("Server listening on", addr)
|
||||
if err := httpServer.Serve(listener); err != nil {
|
||||
log.Fatal("server error:", err)
|
||||
|
Loading…
x
Reference in New Issue
Block a user