diff --git a/llama/llama.cpp b/llama/llama.cpp
index 87d0148b..34970e54 100644
--- a/llama/llama.cpp
+++ b/llama/llama.cpp
@@ -2699,7 +2699,7 @@ struct llama_hparams {
         GGML_ABORT("fatal error");
     }
 
-    bool cross_attention_layer(uint32_t il) const {
+    bool cross_attention_layers(uint32_t il) const {
         return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
     }
 };
@@ -2731,6 +2731,9 @@ struct llama_cparams {
     bool offload_kqv;
     bool flash_attn;
     bool no_perf;
+    // TODO (jmorganca): this should most likely be passed in as part of a batch
+    // and not set on the context for all batches.
+    bool cross_attn = false;
 
     enum llama_pooling_type pooling_type;
 
@@ -3542,10 +3545,6 @@ struct llama_context {
     struct ggml_tensor * inp_embd_enc;      // F32 [n_embd, n_outputs_enc]
     struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
 
-    // TODO (jmorganca): this should most likely be passed in as part of a batch
-    // and not set on the context for all batches.
-    float * cross_attn_state = nullptr;
-    bool cross_attn_state_first_pass = true;
     struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
 };
 
@@ -3782,7 +3781,7 @@ static bool llama_kv_cache_init(
 
     for (int i = 0; i < (int) n_layer; i++) {
         // for cross attention layers
-        if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) {
+        if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
             struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
             ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
             ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
@@ -7389,7 +7388,7 @@ static bool llm_load_tensors(
 
                         auto & layer = model.layers[i];
 
-                        if (hparams.cross_attention_layer(i)) {
+                        if (hparams.cross_attention_layers(i)) {
                             layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM,   "weight", i), {128});
                             layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ,   "weight", i), {n_embd, 1024});
                             layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ,   "weight", i), {n_embd, n_embd});
@@ -9346,7 +9345,7 @@ static struct ggml_tensor * llm_build_inp_embd(
 
         inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
     } else {
-        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+       lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
         inpL = lctx.inp_embd;
         ggml_set_input(lctx.inp_embd);
     }
@@ -9368,11 +9367,10 @@ static struct ggml_tensor * llm_build_inp_cross_attn_state(
          const llm_build_cb & cb) {
     const int64_t n_embd = hparams.n_embd;
 
-    struct ggml_tensor * inpCAS;
-    lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
-    cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1);
-    ggml_set_input(lctx.inp_cross_attn_state);
-    inpCAS = lctx.inp_cross_attn_state;
+    struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
+    cb(inpCAS, "inp_cross_attn_state", -1);
+    ggml_set_input(inpCAS);
+    lctx.inp_cross_attn_state = inpCAS;
 
     return inpCAS;
 }
@@ -10979,8 +10977,8 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "attn_norm", il);
 
-            if (hparams.cross_attention_layer(il)) {
-                if (!lctx.cross_attn_state) {
+            if (hparams.cross_attention_layers(il)) {
+                if (!batch.embd && !cparams.cross_attn) {
                     continue;
                 }
 
@@ -10991,42 +10989,28 @@ struct llm_build_context {
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
                 cb(Qcur, "Qcur", il);
 
-                Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-                cb(Qcur, "Qcur", il);
-
-                // TODO: is this required?
-                Qcur = ggml_cont(ctx0, Qcur);
+                Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
                 cb(Qcur, "Qcur", il);
 
                 Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur;
-                if (lctx.cross_attn_state_first_pass) {
+                struct ggml_tensor * Kcur, * Vcur;
+                if (batch.embd) {
                     Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
                     cb(Kcur, "Kcur", il);
 
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
                     cb(Kcur, "Kcur", il);
 
-                    Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3);
-                    cb(Kcur, "Kcur", il);
-
-                    // TODO: is this required?
-                    Kcur = ggml_cont(ctx0, Kcur);
+                    Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
                     cb(Kcur, "Kcur", il);
 
                     Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
                     cb(Kcur, "Kcur", il);
 
                     ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
-                } else {
-                    Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
-                    cb(Kcur, "Kcur (view)", il);
-                }
 
-                struct ggml_tensor * Vcur;
-                if (lctx.cross_attn_state_first_pass) {
                     Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
                     cb(Vcur, "Vcur", il);
 
@@ -11038,6 +11022,9 @@ struct llm_build_context {
 
                     ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
                 } else {
+                    Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
+                    cb(Kcur, "Kcur (view)", il);
+
                     Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
                     cb(Vcur, "Vcur (view)", il);
                 }
@@ -11045,11 +11032,8 @@ struct llm_build_context {
                 struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
                 cb(kq, "kq", il);
 
-                kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
-                cb(kq, "kq_scaled", il);
-
                 // TODO: apply causal masks
-                struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq);
+                struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
                 cb(kq_soft_max, "kq_soft_max", il);
 
                 Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
@@ -11139,8 +11123,8 @@ struct llm_build_context {
                 cb(Kcur, "Kcur", il);
 
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                    model.layers[il].wo, model.layers[il].bo,
+                    Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
 
 
                 if (il == n_layer - 1) {
@@ -17197,10 +17181,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     }
 
     if (batch.embd) {
-        const int64_t n_embd   = hparams.n_embd;
-        const int64_t n_tokens = batch.n_tokens;
+        if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) {
+            ggml_backend_tensor_set(lctx.inp_cross_attn_state, batch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state));
+            // zero out inp_embd since it's not used
+            float * inp_embd_data = (float *)lctx.inp_embd->data;
+            for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) {
+                inp_embd_data[i] = 0.0f;
+            }
+        } else {
+            const int64_t n_embd   = hparams.n_embd;
+            const int64_t n_tokens = batch.n_tokens;
 
-        ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
+            ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
+        }
     }
 
     if (batch.pos && lctx.inp_pos) {
@@ -17209,14 +17202,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
     }
 
-    // TODO (jmorganca): this might copy a lot of data on every request of a
-    // single generation even though it doesn't change, so we should
-    // find a way to not set this more than one time per image
-    if (lctx.inp_cross_attn_state &&
-        lctx.inp_cross_attn_state->buffer) {
-        ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
-    }
-
     if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
         GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
         const int64_t n_tokens = batch.n_tokens;
@@ -17789,7 +17774,7 @@ static int llama_decode_internal(
         n_outputs = 1;
     }
 
-    lctx.sbatch.from_batch(batch_all, n_embd,
+    lctx.sbatch.from_batch(batch_all, batch_all.n_embd,
         /* simple_split */ !kv_self.recurrent,
         /* logits_all   */ n_outputs == n_tokens_all);
 
@@ -17899,10 +17884,6 @@ static int llama_decode_internal(
 
         llama_set_inputs(lctx, ubatch);
 
-        // TODO: replace with something better to find out if its
-        // our first actual pass
-        lctx.cross_attn_state_first_pass = false;
-
         llama_graph_compute(lctx, gf, n_threads, threadpool);
 
         // update the kv ring buffer
@@ -18086,7 +18067,7 @@ static int llama_encode_internal(
 
     const int64_t n_embd = hparams.n_embd;
 
-    lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+    lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
 
     const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
 
@@ -20194,11 +20175,6 @@ struct llama_context * llama_new_context_with_model(
     return ctx;
 }
 
-void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
-    ctx->cross_attn_state_first_pass = true;
-    ctx->cross_attn_state = cross_attn_state;
-}
-
 void llama_free(struct llama_context * ctx) {
     delete ctx;
 }
@@ -21686,6 +21662,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
     ctx->cparams.causal_attn = causal_attn;
 }
 
+void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
+    ctx->cparams.cross_attn = cross_attention;
+}
+
 struct llama_batch llama_batch_get_one(
              llama_token * tokens,
                  int32_t   n_tokens,
@@ -21695,6 +21675,7 @@ struct llama_batch llama_batch_get_one(
         /*n_tokens       =*/ n_tokens,
         /*tokens         =*/ tokens,
         /*embd           =*/ nullptr,
+        /*n_embd         =*/ 0,
         /*pos            =*/ nullptr,
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
@@ -21710,6 +21691,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
         /*n_tokens       =*/ 0,
         /*tokens         =*/ nullptr,
         /*embd           =*/ nullptr,
+        /*n_embd         =*/ 0,
         /*pos            =*/ nullptr,
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
@@ -21721,6 +21703,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
 
     if (embd) {
         batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
+        batch.n_embd = embd;
     } else {
         batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
     }
diff --git a/llama/llama.go b/llama/llama.go
index 7663e446..2fb19ae7 100644
--- a/llama/llama.go
+++ b/llama/llama.go
@@ -111,6 +111,28 @@ func PrintSystemInfo() string {
 	return C.GoString(C.llama_print_system_info()) + compiler
 }
 
+func GetModelArch(modelPath string) (string, error) {
+	mp := C.CString(modelPath)
+	defer C.free(unsafe.Pointer(mp))
+
+	gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
+	if gguf_ctx == nil {
+		return "", errors.New("unable to load model file")
+	}
+	defer C.gguf_free(gguf_ctx)
+
+	key := C.CString("general.architecture")
+	defer C.free(unsafe.Pointer(key))
+	arch_index := C.gguf_find_key(gguf_ctx, key)
+	if int(arch_index) < 0 {
+		return "", errors.New("unknown model architecture")
+	}
+
+	arch := C.gguf_get_val_str(gguf_ctx, arch_index)
+
+	return C.GoString(arch), nil
+}
+
 type ContextParams struct {
 	c C.struct_llama_context_params
 }
@@ -443,71 +465,36 @@ func Quantize(infile, outfile string, ftype uint32) error {
 	return nil
 }
 
-// llava
+// vision processing
 type ClipContext struct {
-	c        *C.struct_clip_ctx
-	m        *C.struct_mllama_ctx
-	IsMllama bool
-	embedPin runtime.Pinner
-	pinned   bool
+	c *C.struct_clip_ctx
 }
 
-func getVisionArch(mp *C.char) (string, error) {
-	gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
-	if gguf_ctx == nil {
-		return "", errors.New("unable to load vision projector")
-	}
-	defer C.gguf_free(gguf_ctx)
-
-	arch_index := C.gguf_find_key(gguf_ctx, C.CString("general.architecture"))
-	if int(arch_index) < 0 {
-		return "", errors.New("unknown vision model architecture")
-	}
-
-	arch := C.gguf_get_val_str(gguf_ctx, arch_index)
-
-	return C.GoString(arch), nil
-}
-
-func NewClipContext(modelPath string) (*ClipContext, error) {
+func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
 	mp := C.CString(modelPath)
 	defer C.free(unsafe.Pointer(mp))
+	c := C.clip_model_load(mp, 1)
 
-	arch, err := getVisionArch(mp)
-	if err != nil {
-		return nil, err
+	projEmbedSize := int(C.clip_n_mmproj_embd(c))
+	modelEmbedSize := llamaContext.Model().NEmbd()
+	if projEmbedSize != modelEmbedSize {
+		return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
 	}
 
-	var cc ClipContext
-	if arch == "clip" {
-		cc.c = C.clip_model_load(mp, 1)
-	} else if arch == "mllama" {
-		cc.m = C.mllama_model_load(mp, 1)
-		cc.IsMllama = true
-	} else {
-		return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
-	}
-
-	// XXX: check embedding size?
-	return &cc, nil
+	return &ClipContext{c: c}, nil
 }
 
 func (c *ClipContext) Free() {
-	if c.c != nil {
-		C.clip_free(c.c)
-	}
-	if c.m != nil {
-		C.mllama_free(c.m)
-	}
+	C.clip_free(c.c)
 }
 
-func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte) [][]float32 {
-	c := C.llava_image_embed_make_with_bytes(clipContext.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
+func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) [][]float32 {
+	l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
 
-	numTokens := int(c.n_image_pos)
+	numTokens := int(l.n_image_pos)
 	numEmbed := llamaContext.Model().NEmbd()
 
-	s := unsafe.Slice((*float32)(c.embed), numEmbed*numTokens)
+	s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
 
 	embed := make([][]float32, numTokens)
 	rows := make([]float32, len(s))
@@ -517,51 +504,57 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []
 		embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
 	}
 
-	C.llava_image_embed_free(c)
+	C.llava_image_embed_free(l)
 
 	return embed
 }
 
-func NewMllamaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte, aspectRatioId int) [][]float32 {
+type MllamaContext struct {
+	c *C.struct_mllama_ctx
+}
+
+func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
+	mp := C.CString(modelPath)
+	defer C.free(unsafe.Pointer(mp))
+	c := C.mllama_model_load(mp, 1)
+
+	projEmbedSize := int(C.mllama_n_embd(c))
+	modelEmbedSize := llamaContext.Model().NEmbd()
+	if projEmbedSize != modelEmbedSize {
+		return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
+	}
+
+	return &MllamaContext{c: c}, nil
+}
+
+func (m *MllamaContext) Free() {
+	C.mllama_free(m.c)
+}
+
+func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) [][]float32 {
 	img := C.mllama_image_init()
 	defer C.mllama_image_free(img)
 
 	C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img)
 
-	numTokens := int(C.mllama_n_positions(clipContext.m) * C.mllama_n_tiles(clipContext.m))
-	numEmbed := llamaContext.Model().NEmbd()
+	rows := make([]float32, m.EmbedSize(llamaContext))
+	C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
 
-	rows := make([]float32, numEmbed*numTokens)
-	C.mllama_image_encode(clipContext.m, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
-
-	embed := make([][]float32, numTokens)
-	for i := range embed {
-		embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
-	}
+	embed := make([][]float32, 1)
+	embed[0] = rows
 
 	return embed
 }
 
-// This really needs to be set on a batch instead
-func MllamaSetCrossAttn(llamaContext *Context, clipContext *ClipContext, embed [][]float32) {
-	if embed != nil {
-		if clipContext.pinned {
-			panic("Cross attention state already pinned")
-		}
+func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
+	numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
+	numEmbed := llamaContext.Model().NEmbd()
 
-		embedData := &embed[0][0]
-		clipContext.embedPin.Pin(embedData)
-		clipContext.pinned = true
+	return numTokens * numEmbed
+}
 
-		C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(unsafe.Pointer(embedData)))
-	} else {
-		C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(C.NULL))
-
-		if clipContext.pinned {
-			clipContext.embedPin.Unpin()
-			clipContext.pinned = false
-		}
-	}
+func (c *Context) SetCrossAttention(state bool) {
+	C.llama_set_cross_attention(c.c, C.bool(state))
 }
 
 // sampling
diff --git a/llama/llama.h b/llama/llama.h
index 5f04fc86..dea03f76 100644
--- a/llama/llama.h
+++ b/llama/llama.h
@@ -266,6 +266,7 @@ extern "C" {
 
         llama_token  *  token;
         float        *  embd;
+        int32_t         n_embd;
         llama_pos    *  pos;
         int32_t      *  n_seq_id;
         llama_seq_id ** seq_id;
@@ -451,7 +452,7 @@ extern "C" {
 
     // TODO (jmorganca): this should most likely be passed in as part of a batch
     // and not set on the context for all batches.
-    LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state);
+    LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
 
     // Frees all allocated memory
     LLAMA_API void llama_free(struct llama_context * ctx);
diff --git a/llama/llava.cpp b/llama/llava.cpp
index 9839de93..e759900e 100644
--- a/llama/llava.cpp
+++ b/llama/llava.cpp
@@ -435,7 +435,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
         if (n_eval > n_batch) {
             n_eval = n_batch;
         }
-        llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
+        llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), n_embd, nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
         if (llama_decode(ctx_llama, batch)) {
             LOG_ERR("%s : failed to eval\n", __func__);
             return false;
diff --git a/llama/patches/0010-add-mllama-support.patch b/llama/patches/0010-add-mllama-support.patch
index c6dd72a7..de8e919c 100644
--- a/llama/patches/0010-add-mllama-support.patch
+++ b/llama/patches/0010-add-mllama-support.patch
@@ -12,27 +12,49 @@ kv cache once per run
 
 remaining is to implement the cross attention mask
 ---
- include/llama.h |   4 +
- src/llama.cpp   | 456 ++++++++++++++++++++++++++++++++++++++++++++++--
- 2 files changed, 447 insertions(+), 13 deletions(-)
+ examples/llava/llava.cpp |   2 +-
+ include/llama.h          |   5 +
+ src/llama.cpp            | 447 +++++++++++++++++++++++++++++++++++++--
+ 3 files changed, 436 insertions(+), 18 deletions(-)
 
+diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
+index 8558c6bd..37b2f2e2 100644
+--- a/examples/llava/llava.cpp
++++ b/examples/llava/llava.cpp
+@@ -409,7 +409,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
+         if (n_eval > n_batch) {
+             n_eval = n_batch;
+         }
+-        llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
++        llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), n_embd, nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
+         if (llama_decode(ctx_llama, batch)) {
+             LOG_ERR("%s : failed to eval\n", __func__);
+             return false;
 diff --git a/include/llama.h b/include/llama.h
-index 7cae1bbe..122e3cf1 100644
+index 7cae1bbe..aca09310 100644
 --- a/include/llama.h
 +++ b/include/llama.h
-@@ -423,6 +423,10 @@ extern "C" {
+@@ -240,6 +240,7 @@ extern "C" {
+ 
+         llama_token  *  token;
+         float        *  embd;
++        int32_t         n_embd;
+         llama_pos    *  pos;
+         int32_t      *  n_seq_id;
+         llama_seq_id ** seq_id;
+@@ -423,6 +424,10 @@ extern "C" {
                       struct llama_model * model,
              struct llama_context_params   params);
  
 +    // TODO (jmorganca): this should most likely be passed in as part of a batch
 +    // and not set on the context for all batches.
-+    LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state);
++    LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
 +
      // Frees all allocated memory
      LLAMA_API void llama_free(struct llama_context * ctx);
  
 diff --git a/src/llama.cpp b/src/llama.cpp
-index 83b80b59..b189a19a 100644
+index 83b80b59..35748488 100644
 --- a/src/llama.cpp
 +++ b/src/llama.cpp
 @@ -169,6 +169,7 @@ static std::string format(const char * fmt, ...) {
@@ -160,13 +182,23 @@ index 83b80b59..b189a19a 100644
          GGML_ABORT("fatal error");
      }
 +
-+    bool cross_attention_layer(uint32_t il) const {
++    bool cross_attention_layers(uint32_t il) const {
 +        return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
 +    }
  };
  
  static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
-@@ -2806,6 +2859,16 @@ struct llama_layer {
+@@ -2652,6 +2705,9 @@ struct llama_cparams {
+     bool offload_kqv;
+     bool flash_attn;
+     bool no_perf;
++    // TODO (jmorganca): this should most likely be passed in as part of a batch
++    // and not set on the context for all batches.
++    bool cross_attn = false;
+ 
+     enum llama_pooling_type pooling_type;
+ 
+@@ -2806,6 +2862,16 @@ struct llama_layer {
      struct ggml_tensor * ffn_down_scale;
  
      struct ggml_tensor * bskcn_tv;
@@ -183,25 +215,21 @@ index 83b80b59..b189a19a 100644
  };
  
  // very similar to llama_batch,
-@@ -3452,6 +3515,12 @@ struct llama_context {
+@@ -3452,6 +3518,8 @@ struct llama_context {
      struct ggml_tensor * inp_pos_bucket;    // I32 [n_batch|n_kv, n_batch]
      struct ggml_tensor * inp_embd_enc;      // F32 [n_embd, n_outputs_enc]
      struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
 +
-+    // TODO (jmorganca): this should most likely be passed in as part of a batch
-+    // and not set on the context for all batches.
-+    float * cross_attn_state = nullptr;
-+    bool cross_attn_state_first_pass = true;
 +    struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
  };
  
  struct llama_lora_weight {
-@@ -3686,6 +3755,18 @@ static bool llama_kv_cache_init(
+@@ -3686,6 +3754,18 @@ static bool llama_kv_cache_init(
      cache.v_l.reserve(n_layer);
  
      for (int i = 0; i < (int) n_layer; i++) {
 +        // for cross attention layers
-+        if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) {
++        if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
 +            struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
 +            ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
 +            ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
@@ -215,7 +243,7 @@ index 83b80b59..b189a19a 100644
          const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
          const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
  
-@@ -5460,12 +5541,14 @@ static void llm_load_hparams(
+@@ -5460,12 +5540,14 @@ static void llm_load_hparams(
      }
  
      // zero-out the per-layer hparams
@@ -235,7 +263,7 @@ index 83b80b59..b189a19a 100644
  
      // n_head_kv is optional, default to n_head
      hparams.n_head_kv_arr = hparams.n_head_arr;
-@@ -5514,7 +5597,7 @@ static void llm_load_hparams(
+@@ -5514,7 +5596,7 @@ static void llm_load_hparams(
  
          ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
  
@@ -244,7 +272,7 @@ index 83b80b59..b189a19a 100644
              if (hparams.n_rot != hparams.n_embd_head_k) {
                  throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
              }
-@@ -5554,6 +5637,16 @@ static void llm_load_hparams(
+@@ -5554,6 +5636,16 @@ static void llm_load_hparams(
                      }
                  }
              } break;
@@ -261,7 +289,7 @@ index 83b80b59..b189a19a 100644
          case LLM_ARCH_MINICPM:
              {
                  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-@@ -7249,6 +7342,55 @@ static bool llm_load_tensors(
+@@ -7249,6 +7341,55 @@ static bool llm_load_tensors(
                          layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
                      }
                  } break;
@@ -286,7 +314,7 @@ index 83b80b59..b189a19a 100644
 +
 +                        auto & layer = model.layers[i];
 +
-+                        if (hparams.cross_attention_layer(i)) {
++                        if (hparams.cross_attention_layers(i)) {
 +                            layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM,   "weight", i), {128});
 +                            layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ,   "weight", i), {n_embd, 1024});
 +                            layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ,   "weight", i), {n_embd, n_embd});
@@ -317,7 +345,7 @@ index 83b80b59..b189a19a 100644
              case LLM_ARCH_GROK:
                  {
                      if (n_expert == 0) {
-@@ -9093,7 +9235,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
+@@ -9093,7 +9234,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
  
          if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
              model.hparams.n_vocab != model.vocab.id_to_token.size()) {
@@ -326,16 +354,7 @@ index 83b80b59..b189a19a 100644
          }
  
          if (params.vocab_only) {
-@@ -9178,7 +9320,7 @@ static struct ggml_tensor * llm_build_inp_embd(
- 
-         inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
-     } else {
--       lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
-+        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
-         inpL = lctx.inp_embd;
-         ggml_set_input(lctx.inp_embd);
-     }
-@@ -9193,6 +9335,22 @@ static struct ggml_tensor * llm_build_inp_embd(
+@@ -9193,6 +9334,21 @@ static struct ggml_tensor * llm_build_inp_embd(
      return inpL;
  }
  
@@ -346,11 +365,10 @@ index 83b80b59..b189a19a 100644
 +         const llm_build_cb & cb) {
 +    const int64_t n_embd = hparams.n_embd;
 +
-+    struct ggml_tensor * inpCAS;
-+    lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
-+    cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1);
-+    ggml_set_input(lctx.inp_cross_attn_state);
-+    inpCAS = lctx.inp_cross_attn_state;
++    struct ggml_tensor * inpCAS = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
++    cb(inpCAS, "inp_cross_attn_state", -1);
++    ggml_set_input(inpCAS);
++    lctx.inp_cross_attn_state = inpCAS;
 +
 +    return inpCAS;
 +}
@@ -358,7 +376,7 @@ index 83b80b59..b189a19a 100644
  static void llm_build_kv_store(
          struct ggml_context * ctx,
          const llama_hparams & hparams,
-@@ -10167,6 +10325,7 @@ struct llm_build_context {
+@@ -10167,6 +10323,7 @@ struct llm_build_context {
          lctx.inp_pos_bucket    = nullptr;
          lctx.inp_embd_enc      = nullptr;
          lctx.inp_KQ_mask_cross = nullptr;
@@ -366,7 +384,7 @@ index 83b80b59..b189a19a 100644
      }
  
      void free() {
-@@ -10754,6 +10913,253 @@ struct llm_build_context {
+@@ -10754,6 +10911,239 @@ struct llm_build_context {
                  LLM_NORM_RMS, cb, -1);
          cb(cur, "result_norm", -1);
  
@@ -410,8 +428,8 @@ index 83b80b59..b189a19a 100644
 +                    LLM_NORM_RMS, cb, il);
 +            cb(cur, "attn_norm", il);
 +
-+            if (hparams.cross_attention_layer(il)) {
-+                if (!lctx.cross_attn_state) {
++            if (hparams.cross_attention_layers(il)) {
++                if (!batch.embd && !cparams.cross_attn) {
 +                    continue;
 +                }
 +
@@ -422,42 +440,28 @@ index 83b80b59..b189a19a 100644
 +                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 +                cb(Qcur, "Qcur", il);
 +
-+                Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-+                cb(Qcur, "Qcur", il);
-+
-+                // TODO: is this required?
-+                Qcur = ggml_cont(ctx0, Qcur);
++                Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
 +                cb(Qcur, "Qcur", il);
 +
 +                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
 +                cb(Qcur, "Qcur", il);
 +
-+                struct ggml_tensor * Kcur;
-+                if (lctx.cross_attn_state_first_pass) {
++                struct ggml_tensor * Kcur, * Vcur;
++                if (batch.embd) {
 +                    Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
 +                    cb(Kcur, "Kcur", il);
 +
 +                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
 +                    cb(Kcur, "Kcur", il);
 +
-+                    Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3);
-+                    cb(Kcur, "Kcur", il);
-+
-+                    // TODO: is this required?
-+                    Kcur = ggml_cont(ctx0, Kcur);
++                    Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
 +                    cb(Kcur, "Kcur", il);
 +
 +                    Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
 +                    cb(Kcur, "Kcur", il);
 +
 +                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
-+                } else {
-+                    Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
-+                    cb(Kcur, "Kcur (view)", il);
-+                }
 +
-+                struct ggml_tensor * Vcur;
-+                if (lctx.cross_attn_state_first_pass) {
 +                    Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
 +                    cb(Vcur, "Vcur", il);
 +
@@ -469,6 +473,9 @@ index 83b80b59..b189a19a 100644
 +
 +                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
 +                } else {
++                    Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
++                    cb(Kcur, "Kcur (view)", il);
++
 +                    Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
 +                    cb(Vcur, "Vcur (view)", il);
 +                }
@@ -476,11 +483,8 @@ index 83b80b59..b189a19a 100644
 +                struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
 +                cb(kq, "kq", il);
 +
-+                kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
-+                cb(kq, "kq_scaled", il);
-+
 +                // TODO: apply causal masks
-+                struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq);
++                struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
 +                cb(kq_soft_max, "kq_soft_max", il);
 +
 +                Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
@@ -570,8 +574,8 @@ index 83b80b59..b189a19a 100644
 +                cb(Kcur, "Kcur", il);
 +
 +                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-+                        model.layers[il].wo, model.layers[il].bo,
-+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
++                    model.layers[il].wo, model.layers[il].bo,
++                    Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
 +
 +
 +                if (il == n_layer - 1) {
@@ -620,7 +624,7 @@ index 83b80b59..b189a19a 100644
          // lm_head
          cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
          cb(cur, "result_output", -1);
-@@ -16501,6 +16907,10 @@ static struct ggml_cgraph * llama_build_graph(
+@@ -16501,6 +16891,10 @@ static struct ggml_cgraph * llama_build_graph(
              {
                  result = llm.build_llama();
              } break;
@@ -631,33 +635,48 @@ index 83b80b59..b189a19a 100644
          case LLM_ARCH_BAICHUAN:
              {
                  result = llm.build_baichuan();
-@@ -16773,6 +17183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
-         ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
+@@ -16761,10 +17155,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
      }
  
-+    // TODO (jmorganca): this might copy a lot of data on every request of a
-+    // single generation even though it doesn't change, so we should
-+    // find a way to not set this more than one time per image
-+    if (lctx.inp_cross_attn_state &&
-+        lctx.inp_cross_attn_state->buffer) {
-+        ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
-+    }
-+
-     if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
-         GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
-         const int64_t n_tokens = batch.n_tokens;
-@@ -17455,6 +17873,10 @@ static int llama_decode_internal(
+     if (batch.embd) {
+-        const int64_t n_embd   = hparams.n_embd;
+-        const int64_t n_tokens = batch.n_tokens;
++        if (lctx.inp_cross_attn_state && lctx.inp_cross_attn_state->buffer) {
++            ggml_backend_tensor_set(lctx.inp_cross_attn_state, batch.embd, 0, ggml_nbytes(lctx.inp_cross_attn_state));
++            // zero out inp_embd since it's not used
++            float * inp_embd_data = (float *)lctx.inp_embd->data;
++            for (int i = 0; i < ggml_nelements(lctx.inp_embd); ++i) {
++                inp_embd_data[i] = 0.0f;
++            }
++        } else {
++            const int64_t n_embd   = hparams.n_embd;
++            const int64_t n_tokens = batch.n_tokens;
  
-         llama_set_inputs(lctx, ubatch);
+-        ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
++            ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
++        }
+     }
  
-+        // TODO: replace with something better to find out if its
-+        // our first actual pass
-+        lctx.cross_attn_state_first_pass = false;
-+
-         llama_graph_compute(lctx, gf, n_threads, threadpool);
+     if (batch.pos && lctx.inp_pos) {
+@@ -17345,7 +17748,7 @@ static int llama_decode_internal(
+         n_outputs = 1;
+     }
  
-         // update the kv ring buffer
-@@ -18648,7 +19070,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
+-    lctx.sbatch.from_batch(batch_all, n_embd,
++    lctx.sbatch.from_batch(batch_all, batch_all.n_embd,
+         /* simple_split */ !kv_self.recurrent,
+         /* logits_all   */ n_outputs == n_tokens_all);
+ 
+@@ -17638,7 +18041,7 @@ static int llama_encode_internal(
+ 
+     const int64_t n_embd = hparams.n_embd;
+ 
+-    lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
++    lctx.sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
+ 
+     const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
+ 
+@@ -18648,7 +19051,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
          if (llama_model_has_encoder(&model)) {
              n_attn_layer *= 3;
          }
@@ -668,19 +687,7 @@ index 83b80b59..b189a19a 100644
      }
  
      size_t total_size_org = 0;
-@@ -19744,6 +20168,11 @@ struct llama_context * llama_new_context_with_model(
-     return ctx;
- }
- 
-+void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
-+    ctx->cross_attn_state_first_pass = true;
-+    ctx->cross_attn_state = cross_attn_state;
-+}
-+
- void llama_free(struct llama_context * ctx) {
-     delete ctx;
- }
-@@ -19814,6 +20243,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
+@@ -19814,6 +20219,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
  
          // use what we call a normal RoPE, operating on pairs of consecutive head values
          case LLM_ARCH_LLAMA:
@@ -688,3 +695,38 @@ index 83b80b59..b189a19a 100644
          case LLM_ARCH_BAICHUAN:
          case LLM_ARCH_STARCODER:
          case LLM_ARCH_PLAMO:
+@@ -21230,6 +21636,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
+     ctx->cparams.causal_attn = causal_attn;
+ }
+ 
++void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
++    ctx->cparams.cross_attn = cross_attention;
++}
++
+ struct llama_batch llama_batch_get_one(
+              llama_token * tokens,
+                  int32_t   n_tokens,
+@@ -21239,6 +21649,7 @@ struct llama_batch llama_batch_get_one(
+         /*n_tokens       =*/ n_tokens,
+         /*tokens         =*/ tokens,
+         /*embd           =*/ nullptr,
++        /*n_embd         =*/ 0,
+         /*pos            =*/ nullptr,
+         /*n_seq_id       =*/ nullptr,
+         /*seq_id         =*/ nullptr,
+@@ -21254,6 +21665,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
+         /*n_tokens       =*/ 0,
+         /*tokens         =*/ nullptr,
+         /*embd           =*/ nullptr,
++        /*n_embd         =*/ 0,
+         /*pos            =*/ nullptr,
+         /*n_seq_id       =*/ nullptr,
+         /*seq_id         =*/ nullptr,
+@@ -21265,6 +21677,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
+ 
+     if (embd) {
+         batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
++        batch.n_embd = embd;
+     } else {
+         batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
+     }
diff --git a/llama/runner/cache.go b/llama/runner/cache.go
index ef8f6cfb..75c1d874 100644
--- a/llama/runner/cache.go
+++ b/llama/runner/cache.go
@@ -2,7 +2,6 @@ package main
 
 import (
 	"errors"
-	"hash/maphash"
 	"log/slog"
 	"reflect"
 	"time"
@@ -20,10 +19,6 @@ type InputCache struct {
 	// optimize cache eviction for multiple users
 	multiUserCache bool
 
-	// cache of images to embeddings
-	images    []imageCache
-	imageHash maphash.Hash
-
 	lc *llama.Context
 }
 
@@ -41,7 +36,6 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b
 		numCtx:         kvSize / numSlots,
 		slots:          slots,
 		multiUserCache: multiUserCache,
-		images:         make([]imageCache, numSlots),
 		lc:             lc,
 	}
 }
@@ -211,55 +205,3 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscar
 	}
 	slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard]
 }
-
-// Locking: Lookup and store operations on imageCache require a lock
-// to be held that serializes these with each other. Hash does not
-// require a lock nor they need to be serialized with InputCacheSlot.
-
-type imageCache struct {
-	key      uint64
-	val      [][]float32
-	lastUsed time.Time
-}
-
-func (c *InputCache) HashImage(image []byte) uint64 {
-	c.imageHash.Reset()
-	_, _ = c.imageHash.Write(image)
-	return c.imageHash.Sum64()
-}
-
-var ErrImageNotFound = errors.New("image not found in cache")
-
-func (c *InputCache) FindImage(hash uint64) ([][]float32, error) {
-	for i := range c.images {
-		if c.images[i].key == hash {
-			slog.Debug("loading image embeddings from cache", "entry", i)
-			c.images[i].lastUsed = time.Now()
-			return c.images[i].val, nil
-		}
-	}
-
-	return nil, ErrImageNotFound
-}
-
-func (c *InputCache) AddImage(hash uint64, embed [][]float32) {
-	best := time.Now()
-	var bestImage int
-
-	for i := range c.images {
-		if c.images[i].key == hash {
-			bestImage = i
-			break
-		}
-
-		if c.images[i].lastUsed.Compare(best) < 0 {
-			best = c.images[i].lastUsed
-			bestImage = i
-		}
-	}
-
-	slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
-	c.images[bestImage].key = hash
-	c.images[bestImage].val = embed
-	c.images[bestImage].lastUsed = time.Now()
-}
diff --git a/llama/runner/cache_test.go b/llama/runner/cache_test.go
index cc13b5f2..0e38c67d 100644
--- a/llama/runner/cache_test.go
+++ b/llama/runner/cache_test.go
@@ -1,7 +1,6 @@
 package main
 
 import (
-	"reflect"
 	"testing"
 	"time"
 )
@@ -228,77 +227,3 @@ func TestFindCacheSlot(t *testing.T) {
 		})
 	}
 }
-
-func TestImageCache(t *testing.T) {
-	cache := NewInputCache(nil, 2048, 4, false)
-
-	valA := [][]float32{{0.1, 0.2}, {0.3}}
-	valB := [][]float32{{0.4}, {0.5}, {0.6}}
-	valC := [][]float32{{0.7}}
-	valD := [][]float32{{0.8}}
-	valE := [][]float32{{0.9}}
-
-	// Empty cache
-	result, err := cache.FindImage(0x5adb61d31933a946)
-	if err != ErrImageNotFound {
-		t.Errorf("found result in empty cache: result %v, err %v", result, err)
-	}
-
-	// Insert A
-	cache.AddImage(0x5adb61d31933a946, valA)
-
-	result, err = cache.FindImage(0x5adb61d31933a946)
-	if !reflect.DeepEqual(result, valA) {
-		t.Errorf("failed to find expected value: result %v, err %v", result, err)
-	}
-
-	// Insert B
-	cache.AddImage(0x011551369a34a901, valB)
-
-	result, err = cache.FindImage(0x5adb61d31933a946)
-	if !reflect.DeepEqual(result, valA) {
-		t.Errorf("failed to find expected value: result %v, err %v", result, err)
-	}
-	result, err = cache.FindImage(0x011551369a34a901)
-	if !reflect.DeepEqual(result, valB) {
-		t.Errorf("failed to find expected value: result %v, err %v", result, err)
-	}
-
-	// Replace B with C
-	cache.AddImage(0x011551369a34a901, valC)
-
-	result, err = cache.FindImage(0x5adb61d31933a946)
-	if !reflect.DeepEqual(result, valA) {
-		t.Errorf("failed to find expected value: result %v, err %v", result, err)
-	}
-	result, err = cache.FindImage(0x011551369a34a901)
-	if !reflect.DeepEqual(result, valC) {
-		t.Errorf("failed to find expected value: result %v, err %v", result, err)
-	}
-
-	// Evict A
-	cache.AddImage(0x756b218a517e7353, valB)
-	cache.AddImage(0x75e5e8d35d7e3967, valD)
-	cache.AddImage(0xd96f7f268ca0646e, valE)
-
-	result, err = cache.FindImage(0x5adb61d31933a946)
-	if reflect.DeepEqual(result, valA) {
-		t.Errorf("failed to find expected value: result %v, err %v", result, err)
-	}
-	result, err = cache.FindImage(0x756b218a517e7353)
-	if !reflect.DeepEqual(result, valB) {
-		t.Errorf("failed to find expected value: result %v, err %v", result, err)
-	}
-	result, err = cache.FindImage(0x011551369a34a901)
-	if !reflect.DeepEqual(result, valC) {
-		t.Errorf("failed to find expected value: result %v, err %v", result, err)
-	}
-	result, err = cache.FindImage(0x75e5e8d35d7e3967)
-	if !reflect.DeepEqual(result, valD) {
-		t.Errorf("failed to find expected value: result %v, err %v", result, err)
-	}
-	result, err = cache.FindImage(0xd96f7f268ca0646e)
-	if !reflect.DeepEqual(result, valE) {
-		t.Errorf("failed to find expected value: result %v, err %v", result, err)
-	}
-}
diff --git a/llama/runner/image.go b/llama/runner/image.go
new file mode 100644
index 00000000..d50645e8
--- /dev/null
+++ b/llama/runner/image.go
@@ -0,0 +1,145 @@
+package main
+
+import (
+	"errors"
+	"fmt"
+	"hash/maphash"
+	"log/slog"
+	"sync"
+	"time"
+
+	"github.com/ollama/ollama/llama"
+)
+
+const imageCacheSize = 4
+
+type ImageContext struct {
+	// mu is required to be held when generating embeddings or accessing the cache
+	mu sync.Mutex
+
+	clip   *llama.ClipContext
+	mllama *llama.MllamaContext
+
+	// cache of images to embeddings
+	images    []imageCache
+	imageHash maphash.Hash
+}
+
+func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
+	arch, err := llama.GetModelArch(modelPath)
+	if err != nil {
+		return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
+	}
+
+	var c ImageContext
+	if arch == "clip" {
+		c.clip, err = llama.NewClipContext(llamaContext, modelPath)
+	} else if arch == "mllama" {
+		c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath)
+	} else {
+		return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
+	}
+
+	if err != nil {
+		return nil, err
+	}
+
+	c.images = make([]imageCache, imageCacheSize)
+
+	return &c, nil
+}
+
+func (c *ImageContext) Free(modelPath string) {
+	if c == nil {
+		return
+	}
+
+	if c.clip != nil {
+		c.clip.Free()
+	}
+	if c.mllama != nil {
+		c.mllama.Free()
+	}
+}
+
+func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) [][]float32 {
+	if c == nil {
+		return nil
+	}
+
+	hash := c.hashImage(data)
+
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	embed, err := c.findImage(hash)
+	if err != nil {
+		if c.mllama != nil {
+			embed = c.mllama.NewEmbed(llamaContext, data, aspectRatioId)
+		} else if c.clip != nil {
+			embed = c.clip.NewEmbed(llamaContext, data)
+		} else {
+			return nil
+		}
+
+		c.addImage(hash, embed)
+	}
+
+	return embed
+}
+
+func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
+	if c != nil && c.mllama != nil {
+		return c.mllama.EmbedSize(llamaContext)
+	} else {
+		return llamaContext.Model().NEmbd()
+	}
+}
+
+type imageCache struct {
+	key      uint64
+	val      [][]float32
+	lastUsed time.Time
+}
+
+func (c *ImageContext) hashImage(image []byte) uint64 {
+	c.imageHash.Reset()
+	_, _ = c.imageHash.Write(image)
+	return c.imageHash.Sum64()
+}
+
+var errImageNotFound = errors.New("image not found in cache")
+
+func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
+	for i := range c.images {
+		if c.images[i].key == hash {
+			slog.Debug("loading image embeddings from cache", "entry", i)
+			c.images[i].lastUsed = time.Now()
+			return c.images[i].val, nil
+		}
+	}
+
+	return nil, errImageNotFound
+}
+
+func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
+	best := time.Now()
+	var bestImage int
+
+	for i := range c.images {
+		if c.images[i].key == hash {
+			bestImage = i
+			break
+		}
+
+		if c.images[i].lastUsed.Compare(best) < 0 {
+			best = c.images[i].lastUsed
+			bestImage = i
+		}
+	}
+
+	slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
+	c.images[bestImage].key = hash
+	c.images[bestImage].val = embed
+	c.images[bestImage].lastUsed = time.Now()
+}
diff --git a/llama/runner/image_test.go b/llama/runner/image_test.go
new file mode 100644
index 00000000..4f1d265a
--- /dev/null
+++ b/llama/runner/image_test.go
@@ -0,0 +1,80 @@
+package main
+
+import (
+	"reflect"
+	"testing"
+)
+
+func TestImageCache(t *testing.T) {
+	cache := ImageContext{images: make([]imageCache, 4)}
+
+	valA := [][]float32{{0.1, 0.2}, {0.3}}
+	valB := [][]float32{{0.4}, {0.5}, {0.6}}
+	valC := [][]float32{{0.7}}
+	valD := [][]float32{{0.8}}
+	valE := [][]float32{{0.9}}
+
+	// Empty cache
+	result, err := cache.findImage(0x5adb61d31933a946)
+	if err != errImageNotFound {
+		t.Errorf("found result in empty cache: result %v, err %v", result, err)
+	}
+
+	// Insert A
+	cache.addImage(0x5adb61d31933a946, valA)
+
+	result, err = cache.findImage(0x5adb61d31933a946)
+	if !reflect.DeepEqual(result, valA) {
+		t.Errorf("failed to find expected value: result %v, err %v", result, err)
+	}
+
+	// Insert B
+	cache.addImage(0x011551369a34a901, valB)
+
+	result, err = cache.findImage(0x5adb61d31933a946)
+	if !reflect.DeepEqual(result, valA) {
+		t.Errorf("failed to find expected value: result %v, err %v", result, err)
+	}
+	result, err = cache.findImage(0x011551369a34a901)
+	if !reflect.DeepEqual(result, valB) {
+		t.Errorf("failed to find expected value: result %v, err %v", result, err)
+	}
+
+	// Replace B with C
+	cache.addImage(0x011551369a34a901, valC)
+
+	result, err = cache.findImage(0x5adb61d31933a946)
+	if !reflect.DeepEqual(result, valA) {
+		t.Errorf("failed to find expected value: result %v, err %v", result, err)
+	}
+	result, err = cache.findImage(0x011551369a34a901)
+	if !reflect.DeepEqual(result, valC) {
+		t.Errorf("failed to find expected value: result %v, err %v", result, err)
+	}
+
+	// Evict A
+	cache.addImage(0x756b218a517e7353, valB)
+	cache.addImage(0x75e5e8d35d7e3967, valD)
+	cache.addImage(0xd96f7f268ca0646e, valE)
+
+	result, err = cache.findImage(0x5adb61d31933a946)
+	if reflect.DeepEqual(result, valA) {
+		t.Errorf("failed to find expected value: result %v, err %v", result, err)
+	}
+	result, err = cache.findImage(0x756b218a517e7353)
+	if !reflect.DeepEqual(result, valB) {
+		t.Errorf("failed to find expected value: result %v, err %v", result, err)
+	}
+	result, err = cache.findImage(0x011551369a34a901)
+	if !reflect.DeepEqual(result, valC) {
+		t.Errorf("failed to find expected value: result %v, err %v", result, err)
+	}
+	result, err = cache.findImage(0x75e5e8d35d7e3967)
+	if !reflect.DeepEqual(result, valD) {
+		t.Errorf("failed to find expected value: result %v, err %v", result, err)
+	}
+	result, err = cache.findImage(0xd96f7f268ca0646e)
+	if !reflect.DeepEqual(result, valE) {
+		t.Errorf("failed to find expected value: result %v, err %v", result, err)
+	}
+}
diff --git a/llama/runner/runner.go b/llama/runner/runner.go
index bbd1c0fb..a137f879 100644
--- a/llama/runner/runner.go
+++ b/llama/runner/runner.go
@@ -190,57 +190,22 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
 				return nil, fmt.Errorf("invalid image index: %d", n)
 			}
 
-			hash := s.cache.HashImage(images[imageIndex].Data)
-
-			// Vision models cannot be accessed concurrently
-			s.clip.mu.Lock()
-			embed, err := s.cache.FindImage(hash)
-			if err != nil {
-				embed = llama.NewLlavaImageEmbed(s.lc, s.clip.cc, images[imageIndex].Data)
-				s.cache.AddImage(hash, embed)
-			}
-			s.clip.mu.Unlock()
-
+			embed := s.image.NewEmbed(s.lc, images[imageIndex].Data, images[imageIndex].AspectRatioID)
 			for _, e := range embed {
 				inputs = append(inputs, input{embed: e})
 			}
 		}
 	}
 
-	if s.clip.cc != nil {
-		var embed [][]float32
-
-		if s.clip.cc.IsMllama && len(images) >= 1 {
-			hash := s.cache.HashImage(images[0].Data)
-
-			s.clip.mu.Lock()
-			var err error
-			embed, err = s.cache.FindImage(hash)
-			if err != nil {
-				embed = llama.NewMllamaImageEmbed(s.lc, s.clip.cc, images[0].Data, images[0].AspectRatioID)
-				s.cache.AddImage(hash, embed)
-			}
-			s.clip.mu.Unlock()
-		}
-		s.mu.Lock()
-		llama.MllamaSetCrossAttn(s.lc, s.clip.cc, embed)
-		s.mu.Unlock()
-	}
-
 	return inputs, nil
 }
 
-type clip struct {
-	cc *llama.ClipContext
-	mu sync.Mutex
-}
-
 type Server struct {
 	model *llama.Model
 	lc    *llama.Context
 
 	// required for image embeddings
-	clip clip
+	image *ImageContext
 
 	batchSize int
 
@@ -322,14 +287,12 @@ func flushPending(seq *Sequence) bool {
 func (s *Server) removeSequence(seqIndex int, reason string) {
 	seq := s.seqs[seqIndex]
 
+	s.lc.SetCrossAttention(false)
 	flushPending(seq)
 	seq.doneReason = reason
 	close(seq.responses)
 	close(seq.embedding)
 	seq.cache.InUse = false
-	if s.clip.cc != nil {
-		llama.MllamaSetCrossAttn(s.lc, s.clip.cc, nil)
-	}
 	s.seqs[seqIndex] = nil
 }
 
@@ -341,7 +304,7 @@ func (s *Server) run(ctx context.Context) {
 	tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
 	defer tokenBatch.Free()
 
-	embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.lc.Model().NEmbd(), len(s.seqs))
+	embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.image.EmbedSize(s.lc), len(s.seqs))
 	defer embedBatch.Free()
 
 	for {
@@ -642,12 +605,20 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	s.mu.Lock()
 	for i, sq := range s.seqs {
 		if sq == nil {
+			for _, input := range seq.inputs {
+				if input.embed != nil {
+					s.lc.SetCrossAttention(true)
+					break
+				}
+			}
+
 			seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
 			if err != nil {
 				s.mu.Unlock()
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
 				return
 			}
+
 			s.seqs[i] = seq
 			s.cond.Signal()
 			break
@@ -815,7 +786,7 @@ func (s *Server) loadModel(
 
 	if ppath != "" {
 		var err error
-		s.clip.cc, err = llama.NewClipContext(ppath)
+		s.image, err = NewImageContext(s.lc, ppath)
 		if err != nil {
 			panic(err)
 		}
diff --git a/server/prompt.go b/server/prompt.go
index 1d6f5cdb..f91b94d8 100644
--- a/server/prompt.go
+++ b/server/prompt.go
@@ -75,11 +75,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 
 	currMsgIdx := n
 
-	if isMllama {
-		lastMsgIdx := len(msgs) - 1
-		for i := lastMsgIdx; i >= currMsgIdx; i-- {
-			if len(msgs[i].Images) > 0 {
-				data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
+	for cnt, msg := range msgs[currMsgIdx:] {
+		prefix := ""
+		imgPrompt := ""
+		prompt := msg.Content
+
+		for _, i := range msg.Images {
+			var imgData llm.ImageData
+
+			if isMllama {
+				data, aspectRatioID, err := imageproc.Preprocess(i)
 				if err != nil {
 					return "", nil, err
 				}
@@ -90,37 +95,30 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 					return "", nil, err
 				}
 
-				imgData := llm.ImageData{
+				imgData = llm.ImageData{
+					ID:            len(images),
 					Data:          buf.Bytes(),
 					AspectRatioID: aspectRatioID,
 				}
-
-				msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content)
-				images = append(images, imgData)
-				break
-			}
-		}
-	} else {
-		for cnt, msg := range msgs[currMsgIdx:] {
-			prefix := ""
-			prompt := msg.Content
-			for _, i := range msg.Images {
-				imgData := llm.ImageData{
+				imgPrompt = "<|image|>"
+			} else {
+				imgData = llm.ImageData{
 					ID:   len(images),
 					Data: i,
 				}
-
-				imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
-				if !strings.Contains(prompt, "[img]") {
-					prefix += imgTag
-				} else {
-					prompt = strings.Replace(prompt, "[img]", imgTag, 1)
-				}
-
-				images = append(images, imgData)
+				imgPrompt = " "
 			}
-			msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + " " + prompt)
+
+			imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
+			if !strings.Contains(prompt, "[img]") {
+				prefix += imgTag
+			} else {
+				prompt = strings.Replace(prompt, "[img]", imgTag, 1)
+			}
+
+			images = append(images, imgData)
 		}
+		msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + imgPrompt + prompt)
 	}
 
 	// truncate any messages that do not fit into the context window
diff --git a/server/prompt_test.go b/server/prompt_test.go
index 123a2081..6d04db53 100644
--- a/server/prompt_test.go
+++ b/server/prompt_test.go
@@ -249,7 +249,7 @@ func TestChatPrompt(t *testing.T) {
 				{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
 			},
 			expect: expect{
-				prompt:        "<|image|>How many hotdogs are in this image? ",
+				prompt:        "[img-0]<|image|>How many hotdogs are in this image? ",
 				images:        [][]byte{imgBuf},
 				aspectRatioID: 1,
 			},
@@ -264,7 +264,7 @@ func TestChatPrompt(t *testing.T) {
 				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}},
 			},
 			expect: expect{
-				prompt:        "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
+				prompt:        "You're a test, Harry! I-I'm a what? [img-0]<|image|>A test. And a thumping good one at that, I'd wager. ",
 				images:        [][]byte{imgBuf},
 				aspectRatioID: 1,
 			},
@@ -279,8 +279,8 @@ func TestChatPrompt(t *testing.T) {
 				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
 			},
 			expect: expect{
-				prompt:        "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
-				images:        [][]byte{imgBuf2},
+				prompt:        "[img-0]<|image|>You're a test, Harry! I-I'm a what? [img-1]<|image|>A test. And a thumping good one at that, I'd wager. ",
+				images:        [][]byte{imgBuf, imgBuf2},
 				aspectRatioID: 1,
 			},
 		},
@@ -294,7 +294,7 @@ func TestChatPrompt(t *testing.T) {
 				{Role: "user", Content: "Which ones have mustard?"},
 			},
 			expect: expect{
-				prompt:        "<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
+				prompt:        "[img-0]<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
 				images:        [][]byte{imgBuf},
 				aspectRatioID: 1,
 			},
diff --git a/server/routes.go b/server/routes.go
index eb2268c7..d5c4172a 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -205,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 				return
 			}
 
-			images[i] = llm.ImageData{Data: buf.Bytes(), AspectRatioID: aspectRatioID}
+			images[i] = llm.ImageData{ID: i, Data: buf.Bytes(), AspectRatioID: aspectRatioID}
 		} else {
 			images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
 		}
@@ -239,11 +239,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 			}
 
 			for _, i := range images {
+				imgPrompt := ""
 				if isMllama {
-					msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"})
-				} else {
-					msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
+					imgPrompt = "<|image|>"
 				}
+				msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
 			}
 
 			values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})