Compare commits

...

8 Commits

Author SHA1 Message Date
Josh Yan
4da5d5beaa lint 2024-08-28 10:23:41 -07:00
Josh Yan
cc17b02b23 update 2024-08-28 09:58:23 -07:00
Josh Yan
73d69bc90b remove types 2024-08-27 16:45:07 -07:00
Josh Yan
9bc42f532b rmv api type 2024-08-27 16:45:07 -07:00
Josh Yan
07c0f66f5e rm print 2024-08-27 16:45:04 -07:00
Josh Yan
4a7bfca902 change progress msg 2024-08-27 16:44:38 -07:00
Josh Yan
04f2154505 fixed cgo 2024-08-27 16:44:38 -07:00
Josh Yan
de9b21b472 quantize progress 2024-08-27 16:44:32 -07:00
6 changed files with 1342 additions and 5 deletions

1
.gitattributes vendored
View File

@ -1,3 +1,4 @@
llm/ext_server/* linguist-vendored llm/ext_server/* linguist-vendored
llm/*.h linguist-vendored
* text=auto * text=auto
*.go text eol=lf *.go text eol=lf

View File

@ -124,6 +124,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
bars := make(map[string]*progress.Bar) bars := make(map[string]*progress.Bar)
var quantizeSpin *progress.Spinner
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
spinner.Stop() spinner.Stop()
@ -136,6 +137,15 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
bar.Set(resp.Completed) bar.Set(resp.Completed)
} else if strings.Contains(resp.Status, "quantizing") {
spinner.Stop()
if quantizeSpin != nil {
quantizeSpin.SetMessage(resp.Status)
} else {
quantizeSpin = progress.NewSpinner(resp.Status)
p.Add("quantize", quantizeSpin)
}
} else if status != resp.Status { } else if status != resp.Status {
spinner.Stop() spinner.Stop()

1227
llm/llama.h vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
package llm package llm
// #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include // #cgo CPPFLAGS: -Illama.cpp/ggml/include
// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread // #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
// #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal // #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
// #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src // #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
@ -9,12 +9,24 @@ package llm
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src // #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
// #include <stdlib.h> // #include <stdlib.h>
// #include <stdatomic.h>
// #include "llama.h" // #include "llama.h"
// bool update_quantize_progress(float progress, void* data) {
// atomic_int* atomicData = (atomic_int*)data;
// int intProgress = *((int*)&progress);
// atomic_store(atomicData, intProgress);
// return true;
// }
import "C" import "C"
import ( import (
"errors" "errors"
"fmt"
"sync/atomic"
"time"
"unsafe" "unsafe"
"github.com/ollama/ollama/api"
) )
// SystemInfo is an unused example of calling llama.cpp functions using CGo // SystemInfo is an unused example of calling llama.cpp functions using CGo
@ -22,17 +34,49 @@ func SystemInfo() string {
return C.GoString(C.llama_print_system_info()) return C.GoString(C.llama_print_system_info())
} }
func Quantize(infile, outfile string, ftype fileType) error { func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error {
cinfile := C.CString(infile) cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile)) defer C.free(unsafe.Pointer(cinfile))
coutfile := C.CString(outfile) coutfile := C.CString(outfile)
defer C.free(unsafe.Pointer(coutfile)) defer C.free(unsafe.Pointer(coutfile))
params := C.llama_model_quantize_default_params() params := C.llama_model_quantize_default_params()
params.nthread = -1 params.nthread = -1
params.ftype = ftype.Value() params.ftype = ftype.Value()
// Initialize "global" to store progress
store := (*int32)(C.malloc(C.sizeof_int))
defer C.free(unsafe.Pointer(store))
// Initialize store value, e.g., setting initial progress to 0
atomic.StoreInt32(store, 0)
params.quantize_callback_data = unsafe.Pointer(store)
params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress)
ticker := time.NewTicker(30 * time.Millisecond)
done := make(chan struct{})
defer close(done)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
progressInt := atomic.LoadInt32(store)
progress := *(*float32)(unsafe.Pointer(&progressInt))
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model %d%%", 100*int(progress)/tensorCount),
})
case <-done:
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model 100%%"),
})
return
}
}
}()
if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 { if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
return errors.New("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version") return errors.New("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
} }

View File

@ -0,0 +1,52 @@
From ed941590d59fc07b1ad21d6aa458588e47d1e446 Mon Sep 17 00:00:00 2001
From: Josh Yan <jyan00017@gmail.com>
Date: Wed, 10 Jul 2024 13:39:39 -0700
Subject: [PATCH] quantize progress
---
include/llama.h | 3 +++
src/llama.cpp | 8 ++++++++
2 files changed, 11 insertions(+)
diff --git a/include/llama.h b/include/llama.h
index bb4b05ba..613db68e 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -349,6 +349,9 @@ extern "C" {
bool keep_split; // quantize to the same number of shards
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
+
+ llama_progress_callback quantize_callback; // callback to report quantization progress
+ void * quantize_callback_data; // user data for the callback
} llama_model_quantize_params;
// grammar types
diff --git a/src/llama.cpp b/src/llama.cpp
index 2b9ace28..ac640c02 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -18252,6 +18252,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
const auto tn = LLM_TN(model.arch);
new_ofstream(0);
for (int i = 0; i < ml.n_tensors; ++i) {
+ if (params->quantize_callback){
+ if (!params->quantize_callback(i, params->quantize_callback_data)) {
+ return;
+ }
+ }
+
auto weight = ml.get_weight(i);
struct ggml_tensor * tensor = weight->tensor;
if (weight->idx != cur_split && params->keep_split) {
@@ -18789,6 +18795,8 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
/*.keep_split =*/ false,
/*.imatrix =*/ nullptr,
/*.kv_overrides =*/ nullptr,
+ /*.quantize_callback =*/ nullptr,
+ /*.quantize_callback_data =*/ nullptr,
};
return result;
--
2.39.3 (Apple Git-146)

View File

@ -435,11 +435,14 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err return err
} }
tensorCount := len(baseLayer.GGML.Tensors().Items)
ft := baseLayer.GGML.KV().FileType() ft := baseLayer.GGML.KV().FileType()
if !slices.Contains([]string{"F16", "F32"}, ft.String()) { if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
return errors.New("quantization is only supported for F16 and F32 models") return errors.New("quantization is only supported for F16 and F32 models")
} else if want != ft { } else if want != ft {
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)}) fn(api.ProgressResponse{
Status: "quantizing model tensors",
})
blob, err := GetBlobsPath(baseLayer.Digest) blob, err := GetBlobsPath(baseLayer.Digest)
if err != nil { if err != nil {
@ -453,7 +456,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
defer temp.Close() defer temp.Close()
defer os.Remove(temp.Name()) defer os.Remove(temp.Name())
if err := llm.Quantize(blob, temp.Name(), want); err != nil { if err := llm.Quantize(blob, temp.Name(), want, fn, tensorCount); err != nil {
return err return err
} }