Merge 8ce63ea9e6e3144896c31e9632b11892bc5e131f into d7eb05b9361febead29a74e71ddffc2ebeff5302
This commit is contained in:
commit
cccbaacaea
7
api/data/README.md
Normal file
7
api/data/README.md
Normal file
@ -0,0 +1,7 @@
|
||||
## Model information
|
||||
|
||||
This directory contains a tiny llama model that is used to test the API client in our example tests.
|
||||
|
||||
The model source is from [here](https://huggingface.co/aladar/llama-2-tiny-random-GGUF/) thanks to [aladar on huggingface](https://huggingface.co/aladar) and was likely inspired originally by this [tweet from @StasBekman](https://twitter.com/StasBekman/status/1724519457790509310) and it's [accompanying huggingface mode](https://huggingface.co/stas/tiny-random-llama-2).
|
||||
|
||||
It's a simple model that mostly produces gibberish, but for unit testing purposes it works great.
|
BIN
api/data/llama-2-tiny-random.gguf
Normal file
BIN
api/data/llama-2-tiny-random.gguf
Normal file
Binary file not shown.
5
api/data/llama-2-tiny-random.model
Normal file
5
api/data/llama-2-tiny-random.model
Normal file
@ -0,0 +1,5 @@
|
||||
FROM llama-2-tiny-random.gguf
|
||||
PARAMETER seed 42
|
||||
PARAMETER top_p 0.9
|
||||
PARAMETER stop foo
|
||||
PARAMETER stop bar
|
395
api/example_test.go
Normal file
395
api/example_test.go
Normal file
@ -0,0 +1,395 @@
|
||||
package api_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/server"
|
||||
)
|
||||
|
||||
const testModelName = "llama-2-tiny-random"
|
||||
const testGgufFilePath = "data/llama-2-tiny-random.gguf"
|
||||
const testModelFilePath = "data/llama-2-tiny-random.model"
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
||||
// Here we start a server in the background to run with our tests.
|
||||
host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
|
||||
if err != nil {
|
||||
host, port = "127.0.0.1", "11434"
|
||||
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
|
||||
host = ip.String()
|
||||
}
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go server.Serve(ln)
|
||||
|
||||
// Create a new context
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
|
||||
// load our test model
|
||||
err = loadTestModel(ctx, client, testModelName)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer deleteTestModel(ctx, client, testModelName) // not sure if we need this and is more of a question of what developers would prefer.
|
||||
|
||||
result := m.Run()
|
||||
|
||||
ln.Close()
|
||||
|
||||
os.Exit(result)
|
||||
}
|
||||
|
||||
func ExampleClient_Heartbeat() {
|
||||
// Create a new context
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Get a heartbeat from the server. If no error is returned, we know the server is up.
|
||||
err = client.Heartbeat(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
fmt.Println("server is up")
|
||||
}
|
||||
// Output:
|
||||
// server is up
|
||||
}
|
||||
|
||||
func ExampleClient_Version() {
|
||||
// Create a new context
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Get the version of the server
|
||||
version, err := client.Version(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// if we get an empty version, something is wrong
|
||||
if version == "" {
|
||||
log.Fatal("version is empty")
|
||||
} else {
|
||||
fmt.Println("version is not empty")
|
||||
}
|
||||
// Output:
|
||||
// version is not empty
|
||||
}
|
||||
|
||||
func ExampleClient_List() {
|
||||
|
||||
// Create a new context
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// List all the models on the server as tags
|
||||
tags, err := client.List(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// if we get an empty list it means that there are no models on the server
|
||||
if len(tags.Models) == 0 {
|
||||
log.Fatal("no tags returned")
|
||||
} else {
|
||||
fmt.Println("tags returned")
|
||||
}
|
||||
// Output:
|
||||
// tags returned
|
||||
}
|
||||
|
||||
func ExampleClient_Chat() {
|
||||
|
||||
// Create a new context
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// List all the models on the server as tags
|
||||
tags, err := client.List(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// if we get an empty list it means that there are no models on the server
|
||||
if len(tags.Models) == 0 {
|
||||
log.Fatal("no tags returned")
|
||||
}
|
||||
|
||||
// sort the tags.models by Size least to greatest
|
||||
sort.Slice(tags.Models, func(i, j int) bool {
|
||||
return tags.Models[i].Size < tags.Models[j].Size
|
||||
})
|
||||
|
||||
// Get the first (smallest) model in the list (this will like the defined in the testModelName constant)
|
||||
model := tags.Models[0].Model
|
||||
|
||||
// Now we're
|
||||
|
||||
// Create a new chat message to send to the server. Role must be defined and can be "user", system", or a third one that I forget.
|
||||
// These should be defined in the api package as an enum constant or something.
|
||||
chatMessage := api.Message{
|
||||
Content: "Hello! Tell me about yourself in exactly one sentence.",
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
// Create a new chat request to send to the server. This specifies the model and the messages to send.
|
||||
req := &api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{chatMessage},
|
||||
}
|
||||
|
||||
// Create a function to handle the response from the server. This is a callback function that will be called for each response from the server.
|
||||
var modelResponseText string // we will store the response from the server in this variable
|
||||
fn := func(response api.ChatResponse) error { // this callback function fires for each token returned from the server
|
||||
|
||||
modelResponseText += response.Message.Content // append the response to the modelResponseText variable to retrieve the whole message
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := client.Chat(ctx, req, fn); err != nil { // send the request to the server and if there is an error, log it
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Fatal("context was canceled")
|
||||
}
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if modelResponseText == "" {
|
||||
log.Fatal("modelResponseText is empty")
|
||||
} else {
|
||||
fmt.Println("model responded")
|
||||
}
|
||||
|
||||
// Output:
|
||||
// model responded
|
||||
}
|
||||
|
||||
func ExampleClient_Create() {
|
||||
|
||||
// Create a new context
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ggufFilePath := testGgufFilePath
|
||||
modelFilePath := testModelFilePath
|
||||
modelName := "tiny-spoof"
|
||||
|
||||
// get the model file bytes from the testModel file
|
||||
modelBytes, err := os.ReadFile(modelFilePath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
modelString := string(modelBytes)
|
||||
|
||||
// Create a new create request
|
||||
req := &api.CreateRequest{
|
||||
Path: ggufFilePath,
|
||||
Model: modelName,
|
||||
Modelfile: modelString,
|
||||
}
|
||||
|
||||
var progressMessages []string // we will store the response from the server in this variable
|
||||
// Create a function to handle the response from the server. This is a callback function that will be called for each response from the server.
|
||||
fn := func(response api.ProgressResponse) error { // this callback function fires for each response returned from the server
|
||||
|
||||
progressMessages = append(progressMessages, response.Status) // append the response to the modelResponseText variable to retrieve the whole message
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create the model
|
||||
if err := client.Create(ctx, req, fn); err != nil { // send the request to the server and if there is an error, log it
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Fatal("context was canceled")
|
||||
}
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// List all models on the server
|
||||
tags, err := client.List(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Iterate through the tags and check if the model was created
|
||||
modelCreatedFlag := false
|
||||
for _, tag := range tags.Models {
|
||||
if strings.Contains(tag.Model, modelName) {
|
||||
modelCreatedFlag = true
|
||||
}
|
||||
}
|
||||
|
||||
// delete the model
|
||||
deleteRequest := &api.DeleteRequest{
|
||||
Model: modelName,
|
||||
}
|
||||
|
||||
if err := client.Delete(ctx, deleteRequest); err != nil { // send the request to the server and if there is an error, log it
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Fatal("context was canceled")
|
||||
}
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// check if the model was deleted
|
||||
tags, err = client.List(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// iterate through the tags and check if the model was deleted
|
||||
for _, tag := range tags.Models {
|
||||
if strings.Contains(tag.Model, modelName) {
|
||||
log.Fatal("model was not deleted")
|
||||
}
|
||||
}
|
||||
|
||||
if modelCreatedFlag == true {
|
||||
fmt.Println("model created")
|
||||
} else {
|
||||
log.Fatal("model not created")
|
||||
}
|
||||
// Output:
|
||||
// model created
|
||||
}
|
||||
|
||||
func ExampleClient_Show() {
|
||||
|
||||
// Create a new context
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// List all the models on the server as tags
|
||||
tags, err := client.List(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// if we get an empty list it means that there are no models on the server
|
||||
if len(tags.Models) == 0 {
|
||||
log.Fatal("no tags returned")
|
||||
}
|
||||
|
||||
// Get the first model in the list
|
||||
model := tags.Models[0].Model
|
||||
|
||||
request := api.ShowRequest{
|
||||
Model: model,
|
||||
}
|
||||
|
||||
// Show the model
|
||||
show, err := client.Show(ctx, &request)
|
||||
|
||||
// if we get an empty show, something is wrong
|
||||
if show == nil {
|
||||
log.Fatal("show is empty")
|
||||
} else {
|
||||
fmt.Println("show is not empty")
|
||||
}
|
||||
// Output:
|
||||
// show is not empty
|
||||
}
|
||||
|
||||
func loadTestModel(ctx context.Context, client *api.Client, modelName string) error {
|
||||
|
||||
ggufFilePath := testGgufFilePath
|
||||
modelFilePath := testModelFilePath
|
||||
|
||||
// get the model file bytes from the testModel file
|
||||
modelBytes, err := os.ReadFile(modelFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelString := string(modelBytes)
|
||||
|
||||
// Create a new create request
|
||||
req := &api.CreateRequest{
|
||||
Path: ggufFilePath,
|
||||
Model: modelName,
|
||||
Modelfile: modelString,
|
||||
}
|
||||
|
||||
var progressMessages []string // we will store the response from the server in this variable
|
||||
// Create a function to handle the response from the server. This is a callback function that will be called for each response from the server.
|
||||
fn := func(response api.ProgressResponse) error { // this callback function fires for each response returned from the server
|
||||
|
||||
progressMessages = append(progressMessages, response.Status) // append the response to the modelResponseText variable to retrieve the whole message
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create the model
|
||||
if err := client.Create(ctx, req, fn); err != nil { // send the request to the server and if there is an error, log it
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Fatal("context was canceled")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteTestModel(ctx context.Context, client *api.Client, modelName string) error {
|
||||
|
||||
deleteRequest := &api.DeleteRequest{
|
||||
Model: modelName,
|
||||
}
|
||||
|
||||
if err := client.Delete(ctx, deleteRequest); err != nil { // send the request to the server and if there is an error, log it
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Fatal("context was canceled")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user