Compare commits
3 Commits
main
...
jyan/palig
Author | SHA1 | Date | |
---|---|---|---|
|
e6802df906 | ||
|
c631633bce | ||
|
7de230f005 |
65
llm/ext_server/server.cpp
vendored
65
llm/ext_server/server.cpp
vendored
@ -1271,8 +1271,49 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// for multiple images processing
|
bool prepare_pali(server_slot &slot, int n_batch)
|
||||||
bool ingest_images(server_slot &slot, int n_batch)
|
{
|
||||||
|
int n_past = 0;
|
||||||
|
int image_idx = 0;
|
||||||
|
slot_image &img = slot.images[image_idx];
|
||||||
|
|
||||||
|
// rescale image embeddings
|
||||||
|
float *data = img.image_embedding;
|
||||||
|
for (int i = 0; i < 2048 * 256; i++)
|
||||||
|
{
|
||||||
|
data[i] = data[i] / sqrt(2048);
|
||||||
|
}
|
||||||
|
set_image_embeds(ctx, data);
|
||||||
|
|
||||||
|
// generate user_prompt -> this should contain image tokens prepended and a new line appended:
|
||||||
|
// batch.n_tokens += (int)slot.images.size() * llama_n_embd(model);
|
||||||
|
std::vector<llama_token> tokens;
|
||||||
|
std::string prompt = "How much ketchup is in this image?";
|
||||||
|
std::vector<llama_token> text = ::llama_tokenize(ctx, prompt, false, true);
|
||||||
|
|
||||||
|
for (int i = 0; i < (int)slot.images.size() * 256; i++)
|
||||||
|
{
|
||||||
|
tokens.push_back(257152);
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens.push_back(2);
|
||||||
|
|
||||||
|
for (int i = 0; i < text.size(); i++)
|
||||||
|
{
|
||||||
|
tokens.push_back(text[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens.push_back(108);
|
||||||
|
|
||||||
|
for (int i = 0; i < (int)tokens.size(); ++i)
|
||||||
|
{
|
||||||
|
llama_batch_add(batch, tokens[i], system_tokens.size() + slot.n_past, {slot.id}, true);
|
||||||
|
slot.n_past += 1;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool process_llava(server_slot &slot, int n_batch)
|
||||||
{
|
{
|
||||||
int image_idx = 0;
|
int image_idx = 0;
|
||||||
|
|
||||||
@ -1349,6 +1390,20 @@ struct llama_server_context
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// for multiple images processing based on model architecture
|
||||||
|
bool ingest_images(server_slot &slot, int n_batch)
|
||||||
|
{
|
||||||
|
switch (llama_get_architecture(model))
|
||||||
|
{
|
||||||
|
case 0:
|
||||||
|
return process_llava(slot, n_batch);
|
||||||
|
case 25:
|
||||||
|
return prepare_pali(slot, n_batch);
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void request_cancel(int task_id)
|
void request_cancel(int task_id)
|
||||||
{
|
{
|
||||||
task_server task;
|
task_server task;
|
||||||
@ -1838,6 +1893,12 @@ struct llama_server_context
|
|||||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
|
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
|
||||||
slot_npast++;
|
slot_npast++;
|
||||||
}
|
}
|
||||||
|
LOG_DEBUG("hi gpt params processing images", {
|
||||||
|
{"gpt_params.model", params.model.c_str()},
|
||||||
|
{"model alias", params.model_alias.c_str()},
|
||||||
|
});
|
||||||
|
printf("gpt_params model is %s\n", params.model.c_str());
|
||||||
|
printf("gpt_params model is %s\n", params.model.c_str());
|
||||||
|
|
||||||
if (has_images && !ingest_images(slot, n_batch))
|
if (has_images && !ingest_images(slot, n_batch))
|
||||||
{
|
{
|
||||||
|
106
llm/patches/12-paligemma.diff
Normal file
106
llm/patches/12-paligemma.diff
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
|
||||||
|
index 7cda5f10..671806fd 100644
|
||||||
|
--- a/examples/llava/clip.cpp
|
||||||
|
+++ b/examples/llava/clip.cpp
|
||||||
|
@@ -708,11 +708,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
|
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
||||||
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
|
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||||
|
-
|
||||||
|
- embeddings = ggml_gelu(ctx0, embeddings);
|
||||||
|
- embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||||
|
- embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||||
|
-
|
||||||
|
+ if (model.mm_2_w)
|
||||||
|
+ {
|
||||||
|
+ embeddings = ggml_gelu(ctx0, embeddings);
|
||||||
|
+ embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||||
|
+ embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||||
|
+ }
|
||||||
|
} else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||||
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
|
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||||
|
@@ -2076,6 +2077,10 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||||
|
return ctx->vision_model.mm_model_peg_0_b->ne[0];
|
||||||
|
}
|
||||||
|
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
||||||
|
+ if (ctx->vision_model.mm_2_b == nullptr)
|
||||||
|
+ {
|
||||||
|
+ return ctx->vision_model.mm_0_b->ne[0];
|
||||||
|
+ }
|
||||||
|
return ctx->vision_model.mm_2_b->ne[0];
|
||||||
|
}
|
||||||
|
if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||||
|
diff --git a/include/llama.h b/include/llama.h
|
||||||
|
index f23355a6..e48da401 100644
|
||||||
|
--- a/include/llama.h
|
||||||
|
+++ b/include/llama.h
|
||||||
|
@@ -444,6 +444,12 @@ extern "C" {
|
||||||
|
// Frees all allocated memory
|
||||||
|
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||||
|
|
||||||
|
+ // Sets image embeddings
|
||||||
|
+ LLAMA_API void set_image_embeds(struct llama_context *ctx, float *data);
|
||||||
|
+
|
||||||
|
+ // Gets architecture
|
||||||
|
+ LLAMA_API int llama_get_architecture(struct llama_model *model);
|
||||||
|
+
|
||||||
|
LLAMA_API int64_t llama_time_us(void);
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_max_devices(void);
|
||||||
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
|
index a7b1c9eb..ee067919 100644
|
||||||
|
--- a/src/llama.cpp
|
||||||
|
+++ b/src/llama.cpp
|
||||||
|
@@ -2710,6 +2710,8 @@ struct llama_context {
|
||||||
|
|
||||||
|
bool logits_all = false;
|
||||||
|
|
||||||
|
+ float *image_embeds = nullptr;
|
||||||
|
+
|
||||||
|
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||||
|
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||||
|
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||||
|
@@ -11599,6 +11601,15 @@ struct llm_build_context {
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
|
|
||||||
|
+ if (lctx.image_embeds)
|
||||||
|
+ {
|
||||||
|
+ struct ggml_tensor *image_embeds = ggml_dup_tensor(ctx0, inpL);
|
||||||
|
+ image_embeds->data = lctx.image_embeds;
|
||||||
|
+ image_embeds->ne[1] = 256;
|
||||||
|
+ inpL = ggml_set_2d_inplace(ctx0, inpL, image_embeds, inpL->nb[1], 0);
|
||||||
|
+ lctx.image_embeds = NULL;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
||||||
|
cb(inpL, "inp_scaled", -1);
|
||||||
|
|
||||||
|
@@ -14589,7 +14600,8 @@ static int llama_decode_internal(
|
||||||
|
}
|
||||||
|
|
||||||
|
// non-causal masks do not use the KV cache
|
||||||
|
- if (hparams.causal_attn) {
|
||||||
|
+ if (hparams.causal_attn || lctx.image_embeds)
|
||||||
|
+ {
|
||||||
|
llama_kv_cache_update(&lctx);
|
||||||
|
|
||||||
|
// if we have enough unused cells before the current head ->
|
||||||
|
@@ -16448,6 +16460,16 @@ void llama_free_model(struct llama_model * model) {
|
||||||
|
delete model;
|
||||||
|
}
|
||||||
|
|
||||||
|
+void set_image_embeds(llama_context *ctx, float *data)
|
||||||
|
+{
|
||||||
|
+ ctx->image_embeds = data;
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
+int llama_get_architecture(llama_model *model)
|
||||||
|
+{
|
||||||
|
+ return model->arch;
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
struct llama_context * llama_new_context_with_model(
|
||||||
|
struct llama_model * model,
|
||||||
|
struct llama_context_params params) {
|
Loading…
x
Reference in New Issue
Block a user