basic progress

This commit is contained in:
jmorganca 2024-05-28 23:11:48 -07:00
parent 20afaae020
commit 43efc893d7
2 changed files with 98 additions and 31 deletions

View File

@ -4,10 +4,10 @@ package llama
// #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
// #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
// #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
// #cgo darwin,arm64 LDFLAGS: -ld_classic ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
// #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
// #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
// #cgo darwin,amd64 LDFLAGS: -ld_classic -framework Foundation -framework Accelerate
// #cgo darwin,amd64 LDFLAGS: -framework Foundation -framework Accelerate
// #cgo linux CFLAGS: -D_GNU_SOURCE
// #cgo linux CXXFLAGS: -D_GNU_SOURCE
// #cgo windows LDFLAGS: -lmsvcrt
@ -29,11 +29,14 @@ package llama
// #include "clip.h"
// #include "llava.h"
// #include "sampling_ext.h"
//
// bool llamaProgressCallback(float progress, void *user_data);
import "C"
import (
"errors"
"fmt"
"runtime"
"runtime/cgo"
"strings"
"unsafe"
)
@ -65,10 +68,26 @@ type ModelParams struct {
c C.struct_llama_model_params
}
func NewModelParams(numGpuLayers int, mainGpu int) ModelParams {
//export llamaProgressCallback
func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
handle := cgo.Handle(userData)
callback := handle.Value().(func(float32))
callback(float32(progress))
return true
}
func NewModelParams(numGpuLayers int, mainGpu int, callback func(float32)) ModelParams {
params := C.llama_model_default_params()
params.n_gpu_layers = C.int(numGpuLayers)
params.main_gpu = C.int32_t(mainGpu)
handle := cgo.NewHandle(callback)
params.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
params.progress_callback_user_data = unsafe.Pointer(handle)
runtime.SetFinalizer(&params, func(p *C.struct_llama_model_params) {
handle.Delete()
})
return ModelParams{c: params}
}
@ -233,7 +252,8 @@ func (m *Model) TokenToPiece(token int) string {
return strings.TrimRight(string(buf), "\x00")
}
func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) {
func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
maxTokens := len(text) + 2
cTokens := make([]C.llama_token, maxTokens)
cText := C.CString(text)
defer C.free(unsafe.Pointer(cText))

View File

@ -7,6 +7,7 @@ import (
"fmt"
"log"
"log/slog"
"math"
"net"
"net/http"
"runtime"
@ -28,6 +29,9 @@ type Sequence struct {
// channel to send responses over
responses chan string
// number of tokens to predict
numPredict int
samplingCtx *llama.SamplingContext
// channel to send back the embedding if embedding only
@ -38,6 +42,8 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
doneReason string
}
// prompt returns true if the prompt is still being processed
@ -46,11 +52,18 @@ func (s *Sequence) prompt() bool {
}
func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
tokens, err := s.lc.Model().Tokenize(prompt, 2048, false, true)
tokens, err := s.lc.Model().Tokenize(prompt, false, true)
if err != nil {
panic(err)
}
// truncate to last n tokens
// TODO: this shouldn't happen and will severely impact generation
// quality. instead we should ensure to cut prompt in the API.
if len(tokens) > s.numCtx {
tokens = tokens[:s.numCtx]
}
var sc *llama.SamplingContext
if params != nil {
sc = llama.NewSamplingContext(*params)
@ -83,9 +96,16 @@ type Server struct {
// TODO (jmorganca): this can probably be moved into run()
seqs []*Sequence
// context window size
numCtx int
mu sync.Mutex
cond *sync.Cond
progress float32
status string
}
func (s *Server) allNil() bool {
@ -183,6 +203,15 @@ func (s *Server) run(ctx context.Context) {
continue
}
// we've reached the context limit
if seq.nPast > s.numCtx {
seq.doneReason = "limit"
close(seq.responses)
s.lc.KvCacheSeqRm(i, 0, -1)
s.seqs[i] = nil
continue
}
for j, t := range seq.tokens {
// todo: make this n_batch
if j > s.batchSize {
@ -252,6 +281,7 @@ func (s *Server) run(ctx context.Context) {
// as it's important for the /api/generate context
// seq.responses <- piece
seq.doneReason = "stop"
close(seq.responses)
seq.samplingCtx.Free()
pieces[i] = []string{}
@ -273,6 +303,7 @@ func (s *Server) run(ctx context.Context) {
}
s.lc.KvCacheSeqRm(i, 0, -1)
seq.doneReason = "stop"
close(seq.responses)
seq.samplingCtx.Free()
pieces[i] = []string{}
@ -411,6 +442,24 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
}
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
// TODO (jmorganca): is it safe to do this concurrently with decoding?
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status,
Progress: s.progress,
}); err != nil {
log.Println("Failed to encode result:", err)
return
}
}
func main() {
mpath := flag.String("model", "", "Path to model binary file")
ppath := flag.String("projector", "", "Path to projector binary file")
@ -425,36 +474,31 @@ func main() {
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
flag.Parse()
// load the model
llama.BackendInit()
params := llama.NewModelParams(*nGpuLayers, *mainGpu)
model := llama.LoadModelFromFile(*mpath, params)
if *lpath != "" {
model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
}
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
lc := llama.NewContextWithModel(model, ctxParams)
if lc == nil {
panic("Failed to create context")
}
var cc *llama.ClipContext
if *ppath != "" {
cc = llama.NewClipContext(*ppath)
if cc == nil {
panic("Failed to create clip context")
}
}
server := &Server{
model: model,
lc: lc,
cc: cc,
numCtx: *numCtx,
batchSize: *batchSize,
parallel: *parallel,
seqs: make([]*Sequence, *parallel),
status: "loading",
}
// load the model
llama.BackendInit()
params := llama.NewModelParams(*nGpuLayers, *mainGpu, func(progress float32) {
slog.Info("Loading model", "progress %", math.Round(float64(progress*100)))
server.progress = progress
})
server.model = llama.LoadModelFromFile(*mpath, params)
if *lpath != "" {
server.model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
}
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
server.lc = llama.NewContextWithModel(server.model, ctxParams)
if *ppath != "" {
server.cc = llama.NewClipContext(*ppath)
}
server.cond = sync.NewCond(&server.mu)
@ -473,11 +517,14 @@ func main() {
mux := http.NewServeMux()
mux.HandleFunc("/embeddings", server.embeddings)
mux.HandleFunc("/completion", server.completion)
mux.HandleFunc("/health", server.health)
httpServer := http.Server{
Handler: mux,
}
server.status = "ready"
log.Println("Server listening on", addr)
if err := httpServer.Serve(listener); err != nil {
log.Fatal("server error:", err)